Last active
January 2, 2021 13:19
-
-
Save nilesh0109/77254fa4e6368fdd9fa8e82b840f1147 to your computer and use it in GitHub Desktop.
BYOL imeplementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import copy | |
from torchvision import models | |
class BYOL(nn.Module): | |
def __init__(self, backbone: nn.Module, target_momentum=0.996): | |
super().__init__() | |
self.online_network = backbone | |
self.target_network = copy.deepcopy(backbone) | |
# Projection Head | |
self.online_projector = ProjectorHead() | |
self.target_projector = ProjectorHead() | |
# Predictor Head | |
self.predictor = MLPHead(self.online_projector.out_channels, 4096, 256) | |
self.m = target_momentum | |
def initialize_target_network(self): | |
for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()): | |
param_k.data.copy_(param_q.data) | |
param_k.requires_grad = False | |
for param_q, param_k in zip(self.online_projector.parameters(), self.target_projector.parameters()): | |
param_k.data.copy_(param_q.data) | |
param_k.requires_grad = False | |
@torch.no_grad() | |
def update_target_network(self): | |
for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()): | |
param_k.data = self.m * param_k.data + (1 - self.m) * param_q.data | |
for param_q, param_k in zip(self.online_projector.parameters(), self.target_projector.parameters()): | |
param_k.data = self.m * param_k.data + (1 - self.m) * param_q.data | |
@staticmethod | |
def regression_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
x_norm = F.normalize(x, dim=1) #L2-normalize | |
y_norm = F.normalize(y, dim=1) #L2-normalize | |
loss = 2 - 2 * (x_norm * y_norm).sum(dim=-1) #dot product | |
return loss.mean() | |
class ProjectorHead(nn.Module): | |
def __init__(self): | |
super().__init__() | |
num_features = 2048 | |
self.projection = MLPHead(num_features, 4096, 256) | |
self.out_channels = 256 | |
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) | |
def forward(self, x): | |
x_pooled = self.avg_pool(x) | |
h = x_pooled.view(x_pooled.shape[0], x_pooled.shape[1]) # removing the last dimension | |
return self.projection(h) | |
class MLPHead(nn.Module): | |
def __init__(self, in_channels: int, hidden_size: int, out_size: int): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(in_channels, hidden_size), | |
nn.BatchNorm1d(hidden_size), | |
nn.ReLU(inplace=True), | |
nn.Linear(hidden_size, out_size) | |
) | |
def forward(self, x): | |
return self.net(x) | |
base = models.resnet50(pretrained=False) | |
extract_layers = {'layer4': 'feat5'} | |
backbone = models._utils.IntermediateLayerGetter(base, extract_layers) | |
byol = BYOL(backbone['feat5']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment