Last active
September 24, 2021 09:31
-
-
Save rwightman/893723e023ffccca4abf67f92cb76526 to your computer and use it in GitHub Desktop.
Two RNN (1d CNN + LSTM) models for the Kaggle QuickDraw Challenge.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' Two sample RNN (1d CNN + LSTM) networks used in the Kaggle | |
QuickDraw Challenge (https://www.kaggle.com/c/quickdraw-doodle-recognition) | |
Both of these networks expect a tuple input with first element being the sequences | |
and second being the sequence lengths (typical sorted packed format). The sequence tensor | |
should adhere to the following shape: (batch_size, channels, seq_len). | |
Where channels consists of stroke [x, y, t, end]. End indicates whether the | |
stroke is the last in a segment (pen up). It could easily be changed to start | |
(pen down) or combo of both. | |
Copyright (c) 2018 Ross Wightman. All rights reserved. | |
This work is licensed under the terms of the MIT license. | |
For a copy, see <https://opensource.org/licenses/MIT>. | |
''' | |
import torch | |
import torch.nn as nn | |
import torch.nn.init as init | |
import torch.nn.functional as F | |
import torch.nn.utils.rnn as rnn_utils | |
from collections import OrderedDict | |
def basic_seq_stroke_net(num_classes): | |
return BasicSeqStrokeNet(num_classes=num_classes) | |
def se_seq_stroke_net(num_classes): | |
return SEStrokeSeqNet(num_classes=num_classes, dropout_final=0.) | |
def _initialize_module(m): | |
classname = m.__class__.__name__ | |
if isinstance(m, nn.LSTM): | |
for param in m.parameters(): | |
if len(param.shape) >= 2: | |
init.orthogonal_(param.data) | |
else: | |
init.normal_(param.data) | |
elif isinstance(m, nn.Linear): | |
init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
m.bias.data.fill_(0.01) | |
elif 'Conv' in classname: | |
init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
elif 'BatchNorm' in classname: | |
init.constant_(m.weight, 1) | |
init.constant_(m.bias, 0) | |
class SEModule(nn.Module): | |
def __init__(self, channels, reduction): | |
super(SEModule, self).__init__() | |
self.avg_pool = nn.AdaptiveAvgPool1d(1) | |
self.fc1 = nn.Conv1d( | |
channels, channels // reduction, kernel_size=1, padding=0) | |
self.relu = nn.ReLU(inplace=True) | |
self.fc2 = nn.Conv1d(channels // reduction, channels, kernel_size=1, padding=0) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, x): | |
module_input = x | |
x = self.avg_pool(x) | |
x = self.fc1(x) | |
x = self.relu(x) | |
x = self.fc2(x) | |
x = self.sigmoid(x) | |
return module_input * x | |
class Bottleneck(nn.Module): | |
def forward(self, x): | |
residual = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
out = self.relu(out) | |
out = self.conv3(out) | |
out = self.bn3(out) | |
if self.downsample is not None: | |
residual = self.downsample(x) | |
out = self.se_module(out) + residual | |
out = self.relu(out) | |
return out | |
class SEBottleneck(Bottleneck): | |
expansion = 4 | |
def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): | |
super(Bottleneck, self).__init__() | |
self.conv1 = nn.Conv1d(inplanes, planes * 2, kernel_size=1, bias=False) | |
self.bn1 = nn.BatchNorm1d(planes * 2) | |
self.conv2 = nn.Conv1d( | |
planes * 2, planes * 4, kernel_size=3, stride=stride, | |
padding=1, groups=groups, bias=False) | |
self.bn2 = nn.BatchNorm1d(planes * 4) | |
self.conv3 = nn.Conv1d( | |
planes * 4, planes * 4, kernel_size=1, bias=False) | |
self.bn3 = nn.BatchNorm1d(planes * 4) | |
self.relu = nn.ReLU(inplace=True) | |
self.se_module = SEModule(planes * 4, reduction=reduction) | |
self.downsample = downsample | |
self.stride = stride | |
class SEResNetBottleneck(Bottleneck): | |
expansion = 4 | |
def __init__(self, inplanes, planes, groups, reduction, stride=1, | |
downsample=None): | |
super(SEResNetBottleneck, self).__init__() | |
self.conv1 = nn.Conv1d( | |
inplanes, planes, kernel_size=1, bias=False, stride=stride) | |
self.bn1 = nn.BatchNorm1d(planes) | |
self.conv2 = nn.Conv1d( | |
planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) | |
self.bn2 = nn.BatchNorm1d(planes) | |
self.conv3 = nn.Conv1d(planes, planes * 4, kernel_size=1, bias=False) | |
self.bn3 = nn.BatchNorm1d(planes * 4) | |
self.relu = nn.ReLU(inplace=True) | |
self.se_module = SEModule(planes * 4, reduction=reduction) | |
self.downsample = downsample | |
self.stride = stride | |
class SEStrokeSeqNet(nn.Module): | |
''' A QuickDraw stroke based RNN with 1d CNN blocks inspired by SE (Squeeze-Excite) networks. | |
PyTorch code for SE blocks ripped from https://github.com/hujie-frank/SENet by way of | |
https://github.com/Cadene/pretrained-models.pytorch | |
''' | |
def __init__(self, in_ch=4, block=SEResNetBottleneck, layers=[3, 4], groups=1, reduction=16, | |
dropout_final=0., dropout_lstm=0.1, inplanes=128, input_3x3=True, num_classes=340, | |
rnn='lstm'): | |
super(SEStrokeSeqNet, self).__init__() | |
self.num_classes = num_classes | |
n = 1 # make it easy to experiment with smaller networks | |
reduction = reduction // n | |
self.inplanes = inplanes // n | |
if input_3x3: | |
layer0_modules = [ | |
('conv1', nn.Conv1d(in_ch, 64 // n, 3, stride=1, padding=1, bias=False)), | |
('bn1', nn.BatchNorm1d(64 // n)), | |
('relu1', nn.ReLU(inplace=True)), | |
('conv2', nn.Conv1d(64 // n, 64 // n, 3, stride=1, padding=1, bias=False)), | |
('bn2', nn.BatchNorm1d(64 // n)), | |
('relu2', nn.ReLU(inplace=True)), | |
('conv3', nn.Conv1d(64 // n, inplanes, 3, stride=1, padding=1, bias=False)), | |
('bn3', nn.BatchNorm1d(inplanes)), | |
('relu3', nn.ReLU(inplace=True)), | |
] | |
else: | |
layer0_modules = [ | |
('conv1', nn.Conv1d(3, inplanes, kernel_size=7, stride=1, padding=3, bias=False)), | |
('bn1', nn.BatchNorm1d(inplanes)), | |
('relu1', nn.ReLU(inplace=True)), | |
] | |
self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) | |
self.layer1 = self._make_layer( | |
block, | |
planes=64 // n, | |
blocks=layers[0], | |
groups=1, | |
reduction=reduction, | |
downsample_kernel_size=3, | |
downsample_padding=1 | |
) | |
self.layer2 = self._make_layer( | |
block, | |
planes=128 // n, | |
blocks=layers[1], | |
stride=1, | |
groups=1, | |
reduction=reduction, | |
downsample_kernel_size=1, | |
downsample_padding=0 | |
) | |
if rnn == 'lstm': | |
self.rnn = torch.nn.LSTM( | |
input_size=rnn_input_size, | |
hidden_size=512, | |
num_layers=4, | |
dropout=dropout_lstm, | |
batch_first=True, | |
bidirectional=True) | |
else: | |
self.rnn = torch.nn.GRU( | |
input_size=rnn_input_size, | |
hidden_size=512, | |
num_layers=4, | |
dropout=dropout_lstm, | |
batch_first=True, | |
bidirectional=True) | |
# post RNN penultimate FC | |
self.penultimate = nn.Sequential(OrderedDict([ | |
('fc1', nn.Linear(1024, 512)), | |
('bn1', nn.BatchNorm1d(512)), | |
('relu1', nn.ReLU()), | |
])) | |
self.dropout = nn.Dropout(dropout_final) if dropout_final is not None else None | |
# classifier | |
self.cls = nn.Linear(512, self.num_classes) | |
self._initialize() | |
def _initialize(self): | |
for m in self.modules(): | |
_initialize_module(m) | |
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, | |
downsample_kernel_size=1, downsample_padding=0): | |
downsample = None | |
if stride != 1 or self.inplanes != planes * block.expansion: | |
downsample = nn.Sequential( | |
nn.Conv1d(self.inplanes, planes * block.expansion, | |
kernel_size=downsample_kernel_size, stride=stride, | |
padding=downsample_padding, bias=False), | |
nn.BatchNorm1d(planes * block.expansion), | |
) | |
layers = [] | |
layers.append(block(self.inplanes, planes, groups, reduction, stride, downsample)) | |
self.inplanes = planes * block.expansion | |
for i in range(1, blocks): | |
layers.append(block(self.inplanes, planes, groups, reduction)) | |
return nn.Sequential(*layers) | |
def conv_features(self, x): | |
x = self.layer0(x) | |
x = self.layer1(x) | |
x = self.layer2(x) | |
return x | |
def rnn_features(self, x_seq, x_seq_len): | |
x_seq = x_seq.transpose(2, 1) | |
x = rnn_utils.pack_padded_sequence(x_seq, x_seq_len, batch_first=True) | |
x, h = self.rnn(x) | |
if not isinstance(self.rnn, nn.GRU): | |
h = h[0] | |
hc = torch.cat((h[-2], h[-1]), 1) | |
return hc | |
def logits(self, x): | |
x = self.penultimate(x) | |
if self.dropout is not None: | |
x = self.dropout(x) | |
x = self.cls(x) | |
return x | |
def forward(self, x_tuple): | |
x_seq, x_seq_len = x_tuple | |
x_seq = self.conv_features(x_seq) | |
x = self.rnn_features(x_seq, x_seq_len) | |
x = self.logits(x) | |
return x | |
class BasicSeqStrokeNet(nn.Module): | |
''' A QuickDraw stroke based RNN with 1d CNN blocks. | |
General structure inspired by https://github.com/tensorflow/models/blob/master/tutorials/rnn/quickdraw/train_model.py | |
''' | |
def __init__(self, dropout=0.1, num_classes=340, rnn='lstm'): | |
super(BasicSeqStrokeNet, self).__init__() | |
self.num_classes = num_classes | |
conv_params = [(4, 8, 3), (8, 16, 3), (16, 32, 3), (32, 64, 3)] | |
conv_list = [] | |
for n, p in enumerate(conv_params): | |
conv_list.append(('conv%d' % n, nn.Conv1d(p[0], p[1], p[2], stride=1, padding=1 if p[2] == 3 else 0))) | |
conv_list.append(('bn%d' % n, nn.BatchNorm1d(p[1]))) | |
conv_list.append(('act%d' % n, nn.ReLU())) | |
self.conv = torch.nn.Sequential(OrderedDict(conv_list)) | |
rnn_input_size = conv_params[-1][1] | |
if rnn == 'lstm': | |
self.rnn = torch.nn.LSTM( | |
input_size=rnn_input_size, | |
hidden_size=512, | |
num_layers=4, | |
dropout=0.1, | |
batch_first=True, | |
bidirectional=True) | |
else: | |
self.rnn = torch.nn.GRU( | |
input_size=rnn_input_size, | |
hidden_size=512, | |
num_layers=4, | |
dropout=0.1, | |
batch_first=True, | |
bidirectional=True) | |
self.cls = nn.Sequential( | |
nn.Linear(1024, 512), | |
nn.BatchNorm1d(512), | |
nn.ReLU(), | |
nn.Linear(512, self.num_classes)) | |
self._initialize() | |
def _initialize(self): | |
for m in self.modules(): | |
_initialize_module(m) | |
def forward(self, x): | |
x_seq, x_seq_len = x | |
x_seq = self.conv(x_seq) | |
x_seq.transpose_(2, 1) | |
x = rnn_utils.pack_padded_sequence(x_seq, x_seq_len, batch_first=True) | |
x, (h, c) = self.rnn(x) | |
x, h = self.rnn(x) | |
if not isinstance(self.rnn, nn.GRU): | |
h = h[0] | |
hc = torch.cat((h[-2], h[-1]), 1) | |
## This extra work with output was not necessary, simpler to use last hidden state | |
# | |
# xo, xl = rnn_utils.pad_packed_sequence(x, batch_first=True) | |
# | |
# idx = (xl - 1).view(-1, 1).expand(xl.size(0), xo.size(2) // 2).unsqueeze(1).cuda() | |
# | |
# # Shape: (batch_size, rnn_hidden_dim) | |
# last_output_forward = xo[:, :, :512].gather(1, idx).squeeze(1) | |
# last_output_reverse = xo[:, 0, 512:] | |
# | |
# last_output = torch.cat((last_output_forward, last_output_reverse), 1) | |
# assert torch.all(hc == last_output) | |
x = self.cls(hc) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment