Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created May 13, 2019 16:02
Show Gist options
  • Save ptrblck/032c93816ccb1defb19222f1ab1951f6 to your computer and use it in GitHub Desktop.
Save ptrblck/032c93816ccb1defb19222f1ab1951f6 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.optim as optim
from apex import amp
class SiameseTower(nn.Module):
def __init__(self, inplanes=3, planes=32,blocks=3):
super(SiameseTower, self).__init__()
self.preprocessor = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size=3,stride=1,padding=1,bias=True),
BasicBlock(planes,planes),
BasicBlock(planes,planes),
BasicBlock(planes,planes),
)
block_list = []
for block in range(blocks):
block_list.append(BasicBlock(planes, planes, stride=1))
#block_list.append(conv_bn_relu_downsample(planes))
self.residual_blocks = nn.Sequential(*block_list)
self.final = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.preprocessor(x)
print(x.shape)
x = self.residual_blocks(x)
x = self.final(x)
return(x)
class BasicBlock(torch.nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, dilation=1,downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, 3, stride, 1, dilation=dilation)
self.bn1 = nn.BatchNorm2d(planes)
# self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, dilation=dilation)
self.bn2 = nn.BatchNorm2d(planes)
if stride > 1 and downsample is None:
downsample = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1,dilation=1)
self.downsample = downsample
self.stride = stride
self.act = torch.nn.LeakyReLU(negative_slope=0.2)
# self.
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.act(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
print(out.shape, residual.shape)
out += residual
out = self.act(out)
return out
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.feature_extractor = SiameseTower(inplanes=3, blocks=1)
def forward(self, x):
x = self.feature_extractor(x)
return x
model = MyModel().cuda()
criterion = nn.MSELoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, eps=1e-8)
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
x = torch.randn(10, 3, 24, 24, device='cuda')
output = model(x)
loss = criterion(output, torch.randn(output.size(), device='cuda'))
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment