Skip to content

Instantly share code, notes, and snippets.

@zengyu714
Last active April 19, 2017 02:18
Show Gist options
  • Save zengyu714/fba95dde5b0711b445828923eeab9d05 to your computer and use it in GitHub Desktop.
Save zengyu714/fba95dde5b0711b445828923eeab9d05 to your computer and use it in GitHub Desktop.
TensorFlow utils
from tensorflow.python import pywrap_tensorflow
# Get weights or whatever you want.
# Notice the filename should be full pathname, e.g. './model-7000'
def get_weights(filename):
reader = pywrap_tensorflow.NewCheckpointReader(filename)
var_to_shape_map = reader.get_variable_to_shape_map()
weights = [reader.get_tensor(key) for key in var_to_shape_map if 'weights' in key]
return weights
def dice_loss(y_true, y_conv):
"""Compute dice among **positive** labels to avoid unbalance.
Argument:
y_true: [batch_size, depth, height, width, 1]
y_conv: [batch_size, depth, height, width, 2]
"""
y_true = tf.to_float(tf.reshape(y_true[..., 0], [-1]))
y_conv = tf.to_float(tf.reshape(y_conv[..., 1], [-1]))
intersection = tf.reduce_sum(y_conv * y_true)
union = tf.reduce_sum(y_conv * y_conv) + tf.reduce_sum(y_true * y_true) # y_true is binary
dice_coef = 2.0 * intersection / union
return 1 - tf.clip_by_value(dice_coef, 0, 1.0 - 1e-7)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment