Skip to content

Instantly share code, notes, and snippets.

@lucidrains
Last active April 28, 2020 03:41
Show Gist options
  • Save lucidrains/89eb7b5ad3afd20ad08441b51ad5f9d0 to your computer and use it in GitHub Desktop.
Save lucidrains/89eb7b5ad3afd20ad08441b51ad5f9d0 to your computer and use it in GitHub Desktop.
import random
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.models import resnet50
from kornia import augmentation as augs
class OutputHiddenLayer(nn.Module):
def __init__(self, net, layer=-2):
super().__init__()
self.net = net
self.children = [*self.net.children()]
self.layer = layer
def forward(self, x):
hidden = None
def hook(_, __, output):
nonlocal hidden
hidden = output
handle = self.children[self.layer].register_forward_hook(hook)
final = self.net(x)
handle.remove()
return final, hidden
def flatten(t):
return t.reshape(t.shape[0], -1)
def contrastive_loss(z_projs, aug_z_projs):
b, device = z_projs.shape[0], z_projs.device
logits = z_projs @ aug_z_projs.t()
logits = logits - logits.max(dim=-1, keepdim=True).values
return F.cross_entropy(logits, torch.arange(b, device=device))
def nt_xent_loss(z_projs, aug_z_projs):
b, device = z_projs.shape[0], z_projs.device
n = b * 2
projs = torch.cat((z_projs, aug_z_projs))
logits = projs @ projs.t()
mask = torch.eye(n, device=device).bool()
logits = logits[~mask].reshape(n, n - 1)
labels = torch.cat(((torch.arange(b) + b - 1), torch.arange(b)), dim=0)
loss = F.cross_entropy(logits, labels, reduction='sum')
loss /= 2 * (b - 1)
return loss
class RandomApply(nn.Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)
class ContrastiveLearningWrapper(nn.Module):
def __init__(self, net, image_size, hidden_layer_index=-2, project_dim=128, augment_both=True, use_nt_xent_loss=False):
super().__init__()
self.net = OutputHiddenLayer(net, layer=hidden_layer_index)
self.augment = nn.Sequential(
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomResizedCrop((image_size, image_size))
)
self.augment_both = augment_both
self.use_nt_xent_loss = use_nt_xent_loss
self.projection = None
self.project_dim = project_dim
def _get_projection_fn(self, hidden):
_, dim = hidden.shape
if self.projection is not None:
return self.projection
self.projection = nn.Sequential(
nn.Linear(dim, dim * 2),
nn.LeakyReLU(inplace=True),
nn.Linear(dim * 2, self.project_dim)
)
return self.projection
def forward(self, x):
b, c, h, w, device = *x.shape, x.device
out, hidden = self.net(x)
if self.augment_both:
_, hidden = self.net(self.augment(x))
aug_x = self.augment(x)
_, aug_hidden = self.net(aug_x)
hidden, aug_hidden = map(flatten, (hidden, aug_hidden))
project_fn = self._get_projection_fn(hidden)
z_proj, aug_z_proj = map(project_fn, (hidden, aug_hidden))
loss_fn = nt_xent_loss if self.use_nt_xent_loss else contrastive_loss
loss = loss_fn(z_proj, aug_z_proj)
return out, loss
# self supervised learning on resnet 50 - usage instructions
# use big batch sizes for up to 100 epochs
r = resnet50(pretrained=True)
r = ContrastiveLearningWrapper(r, image_size=256, hidden_layer_index=-2, use_nt_xent_loss=True)
opt = torch.optim.Adam(r.parameters(), lr=3e-4)
for _ in range(1):
img = torch.randn(4, 3, 256, 256)
_, contrastive_loss = r(img)
opt.zero_grad()
contrastive_loss.backward()
opt.step()
# tada - resnet is magically smarter now
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment