Skip to content

Instantly share code, notes, and snippets.

@taless474
Created May 1, 2020 18:34
Show Gist options
  • Save taless474/de32022edcf0e9da8aa9754fdcec6652 to your computer and use it in GitHub Desktop.
Save taless474/de32022edcf0e9da8aa9754fdcec6652 to your computer and use it in GitHub Desktop.
char const* const kmeans_code = R"(
define(closest_centroids, points, centroids, block(
define(expanded_centroids,
reshape(centroids, list(shape(centroids, 0), 1, 2))
),
define(distances,
sqrt(sum(square(points - expanded_centroids), 2))
),
argmin(distances, 0)
))
define(move_centroids, points, closest, centroids, block(
fmap(lambda(k, block(
define(x, closest == k),
define(count, sum(x * constant(1, shape(x, 0)))),
define(sum_, sum(points * expand_dims(x, -1), 0)),
sum_/count
)),
range(shape(centroids, 0))
)
))
define(kmeans, points, iterations, initial_centroids, enable_output,
block(
define(centroids, initial_centroids),
for(define(i, 0), i < iterations, store(i, i + 1),
block(
if(enable_output, block(
cout("centroids in iteration ", i,": ", centroids)
)),
store(centroids,
apply(vstack,
list(move_centroids(points,
closest_centroids(points, centroids),
centroids))))
)
), centroids
)
)
)";
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment