Skip to content

Instantly share code, notes, and snippets.

@shnhrtkyk
Created January 23, 2020 05:36
Show Gist options
  • Save shnhrtkyk/2492b537c948fbc19dfd0ff4ee3ff4aa to your computer and use it in GitHub Desktop.
Save shnhrtkyk/2492b537c948fbc19dfd0ff4ee3ff4aa to your computer and use it in GitHub Desktop.
重み付きloss
def loss(labels, logits):
weights = tf.where(labels == 0, 1, 1)
weights = tf.where(labels == 1, 1, weights)
weights = tf.where(labels == 2, 2, weights)
weights = tf.where(labels == 3, 20, weights)
weights = tf.where(labels == 4, 20, weights)
weights = tf.where(labels == 5, 20, weights)
classify_loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits, scope='loss', weights=weights)
return classify_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment