Skip to content

Instantly share code, notes, and snippets.

@nilesh0109
Last active January 2, 2021 13:19
Show Gist options
  • Save nilesh0109/77254fa4e6368fdd9fa8e82b840f1147 to your computer and use it in GitHub Desktop.
Save nilesh0109/77254fa4e6368fdd9fa8e82b840f1147 to your computer and use it in GitHub Desktop.
BYOL imeplementation
import torch
from torch import nn
import torch.nn.functional as F
import copy
from torchvision import models
class BYOL(nn.Module):
def __init__(self, backbone: nn.Module, target_momentum=0.996):
super().__init__()
self.online_network = backbone
self.target_network = copy.deepcopy(backbone)
# Projection Head
self.online_projector = ProjectorHead()
self.target_projector = ProjectorHead()
# Predictor Head
self.predictor = MLPHead(self.online_projector.out_channels, 4096, 256)
self.m = target_momentum
def initialize_target_network(self):
for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
for param_q, param_k in zip(self.online_projector.parameters(), self.target_projector.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
@torch.no_grad()
def update_target_network(self):
for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
param_k.data = self.m * param_k.data + (1 - self.m) * param_q.data
for param_q, param_k in zip(self.online_projector.parameters(), self.target_projector.parameters()):
param_k.data = self.m * param_k.data + (1 - self.m) * param_q.data
@staticmethod
def regression_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x_norm = F.normalize(x, dim=1) #L2-normalize
y_norm = F.normalize(y, dim=1) #L2-normalize
loss = 2 - 2 * (x_norm * y_norm).sum(dim=-1) #dot product
return loss.mean()
class ProjectorHead(nn.Module):
def __init__(self):
super().__init__()
num_features = 2048
self.projection = MLPHead(num_features, 4096, 256)
self.out_channels = 256
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x_pooled = self.avg_pool(x)
h = x_pooled.view(x_pooled.shape[0], x_pooled.shape[1]) # removing the last dimension
return self.projection(h)
class MLPHead(nn.Module):
def __init__(self, in_channels: int, hidden_size: int, out_size: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_channels, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, out_size)
)
def forward(self, x):
return self.net(x)
base = models.resnet50(pretrained=False)
extract_layers = {'layer4': 'feat5'}
backbone = models._utils.IntermediateLayerGetter(base, extract_layers)
byol = BYOL(backbone['feat5'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment