Skip to content

Instantly share code, notes, and snippets.

@yindia
Created May 18, 2017 07:52
Show Gist options
  • Save yindia/6bad038afd754ea09dbff558d04199eb to your computer and use it in GitHub Desktop.
Save yindia/6bad038afd754ea09dbff558d04199eb to your computer and use it in GitHub Desktop.
stackoverflow question
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
points_n = 200
clusters_n = 3
iteration_n = 100
x = tf.cast(tf.convert_to_tensor(np.random.rand(points_n,2),dtype=tf.float32),tf.float32)
points = tf.placeholder(tf.float32, [points_n,2])
#points = tf.constant(np.random.uniform(0, 10, (points_n, 2)))
centroids = tf.Variable(tf.slice(tf.random_shuffle(points), [0, 0], [clusters_n, -1]))
points_expanded = tf.expand_dims(points, 0)
centroids_expanded = tf.expand_dims(centroids, 1)
distances = tf.reduce_sum(tf.square(tf.sub(points_expanded, centroids_expanded)), 2)
assignments = tf.argmin(distances, 0)
means = []
for c in xrange(clusters_n):
means.append(tf.reduce_mean(
tf.gather(points,
tf.reshape(
tf.where(
tf.equal(assignments, c)
),[1,-1])
),reduction_indices=[1]))
new_centroids = tf.concat(0, means)
update_centroids = tf.assign(centroids, new_centroids)
init = tf.initialize_all_variables()
sess = tf.InteractiveSession()
sess.run(init)
for step in xrange(iteration_n):
[_, centroid_values, points_values, assignment_values] = sess.run([update_centroids, centroids, points, assignments],feed_dict={
points : x
})
print "centroids" + "\n", centroid_values
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment