Skip to content

Instantly share code, notes, and snippets.

@enijkamp
Created September 17, 2018 17:48
Show Gist options
  • Save enijkamp/567dec0a4c14ba4014f193da04229c85 to your computer and use it in GitHub Desktop.
Save enijkamp/567dec0a4c14ba4014f193da04229c85 to your computer and use it in GitHub Desktop.
t_pl = tf.placeholder(shape=[], dtype=tf.float32, name='t_pl')
# B(t) = (1-t)^3 P0 + 3(1-t)^2 t P1 + 3(1-t) t^2 P2 + t^3 P3
wt = [((1-t_pl) ** 3)*np.array(P0[i]) + (3*(1-t_pl)**2)*t_pl*P1[i] + (3*(t_pl-1)*(t_pl**2))*P2[i] + (t_pl**3)*P3[i] for i in range(len(P0))]
model_t = Cifar10ModelSimple(wt, resnet_size=32)
logits_t = model_t(x, is_training)
cross_entropy_t = tf.losses.softmax_cross_entropy(logits=logits_t, onehot_labels=y_)
correct_prediction_t = tf.equal(tf.argmax(logits_t, axis=1), tf.argmax(y_, axis=1))
accuracy_t = tf.reduce_mean(tf.cast(correct_prediction_t, tf.float32))
distance_t = distance_to(wt, P0)
sess.run(tf.global_variables_initializer())
tf.get_default_graph().finalize()
t_s = np.linspace(0.0, 1.0, 100)
c_train_s = []
for t in t_s:
d_val = sess.run(distance_t, feed_dict={t_pl: t})
c_train = estimate(cross_entropy_t, steps=200, feed_more={t_pl: t})
c_train_s.append(c_train)
plt.figure(1)
plt.plot(t_s, c_train_s, linewidth=.5, label='c_train')
plt.legend(loc='upper right')
plt.savefig(os.path.join(output_dir, 'plot_1.png'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment