Last active
February 23, 2022 13:59
-
-
Save vihari/0dc2c296e74636725cfee364637fb4f7 to your computer and use it in GitHub Desktop.
PyTorch version of CSD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def csd(self, embeds, labels, domains, num_classes, num_domains, K=1, is_training=False, scope=""): | |
"""CSD layer to be used as a replacement for your final classification layer | |
Args: | |
embeds (tensor): final layer representations of dim 2 | |
labels (tensor): tf tensor with label index of dim 1 | |
domains (tensor): tf tensor with domain index of dim 1 -- set to all zeros when testing | |
num_classes (int): Number of label classes: scalar | |
num_domains (int): Number of domains: scalar | |
K (int): Number of domain specific components to use. should be >=1 and <=num_domains-1 | |
is_training (bool): boolean value indicating if it is training | |
scope (str): name of the scope | |
Returns: | |
tuple of final loss, logits | |
Note: | |
*You should have the following parameters defined in your model definition.* | |
The variables: num_classes, num_domains, K carry the same interpretation as documented above. | |
``` | |
self.sms = torch.nn.Parameter(torch.normal(0, 1e-1, size=[K+1, 512, num_classes], dtype=torch.float, device='cuda'), requires_grad=True) | |
self.sm_biases = torch.nn.Parameter(torch.normal(0, 1e-1, size=[K+1, num_classes], dtype=torch.float, device='cuda'), requires_grad=True) | |
self.embs = torch.nn.Parameter(torch.normal(mean=0., std=1e-4, size=[num_domains, K], dtype=torch.float, device='cuda'), requires_grad=True) | |
self.cs_wt = torch.nn.Parameter(torch.normal(mean=.1, std=1e-4, size=[], dtype=torch.float, device='cuda'), requires_grad=True) | |
``` | |
This routine can then be invoked from your forward function. | |
""" | |
w_c, b_c = self.sms[0, :, :], self.sm_biases[0, :] | |
logits_common = torch.matmul(x, w_c) + b_c | |
domains = torch.nn.functional.one_hot(domains, num_domains) | |
c_wts = torch.matmul(domains, self.embs) | |
# B x K | |
batch_size = embeds.shape[0] | |
c_wts = torch.cat((torch.ones((batch_size, 1), dtype=torch.float, device='cuda')*self.cs_wt, c_wts), 1) | |
c_wts = torch.tanh(c_wts) | |
w_d, b_d = torch.einsum("bk,krl->brl", c_wts, self.sms), torch.einsum("bk,kl->bl", c_wts, self.sm_biases) | |
logits_specialized = torch.einsum("brl,br->bl", w_d, x) + b_d | |
criterion = nn.CrossEntropyLoss() | |
specific_loss = criterion(logits_specialized, labels) | |
class_loss = criterion(logits_common, labels) | |
sms = self.sms | |
diag_tensor = torch.stack([torch.eye(K+1) for _ in range(num_classes)], dim=0).cuda() | |
cps = torch.stack([torch.matmul(sms[:, :, _], torch.transpose(sms[:, :, _], 0, 1)) for _ in range(num_classes)], dim=0) | |
orth_loss = torch.mean((cps - diag_tensor)**2) | |
loss = class_loss + specific_loss + orth_loss | |
return loss, logits_common |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment