Skip to content

Instantly share code, notes, and snippets.

@sthomp
Created February 17, 2016 03:22
Show Gist options
  • Save sthomp/c908ee182aa4c08f294d to your computer and use it in GitHub Desktop.
Save sthomp/c908ee182aa4c08f294d to your computer and use it in GitHub Desktop.
Serializing pool_3 weights in TensorFlow
def load_pool3_data():
X_test_file = 'X_test_20160212-00:06:14.npy'
y_test_file = 'y_test_20160212-00:06:14.npy'
X_train_file = 'X_train_20160212-00:06:14.npy'
y_train_file = 'y_train_20160212-00:06:14.npy'
return np.load(X_train_file), np.load(y_train_file), np.load(X_test_file), np.load(y_test_file)
def batch_pool3_features(sess,X_input):
"""
Currently tensorflow can't extract pool3 in batch so this is slow:
https://github.com/tensorflow/tensorflow/issues/1021
"""
n_train = X_input.shape[0]
print 'Extracting features for %i rows' % n_train
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
X_pool3 = []
for i in range(n_train):
print 'Iteration %i' % i
pool3_features = sess.run(pool3,{'DecodeJpeg:0': X_input[i,:]})
X_pool3.append(np.squeeze(pool3_features))
return np.array(X_pool3)
def serialize_cifar_pool3(X,filename):
print 'About to generate file: %s' % filename
sess = tf.InteractiveSession()
X_pool3 = batch_pool3_features(sess,X)
np.save(filename,X_pool3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment