Skip to content

Instantly share code, notes, and snippets.

@vihari
Created June 9, 2020 16:30
Show Gist options
  • Save vihari/bad9868049ef62db783e0fc11b22bb5c to your computer and use it in GitHub Desktop.
Save vihari/bad9868049ef62db783e0fc11b22bb5c to your computer and use it in GitHub Desktop.
TF version of CSD
def csd(embeds, label_placeholder, domain_placeholder, num_classes, num_domains, K=1, is_training=False, scope=""):
"""CSD layer to be used as a replacement for your final classification layer
Args:
embeds (tensor): final layer representations of dim 2
label_placeholder (tensor): tf tensor with label index of dim 1
domain_placeholder (tensor): tf tensor with domain index of dim 1 -- set to all zeros when testing
num_classes (int): Number of label classes: scalar
num_domains (int): Number of domains: scalar
K (int): Number of domain specific components to use. should be >=1 and <=num_domains-1
is_training (bool): boolean value indicating if it is training
scope (str): name of the scope
Returns:
tuple of final loss, logits
"""
batch_size = tf.shape(net)[0]
EMB_SIZE = net.get_shape()[-1]
with tf.variable_scope(scope, 'csd'):
common_wt = tf.get_variable("common_wt", shape=[1], trainable=False, initializer=tf.ones_initializer)
specialized_common_wt = tf.get_variable("specialized_wt", shape=[1], initializer=tf.random_normal_initializer(.5, 1e-2))
emb_matrix = tf.get_variable("emb_matrix", shape=[num_domains, K], initializer=tf.random_normal_initializer(0, 1e-4))
common_cwt = tf.identity(tf.concat([common_wt, tf.zeros([K])], axis=0), name='common_cwt')
# Batch size x K + 1
c_wts = tf.nn.embedding_lookup(emb_matrix, domain_placeholder)
c_wts = tf.concat([tf.ones([batch_size, 1])*specialized_common_wt, c_wts], axis=1)
c_wts = tf.reshape(c_wts, [batch_size, K+1])
sms = tf.get_variable("sm_matrices", shape=[K+1, EMB_SIZE, num_classes], trainable=True, initializer=tf.random_normal_initializer(0, 0.05))
sm_biases = tf.get_variable("sm_bias", shape=[K+1, num_classes], trainable=True)
specific_sms = tf.einsum("ij,jkl->ikl", c_wts, sms)
common_sm = tf.einsum("j,jkl->kl", common_cwt, sms)
specific_bias = tf.einsum("ij,jl->il", c_wts, sm_biases)
common_bias = tf.einsum("j,jl->l", common_cwt, sm_biases)
diag_tensor = tf.eye(K+1, batch_shape=[num_classes])
cps = tf.stack([tf.matmul(sms[:, :, _], sms[:, :, _], transpose_b=True) for _ in range(num_classes)])
orthn_loss = tf.reduce_mean((cps - diag_tensor)**2)
reg_loss = orthn_loss
logits1 = tf.einsum("ik,ikl->il", net, specific_sms) + specific_bias
logits2 = tf.matmul(net, common_sm) + common_bias
loss1 = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits1, labels=label_placeholder))
loss2 = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits2, labels=label_placeholder))
# alpha is the hyperparam that weights common loss vs specific loss, tune it if you wish.
# It is expected to be stable with the default value though.
alpha = 0.5
loss = (1-alpha)*loss1 + alpha*loss2 + reg_loss
# return common logits
return loss, logits2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment