Skip to content

Instantly share code, notes, and snippets.

@vihari
Last active February 23, 2022 13:59
Show Gist options
  • Save vihari/0dc2c296e74636725cfee364637fb4f7 to your computer and use it in GitHub Desktop.
Save vihari/0dc2c296e74636725cfee364637fb4f7 to your computer and use it in GitHub Desktop.
PyTorch version of CSD
def csd(self, embeds, labels, domains, num_classes, num_domains, K=1, is_training=False, scope=""):
"""CSD layer to be used as a replacement for your final classification layer
Args:
embeds (tensor): final layer representations of dim 2
labels (tensor): tf tensor with label index of dim 1
domains (tensor): tf tensor with domain index of dim 1 -- set to all zeros when testing
num_classes (int): Number of label classes: scalar
num_domains (int): Number of domains: scalar
K (int): Number of domain specific components to use. should be >=1 and <=num_domains-1
is_training (bool): boolean value indicating if it is training
scope (str): name of the scope
Returns:
tuple of final loss, logits
Note:
*You should have the following parameters defined in your model definition.*
The variables: num_classes, num_domains, K carry the same interpretation as documented above.
```
self.sms = torch.nn.Parameter(torch.normal(0, 1e-1, size=[K+1, 512, num_classes], dtype=torch.float, device='cuda'), requires_grad=True)
self.sm_biases = torch.nn.Parameter(torch.normal(0, 1e-1, size=[K+1, num_classes], dtype=torch.float, device='cuda'), requires_grad=True)
self.embs = torch.nn.Parameter(torch.normal(mean=0., std=1e-4, size=[num_domains, K], dtype=torch.float, device='cuda'), requires_grad=True)
self.cs_wt = torch.nn.Parameter(torch.normal(mean=.1, std=1e-4, size=[], dtype=torch.float, device='cuda'), requires_grad=True)
```
This routine can then be invoked from your forward function.
"""
w_c, b_c = self.sms[0, :, :], self.sm_biases[0, :]
logits_common = torch.matmul(x, w_c) + b_c
domains = torch.nn.functional.one_hot(domains, num_domains)
c_wts = torch.matmul(domains, self.embs)
# B x K
batch_size = embeds.shape[0]
c_wts = torch.cat((torch.ones((batch_size, 1), dtype=torch.float, device='cuda')*self.cs_wt, c_wts), 1)
c_wts = torch.tanh(c_wts)
w_d, b_d = torch.einsum("bk,krl->brl", c_wts, self.sms), torch.einsum("bk,kl->bl", c_wts, self.sm_biases)
logits_specialized = torch.einsum("brl,br->bl", w_d, x) + b_d
criterion = nn.CrossEntropyLoss()
specific_loss = criterion(logits_specialized, labels)
class_loss = criterion(logits_common, labels)
sms = self.sms
diag_tensor = torch.stack([torch.eye(K+1) for _ in range(num_classes)], dim=0).cuda()
cps = torch.stack([torch.matmul(sms[:, :, _], torch.transpose(sms[:, :, _], 0, 1)) for _ in range(num_classes)], dim=0)
orth_loss = torch.mean((cps - diag_tensor)**2)
loss = class_loss + specific_loss + orth_loss
return loss, logits_common
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment