Skip to content

Instantly share code, notes, and snippets.

@elbow-jason
Created February 17, 2021 20:04
Show Gist options
  • Save elbow-jason/8541bdd034d5aae34a74b63c7d0b686c to your computer and use it in GitHub Desktop.
Save elbow-jason/8541bdd034d5aae34a74b63c7d0b686c to your computer and use it in GitHub Desktop.
Nx - get and set by index for rank1 tensors
iex(1)> t1 = Nx.tensor([3, 4, 5])
#Nx.Tensor<
s64[3]
[3, 4, 5]
>
iex(2)> Rank1.at(t1, 0)
#Nx.Tensor<
s64
3
>
iex(3)> t2 = Rank1.put(t1, 0, -2)
#Nx.Tensor<
s64[3]
[-2, 4, 5]
>
iex(4)> Rank1.at(t1, 0)
#Nx.Tensor<
s64
3
>
iex(5)> Rank1.at(t2, 0)
#Nx.Tensor<
s64
-2
>
defmodule Rank1 do
import Nx.Defn
defn at(tensor, i) do
shape = Nx.shape(tensor)
iotas = Nx.iota(shape)
predicate = Nx.equal(iotas, i)
values = Nx.select(predicate, tensor, 0)
Nx.sum(values)
end
defn put(tensor, i, val) do
shape = Nx.shape(tensor)
iotas = Nx.iota(shape)
predicate = Nx.equal(iotas, i)
Nx.select(predicate, val, tensor)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment