Created
June 9, 2020 16:30
-
-
Save vihari/bad9868049ef62db783e0fc11b22bb5c to your computer and use it in GitHub Desktop.
TF version of CSD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def csd(embeds, label_placeholder, domain_placeholder, num_classes, num_domains, K=1, is_training=False, scope=""): | |
"""CSD layer to be used as a replacement for your final classification layer | |
Args: | |
embeds (tensor): final layer representations of dim 2 | |
label_placeholder (tensor): tf tensor with label index of dim 1 | |
domain_placeholder (tensor): tf tensor with domain index of dim 1 -- set to all zeros when testing | |
num_classes (int): Number of label classes: scalar | |
num_domains (int): Number of domains: scalar | |
K (int): Number of domain specific components to use. should be >=1 and <=num_domains-1 | |
is_training (bool): boolean value indicating if it is training | |
scope (str): name of the scope | |
Returns: | |
tuple of final loss, logits | |
""" | |
batch_size = tf.shape(net)[0] | |
EMB_SIZE = net.get_shape()[-1] | |
with tf.variable_scope(scope, 'csd'): | |
common_wt = tf.get_variable("common_wt", shape=[1], trainable=False, initializer=tf.ones_initializer) | |
specialized_common_wt = tf.get_variable("specialized_wt", shape=[1], initializer=tf.random_normal_initializer(.5, 1e-2)) | |
emb_matrix = tf.get_variable("emb_matrix", shape=[num_domains, K], initializer=tf.random_normal_initializer(0, 1e-4)) | |
common_cwt = tf.identity(tf.concat([common_wt, tf.zeros([K])], axis=0), name='common_cwt') | |
# Batch size x K + 1 | |
c_wts = tf.nn.embedding_lookup(emb_matrix, domain_placeholder) | |
c_wts = tf.concat([tf.ones([batch_size, 1])*specialized_common_wt, c_wts], axis=1) | |
c_wts = tf.reshape(c_wts, [batch_size, K+1]) | |
sms = tf.get_variable("sm_matrices", shape=[K+1, EMB_SIZE, num_classes], trainable=True, initializer=tf.random_normal_initializer(0, 0.05)) | |
sm_biases = tf.get_variable("sm_bias", shape=[K+1, num_classes], trainable=True) | |
specific_sms = tf.einsum("ij,jkl->ikl", c_wts, sms) | |
common_sm = tf.einsum("j,jkl->kl", common_cwt, sms) | |
specific_bias = tf.einsum("ij,jl->il", c_wts, sm_biases) | |
common_bias = tf.einsum("j,jl->l", common_cwt, sm_biases) | |
diag_tensor = tf.eye(K+1, batch_shape=[num_classes]) | |
cps = tf.stack([tf.matmul(sms[:, :, _], sms[:, :, _], transpose_b=True) for _ in range(num_classes)]) | |
orthn_loss = tf.reduce_mean((cps - diag_tensor)**2) | |
reg_loss = orthn_loss | |
logits1 = tf.einsum("ik,ikl->il", net, specific_sms) + specific_bias | |
logits2 = tf.matmul(net, common_sm) + common_bias | |
loss1 = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits1, labels=label_placeholder)) | |
loss2 = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits2, labels=label_placeholder)) | |
# alpha is the hyperparam that weights common loss vs specific loss, tune it if you wish. | |
# It is expected to be stable with the default value though. | |
alpha = 0.5 | |
loss = (1-alpha)*loss1 + alpha*loss2 + reg_loss | |
# return common logits | |
return loss, logits2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment