Skip to content

Instantly share code, notes, and snippets.

@dantswain
Created November 28, 2021 03:23
Show Gist options
  • Save dantswain/38f56db677db21d8335d8a29fc73c81b to your computer and use it in GitHub Desktop.
Save dantswain/38f56db677db21d8335d8a29fc73c81b to your computer and use it in GitHub Desktop.
K-Means Clustering with Elixir NX

K-Means Clustering

The data

Mix.install([
  {:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true},
  {:vega_lite, "~> 0.1"},
  {:kino, "~> 0.3"}
])
n_points = 64
x_min = -4
x_max = 4
y_min = -4
y_max = 4
n_dims = 2

n_per_init_cluster = floor(n_points / 2)

r1 =
  Nx.add(
    Nx.random_normal({n_per_init_cluster, n_dims}, 0.0, 1.0, names: [:x, :y]),
    Nx.tensor([1.0, 1.0])
  )

label1 = Nx.broadcast(0, {n_per_init_cluster, 1})
c1 = Nx.concatenate([r1, label1], axis: 1)
r2 = Nx.add(Nx.random_normal({n_per_init_cluster, n_dims}, 0.0, 1.0), Nx.tensor([-1.0, -1.0]))
label2 = Nx.broadcast(1, {n_per_init_cluster, 1})
c2 = Nx.concatenate([r2, label2], axis: 1)
labeled = Nx.concatenate([c1, c2])
alias VegaLite, as: Vl

# a helper to plot labeled data
mk_data_layer = fn labeled_data ->
  Vl.new()
  |> Vl.data_from_series(
    x: Nx.to_flat_list(labeled_data[y: 0]),
    y: Nx.to_flat_list(labeled_data[y: 1]),
    label: Nx.to_flat_list(labeled_data[y: 2])
  )
  |> Vl.mark(:point)
  |> Vl.encode_field(:x, "x", type: :quantitative, title: "X")
  |> Vl.encode_field(:y, "y", type: :quantitative, title: "Y")
  |> Vl.encode_field(:color, "label", type: :nominal)
end

Vl.new(title: "Raw Data w/ True Labels", width: 700, height: 700)
|> Vl.layers([
  mk_data_layer.(labeled)
])

Clustering - Initialization

k = 2

# the unlabeled data
data = labeled[y: 0..1]

# calculate initial centroids randomly uniformly in the space that the data spans
initial_centroids =
  0..(n_dims - 1)
  |> Enum.reduce(nil, fn ix, acc ->
    pos = [x_min + (x_max - x_min) * :rand.uniform(), y_min + (y_max - y_min) * :rand.uniform()]

    case acc do
      nil -> Nx.tensor([pos ++ [ix]], names: [:x, :y])
      _ -> Nx.concatenate([acc, Nx.tensor([pos ++ [ix]])])
    end
  end)
# helper to plot centroids
mk_centroid_layer = fn labeled_centroids ->
  Vl.new()
  |> Vl.data_from_series(
    x: Nx.to_flat_list(labeled_centroids[y: 0]),
    y: Nx.to_flat_list(labeled_centroids[y: 1]),
    label: Nx.to_flat_list(labeled_centroids[y: 2])
  )
  |> Vl.mark(:square, size: 400)
  |> Vl.encode_field(:x, "x", type: :quantitative, title: "X")
  |> Vl.encode_field(:y, "y", type: :quantitative, title: "Y")
  |> Vl.encode_field(:color, "label", type: :nominal)
end

Vl.new(title: "Location of Initial Centroids w/ True Labels", width: 700, height: 700)
|> Vl.layers([
  mk_data_layer.(labeled),
  mk_centroid_layer.(initial_centroids)
])
# helper function to calculate the distance from data to centroids (unlabeled)
dist_fn = fn d, centroids ->
  c = Nx.new_axis(centroids, 1)

  Nx.subtract(d, c)
  |> Nx.power(2)
  |> Nx.sum(axes: [2])
  |> Nx.sqrt()
end

# hepler function to find labels
find_labels = fn d, centroids ->
  dist_fn.(d, centroids)
  |> Nx.argmin(axis: 0)
end

new_labels = find_labels.(data, initial_centroids[y: 0..(n_dims - 1)])

alg_labeled = Nx.concatenate([data, Nx.new_axis(new_labels, 1)], axis: 1)
Vl.new(title: "Initial Labeling", width: 700, height: 700)
|> Vl.layers([
  mk_data_layer.(alg_labeled)
])

Clustering - First Iteration

calc_centroids_map = fn data, labels, old_centroids ->
  Enum.reduce(0..(k - 1), %{}, fn el, acc ->
    selector =
      labels
      |> Nx.equal(el)
      |> Nx.reshape({n_points, 1})
      |> Nx.tile([1, n_dims])

    summed =
      selector
      |> Nx.select(data, Nx.tensor([0]))
      |> Nx.sum(axes: [0])
      |> Map.put(:names, [:x, :y])

    num_in_cluster = Nx.to_scalar(Nx.sum(selector))

    if num_in_cluster == 0 do
      Map.put(acc, el, Nx.take(old_centroids, el))
    else
      Map.put(acc, el, Nx.divide(summed, num_in_cluster))
    end
  end)
end

new_centroids = calc_centroids_map.(data, new_labels, initial_centroids)

label_centroids = fn centroids ->
  Nx.concatenate(
    [
      Nx.stack(Map.values(centroids)),
      Nx.iota({k, 1})
    ],
    axis: 1
  )
end

new_centroids = label_centroids.(new_centroids)
new_labels = find_labels.(data, new_centroids[y: 0..(n_dims - 1)])
alg_labeled = Nx.concatenate([data, Nx.new_axis(new_labels, 1)], axis: 1)
Vl.new(title: "Result of First Iteration", width: 700, height: 700)
|> Vl.layers([
  mk_data_layer.(alg_labeled),
  mk_centroid_layer.(new_centroids)
])

Clustering - N Iterations

n_iters = 10

# rename some variables
centroids = new_centroids
labels = new_labels

{final_centroids, final_labels} =
  Enum.reduce(1..n_iters, {centroids, labels}, fn _ix, {pvs_centroids, pvs_labels} ->
    new_centroids = calc_centroids_map.(data, pvs_labels, pvs_centroids)
    new_centroids = label_centroids.(new_centroids)
    new_labels = find_labels.(data, new_centroids[y: 0..(n_dims - 1)])
    {new_centroids, new_labels}
  end)
alg_labeled = Nx.concatenate([data, Nx.new_axis(final_labels, 1)], axis: 1)

true_labels_layer =
  Vl.new()
  |> Vl.data_from_series(
    x: Nx.to_flat_list(labeled[y: 0]),
    y: Nx.to_flat_list(labeled[y: 1]),
    label: Nx.to_flat_list(labeled[y: 2])
  )
  |> Vl.mark(:point, size: 200)
  |> Vl.encode_field(:x, "x", type: :quantitative, title: "X")
  |> Vl.encode_field(:y, "y", type: :quantitative, title: "Y")
  |> Vl.encode_field(:color, "label", type: :nominal)

Vl.new(title: "Result of N Iterations", width: 700, height: 700)
|> Vl.layers([
  mk_data_layer.(alg_labeled),
  true_labels_layer,
  mk_centroid_layer.(final_centroids)
])
@polvalente
Copy link

  1. In dist_fn you could've used Nx.LinAlg.norm https://gist.github.com/dantswain/38f56db677db21d8335d8a29fc73c81b#file-dts_k_means-livemd-L105-L112 like so:
dist_fn = fn %{shape: {m, n}} = d, %{shape: {_, num_coords}} = centroids ->                                      
    c = Nx.new_axis(centroids, 1)                                 

    d
    |> Nx.subtract(c)
    |> Nx.reshape({m * n, num_coords})
    |> Nx.LinAlg.norm(axes: [1])
    |> Nx.reshape({n, m})
 end

It requires reshaping, but since reshape is O(1), this is acceptable :)

  1. calc_centroids_map does not need to return a map. Actually, for maps of size > 32 you would end up swapping centroid labels this way. You just need to change it to Enum.map and instead of Map.put, you just return the corresponding Nx.take/Nx.divide directly

  2. The last visualization was kind of confusing. The outer circle is the initial label and the inner, the final label? Perhaps you could've incorporated x for the initial and o for the final one, and then added the new 2 values to the legend. Also, keep in mind that since you initialize the centroids randomly, the labels can switch between each other, so "true" labels is perhaps not the best terminology

Other than this, I liked your code a lot!

@dantswain
Copy link
Author

Thanks @polvalente ! This all makes sense. Re: (3), the outer circle is actually a "true" label since I generated the data at the outset from two distributions and the color corresponds to which distribution. It's a little contrived, but it was a helpful comparison for me to see if the algorithm was doing what I thought it should.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment