Last active
April 28, 2020 03:41
-
-
Save lucidrains/89eb7b5ad3afd20ad08441b51ad5f9d0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torchvision.models import resnet50 | |
from kornia import augmentation as augs | |
class OutputHiddenLayer(nn.Module): | |
def __init__(self, net, layer=-2): | |
super().__init__() | |
self.net = net | |
self.children = [*self.net.children()] | |
self.layer = layer | |
def forward(self, x): | |
hidden = None | |
def hook(_, __, output): | |
nonlocal hidden | |
hidden = output | |
handle = self.children[self.layer].register_forward_hook(hook) | |
final = self.net(x) | |
handle.remove() | |
return final, hidden | |
def flatten(t): | |
return t.reshape(t.shape[0], -1) | |
def contrastive_loss(z_projs, aug_z_projs): | |
b, device = z_projs.shape[0], z_projs.device | |
logits = z_projs @ aug_z_projs.t() | |
logits = logits - logits.max(dim=-1, keepdim=True).values | |
return F.cross_entropy(logits, torch.arange(b, device=device)) | |
def nt_xent_loss(z_projs, aug_z_projs): | |
b, device = z_projs.shape[0], z_projs.device | |
n = b * 2 | |
projs = torch.cat((z_projs, aug_z_projs)) | |
logits = projs @ projs.t() | |
mask = torch.eye(n, device=device).bool() | |
logits = logits[~mask].reshape(n, n - 1) | |
labels = torch.cat(((torch.arange(b) + b - 1), torch.arange(b)), dim=0) | |
loss = F.cross_entropy(logits, labels, reduction='sum') | |
loss /= 2 * (b - 1) | |
return loss | |
class RandomApply(nn.Module): | |
def __init__(self, fn, p): | |
super().__init__() | |
self.fn = fn | |
self.p = p | |
def forward(self, x): | |
if random.random() > self.p: | |
return x | |
return self.fn(x) | |
class ContrastiveLearningWrapper(nn.Module): | |
def __init__(self, net, image_size, hidden_layer_index=-2, project_dim=128, augment_both=True, use_nt_xent_loss=False): | |
super().__init__() | |
self.net = OutputHiddenLayer(net, layer=hidden_layer_index) | |
self.augment = nn.Sequential( | |
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), | |
augs.RandomResizedCrop((image_size, image_size)) | |
) | |
self.augment_both = augment_both | |
self.use_nt_xent_loss = use_nt_xent_loss | |
self.projection = None | |
self.project_dim = project_dim | |
def _get_projection_fn(self, hidden): | |
_, dim = hidden.shape | |
if self.projection is not None: | |
return self.projection | |
self.projection = nn.Sequential( | |
nn.Linear(dim, dim * 2), | |
nn.LeakyReLU(inplace=True), | |
nn.Linear(dim * 2, self.project_dim) | |
) | |
return self.projection | |
def forward(self, x): | |
b, c, h, w, device = *x.shape, x.device | |
out, hidden = self.net(x) | |
if self.augment_both: | |
_, hidden = self.net(self.augment(x)) | |
aug_x = self.augment(x) | |
_, aug_hidden = self.net(aug_x) | |
hidden, aug_hidden = map(flatten, (hidden, aug_hidden)) | |
project_fn = self._get_projection_fn(hidden) | |
z_proj, aug_z_proj = map(project_fn, (hidden, aug_hidden)) | |
loss_fn = nt_xent_loss if self.use_nt_xent_loss else contrastive_loss | |
loss = loss_fn(z_proj, aug_z_proj) | |
return out, loss | |
# self supervised learning on resnet 50 - usage instructions | |
# use big batch sizes for up to 100 epochs | |
r = resnet50(pretrained=True) | |
r = ContrastiveLearningWrapper(r, image_size=256, hidden_layer_index=-2, use_nt_xent_loss=True) | |
opt = torch.optim.Adam(r.parameters(), lr=3e-4) | |
for _ in range(1): | |
img = torch.randn(4, 3, 256, 256) | |
_, contrastive_loss = r(img) | |
opt.zero_grad() | |
contrastive_loss.backward() | |
opt.step() | |
# tada - resnet is magically smarter now |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment