Skip to content

Instantly share code, notes, and snippets.

@risenW
Last active July 16, 2018 18:46
Show Gist options
  • Save risenW/d7470f1aa330d5756a94491ba50b4f71 to your computer and use it in GitHub Desktop.
Save risenW/d7470f1aa330d5756a94491ba50b4f71 to your computer and use it in GitHub Desktop.
Code for weigted cross entropy using two weights
#comparing two weights
weight = tf.constant(1.)
x_entropy_weighted_vals = tf.nn.weighted_cross_entropy_with_logits(targets=Y_labels, logits=Y_pred, pos_weight=weight)
x_entropy_weighted_out = sess.run(x_entropy_weighted_vals)
weight2 = tf.constant(0.5)
x_entropy_weighted_val_2 = tf.nn.weighted_cross_entropy_with_logits(targets=Y_labels, logits=Y_pred, pos_weight=weight2)
x_entropy_weighted_out_2 = sess.run(x_entropy_weighted_val_2)
#ploting the predicted values against the Sigmoid cross entropy loss
Y_array = sess.run(Y_pred)
plt.plot(Y_array, x_entropy_weighted_out, 'b-', label=' weight = 1.0' )
plt.plot(Y_array, x_entropy_weighted_out_2, 'r--', label='weight = 0.5' )
plt.title('Weighted cross entropy loss')
plt.legend(loc=4)
plt.xlabel('$Y_{pred}$', fontsize=15)
plt.ylabel('$Y_{label}$', fontsize=15)
plt.ylim(-2, 5)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment