Last active
November 18, 2023 20:07
-
-
Save ityonemo/d75a43104b4dc153cfce20f7f07b2d29 to your computer and use it in GitHub Desktop.
Floating point to GPTQ in Nx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule NxExt do | |
import Nx.Defn | |
@bitshift Nx.tensor([[1], [16]], type: {:u, 8}) | |
@doc """ | |
Takes an N-vector of floats (arbitrarily typed) and converts it into 4-bit gptq, which has | |
a range of -8..-7. Should be compacted into two "floats" per byte, with the lower indexed | |
value in the less significant nybble | |
### TODO: check that the sub-endianness is correct. | |
```elixir | |
iex> [-6.0, 1.0, 7.0, -3.0] | |
...> |> Nx.tensor(type: {:f, 16}) | |
...> |> NxExt.to_gptq() | |
...> |> Nx.to_binary() | |
<<1::signed-size(4), -6::signed-size(4), -3::signed-size(4), 7::signed-size(4)>> = | |
``` | |
""" | |
defn to_gptq(tensor) do | |
reshaped = | |
tensor | |
|> Nx.clip(-8, 7) | |
|> Nx.as_type({:s, 8}) | |
|> Nx.bitcast({:u, 8}) | |
|> Nx.bitwise_and(15) | |
|> Nx.reshape({:auto, 2}) | |
|> Nx.dot(@bitshift) | |
|> Nx.reshape({:auto}) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment