Skip to content

Instantly share code, notes, and snippets.

@rwightman
Last active September 24, 2021 09:31
Show Gist options
  • Save rwightman/893723e023ffccca4abf67f92cb76526 to your computer and use it in GitHub Desktop.
Save rwightman/893723e023ffccca4abf67f92cb76526 to your computer and use it in GitHub Desktop.
Two RNN (1d CNN + LSTM) models for the Kaggle QuickDraw Challenge.
''' Two sample RNN (1d CNN + LSTM) networks used in the Kaggle
QuickDraw Challenge (https://www.kaggle.com/c/quickdraw-doodle-recognition)
Both of these networks expect a tuple input with first element being the sequences
and second being the sequence lengths (typical sorted packed format). The sequence tensor
should adhere to the following shape: (batch_size, channels, seq_len).
Where channels consists of stroke [x, y, t, end]. End indicates whether the
stroke is the last in a segment (pen up). It could easily be changed to start
(pen down) or combo of both.
Copyright (c) 2018 Ross Wightman. All rights reserved.
This work is licensed under the terms of the MIT license.
For a copy, see <https://opensource.org/licenses/MIT>.
'''
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils
from collections import OrderedDict
def basic_seq_stroke_net(num_classes):
return BasicSeqStrokeNet(num_classes=num_classes)
def se_seq_stroke_net(num_classes):
return SEStrokeSeqNet(num_classes=num_classes, dropout_final=0.)
def _initialize_module(m):
classname = m.__class__.__name__
if isinstance(m, nn.LSTM):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
m.bias.data.fill_(0.01)
elif 'Conv' in classname:
init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif 'BatchNorm' in classname:
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
class SEModule(nn.Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool1d(1)
self.fc1 = nn.Conv1d(
channels, channels // reduction, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv1d(channels // reduction, channels, kernel_size=1, padding=0)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class Bottleneck(nn.Module):
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out = self.se_module(out) + residual
out = self.relu(out)
return out
class SEBottleneck(Bottleneck):
expansion = 4
def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv1d(inplanes, planes * 2, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm1d(planes * 2)
self.conv2 = nn.Conv1d(
planes * 2, planes * 4, kernel_size=3, stride=stride,
padding=1, groups=groups, bias=False)
self.bn2 = nn.BatchNorm1d(planes * 4)
self.conv3 = nn.Conv1d(
planes * 4, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm1d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.se_module = SEModule(planes * 4, reduction=reduction)
self.downsample = downsample
self.stride = stride
class SEResNetBottleneck(Bottleneck):
expansion = 4
def __init__(self, inplanes, planes, groups, reduction, stride=1,
downsample=None):
super(SEResNetBottleneck, self).__init__()
self.conv1 = nn.Conv1d(
inplanes, planes, kernel_size=1, bias=False, stride=stride)
self.bn1 = nn.BatchNorm1d(planes)
self.conv2 = nn.Conv1d(
planes, planes, kernel_size=3, padding=1, groups=groups, bias=False)
self.bn2 = nn.BatchNorm1d(planes)
self.conv3 = nn.Conv1d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm1d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.se_module = SEModule(planes * 4, reduction=reduction)
self.downsample = downsample
self.stride = stride
class SEStrokeSeqNet(nn.Module):
''' A QuickDraw stroke based RNN with 1d CNN blocks inspired by SE (Squeeze-Excite) networks.
PyTorch code for SE blocks ripped from https://github.com/hujie-frank/SENet by way of
https://github.com/Cadene/pretrained-models.pytorch
'''
def __init__(self, in_ch=4, block=SEResNetBottleneck, layers=[3, 4], groups=1, reduction=16,
dropout_final=0., dropout_lstm=0.1, inplanes=128, input_3x3=True, num_classes=340,
rnn='lstm'):
super(SEStrokeSeqNet, self).__init__()
self.num_classes = num_classes
n = 1 # make it easy to experiment with smaller networks
reduction = reduction // n
self.inplanes = inplanes // n
if input_3x3:
layer0_modules = [
('conv1', nn.Conv1d(in_ch, 64 // n, 3, stride=1, padding=1, bias=False)),
('bn1', nn.BatchNorm1d(64 // n)),
('relu1', nn.ReLU(inplace=True)),
('conv2', nn.Conv1d(64 // n, 64 // n, 3, stride=1, padding=1, bias=False)),
('bn2', nn.BatchNorm1d(64 // n)),
('relu2', nn.ReLU(inplace=True)),
('conv3', nn.Conv1d(64 // n, inplanes, 3, stride=1, padding=1, bias=False)),
('bn3', nn.BatchNorm1d(inplanes)),
('relu3', nn.ReLU(inplace=True)),
]
else:
layer0_modules = [
('conv1', nn.Conv1d(3, inplanes, kernel_size=7, stride=1, padding=3, bias=False)),
('bn1', nn.BatchNorm1d(inplanes)),
('relu1', nn.ReLU(inplace=True)),
]
self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
self.layer1 = self._make_layer(
block,
planes=64 // n,
blocks=layers[0],
groups=1,
reduction=reduction,
downsample_kernel_size=3,
downsample_padding=1
)
self.layer2 = self._make_layer(
block,
planes=128 // n,
blocks=layers[1],
stride=1,
groups=1,
reduction=reduction,
downsample_kernel_size=1,
downsample_padding=0
)
if rnn == 'lstm':
self.rnn = torch.nn.LSTM(
input_size=rnn_input_size,
hidden_size=512,
num_layers=4,
dropout=dropout_lstm,
batch_first=True,
bidirectional=True)
else:
self.rnn = torch.nn.GRU(
input_size=rnn_input_size,
hidden_size=512,
num_layers=4,
dropout=dropout_lstm,
batch_first=True,
bidirectional=True)
# post RNN penultimate FC
self.penultimate = nn.Sequential(OrderedDict([
('fc1', nn.Linear(1024, 512)),
('bn1', nn.BatchNorm1d(512)),
('relu1', nn.ReLU()),
]))
self.dropout = nn.Dropout(dropout_final) if dropout_final is not None else None
# classifier
self.cls = nn.Linear(512, self.num_classes)
self._initialize()
def _initialize(self):
for m in self.modules():
_initialize_module(m)
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
downsample_kernel_size=1, downsample_padding=0):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv1d(self.inplanes, planes * block.expansion,
kernel_size=downsample_kernel_size, stride=stride,
padding=downsample_padding, bias=False),
nn.BatchNorm1d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, groups, reduction, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, groups, reduction))
return nn.Sequential(*layers)
def conv_features(self, x):
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
return x
def rnn_features(self, x_seq, x_seq_len):
x_seq = x_seq.transpose(2, 1)
x = rnn_utils.pack_padded_sequence(x_seq, x_seq_len, batch_first=True)
x, h = self.rnn(x)
if not isinstance(self.rnn, nn.GRU):
h = h[0]
hc = torch.cat((h[-2], h[-1]), 1)
return hc
def logits(self, x):
x = self.penultimate(x)
if self.dropout is not None:
x = self.dropout(x)
x = self.cls(x)
return x
def forward(self, x_tuple):
x_seq, x_seq_len = x_tuple
x_seq = self.conv_features(x_seq)
x = self.rnn_features(x_seq, x_seq_len)
x = self.logits(x)
return x
class BasicSeqStrokeNet(nn.Module):
''' A QuickDraw stroke based RNN with 1d CNN blocks.
General structure inspired by https://github.com/tensorflow/models/blob/master/tutorials/rnn/quickdraw/train_model.py
'''
def __init__(self, dropout=0.1, num_classes=340, rnn='lstm'):
super(BasicSeqStrokeNet, self).__init__()
self.num_classes = num_classes
conv_params = [(4, 8, 3), (8, 16, 3), (16, 32, 3), (32, 64, 3)]
conv_list = []
for n, p in enumerate(conv_params):
conv_list.append(('conv%d' % n, nn.Conv1d(p[0], p[1], p[2], stride=1, padding=1 if p[2] == 3 else 0)))
conv_list.append(('bn%d' % n, nn.BatchNorm1d(p[1])))
conv_list.append(('act%d' % n, nn.ReLU()))
self.conv = torch.nn.Sequential(OrderedDict(conv_list))
rnn_input_size = conv_params[-1][1]
if rnn == 'lstm':
self.rnn = torch.nn.LSTM(
input_size=rnn_input_size,
hidden_size=512,
num_layers=4,
dropout=0.1,
batch_first=True,
bidirectional=True)
else:
self.rnn = torch.nn.GRU(
input_size=rnn_input_size,
hidden_size=512,
num_layers=4,
dropout=0.1,
batch_first=True,
bidirectional=True)
self.cls = nn.Sequential(
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, self.num_classes))
self._initialize()
def _initialize(self):
for m in self.modules():
_initialize_module(m)
def forward(self, x):
x_seq, x_seq_len = x
x_seq = self.conv(x_seq)
x_seq.transpose_(2, 1)
x = rnn_utils.pack_padded_sequence(x_seq, x_seq_len, batch_first=True)
x, (h, c) = self.rnn(x)
x, h = self.rnn(x)
if not isinstance(self.rnn, nn.GRU):
h = h[0]
hc = torch.cat((h[-2], h[-1]), 1)
## This extra work with output was not necessary, simpler to use last hidden state
#
# xo, xl = rnn_utils.pad_packed_sequence(x, batch_first=True)
#
# idx = (xl - 1).view(-1, 1).expand(xl.size(0), xo.size(2) // 2).unsqueeze(1).cuda()
#
# # Shape: (batch_size, rnn_hidden_dim)
# last_output_forward = xo[:, :, :512].gather(1, idx).squeeze(1)
# last_output_reverse = xo[:, 0, 512:]
#
# last_output = torch.cat((last_output_forward, last_output_reverse), 1)
# assert torch.all(hc == last_output)
x = self.cls(hc)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment