Mix.install([
{:torchx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "torchx"},
{:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true}
])
k = IO.gets("K (clusters)") |> String.trim() |> String.to_integer()
num_points = IO.gets("num_points") |> String.trim() |> String.to_integer()
dimensions = IO.gets("dimensions") |> String.trim() |> String.to_integer()
# points = Nx.random_uniform({num_points, dimensions}) |> IO.inspect(label: "points")
points =
Nx.tensor([
[1, 1],
[1.1, 1.1],
[0.9, 0.9],
[10, 10],
[11, 10],
[10, 11],
[100, 100],
[101, 100],
[100, 101]
])
slice_start = Enum.random(1..(num_points - k)//1)
centroids =
points
|> Nx.slice([slice_start, 0], [k, dimensions])
centroids = Nx.tensor([[1, 1], [10, 10], [100, 100]])
# Run algorithm
max_iter = IO.gets("max_iter") |> String.trim() |> String.to_integer()
Enum.reduce_while(1..max_iter, {Nx.broadcast(0, {num_points}), centroids}, fn _, {_, centroids} ->
tiled_centroids =
centroids
|> IO.inspect(label: "centroids")
|> Nx.reshape({1, k * dimensions})
|> Nx.tile([num_points, 1])
assignments =
points
|> Nx.tile([1, k])
|> Nx.subtract(tiled_centroids)
|> Nx.reshape({k * num_points, dimensions})
|> Nx.LinAlg.norm(axes: [1])
|> Nx.reshape({num_points, k})
|> Nx.argsort(axis: 1, direction: :asc)
|> Nx.slice([0, 0], [num_points, 1])
|> Nx.reshape({num_points})
new_centroids =
Nx.stack(
for i <- 0..(k - 1)//1 do
selector =
Nx.equal(assignments, i)
|> Nx.reshape({num_points, 1})
|> Nx.tile([1, dimensions])
|> IO.inspect(label: "selector")
den = Nx.sum(selector)
centroid = Nx.select(selector, points, Nx.broadcast(0, points)) |> Nx.sum(axes: [0])
if Nx.equal(den, 0) |> Nx.to_scalar() == 1 do
Nx.take(centroids, i)
else
Nx.divide(centroid, den)
end
end
)
{:cont, {assignments, new_centroids}}
end)
@polvalente Thanks for the example! Here is my version https://gist.github.com/dantswain/38f56db677db21d8335d8a29fc73c81b would love feedback.