Last active
July 16, 2018 18:46
-
-
Save risenW/d7470f1aa330d5756a94491ba50b4f71 to your computer and use it in GitHub Desktop.
Code for weigted cross entropy using two weights
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#comparing two weights | |
weight = tf.constant(1.) | |
x_entropy_weighted_vals = tf.nn.weighted_cross_entropy_with_logits(targets=Y_labels, logits=Y_pred, pos_weight=weight) | |
x_entropy_weighted_out = sess.run(x_entropy_weighted_vals) | |
weight2 = tf.constant(0.5) | |
x_entropy_weighted_val_2 = tf.nn.weighted_cross_entropy_with_logits(targets=Y_labels, logits=Y_pred, pos_weight=weight2) | |
x_entropy_weighted_out_2 = sess.run(x_entropy_weighted_val_2) | |
#ploting the predicted values against the Sigmoid cross entropy loss | |
Y_array = sess.run(Y_pred) | |
plt.plot(Y_array, x_entropy_weighted_out, 'b-', label=' weight = 1.0' ) | |
plt.plot(Y_array, x_entropy_weighted_out_2, 'r--', label='weight = 0.5' ) | |
plt.title('Weighted cross entropy loss') | |
plt.legend(loc=4) | |
plt.xlabel('$Y_{pred}$', fontsize=15) | |
plt.ylabel('$Y_{label}$', fontsize=15) | |
plt.ylim(-2, 5) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment