Last active
April 28, 2018 19:50
-
-
Save ndronen/19154831c2049a69e8d53dea8cf3e744 to your computer and use it in GitHub Desktop.
Semantic segmentation with ENet in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
A quick, partial implementation of ENet (https://arxiv.org/abs/1606.02147) using PyTorch. | |
The original Torch ENet implementation can process a 480x360 image in ~12 ms (on a P2 AWS | |
instance). TensorFlow takes ~35 ms. The PyTorch implementation takes ~25 ms, an improvement | |
over TensorFlow, but worse than the original Torch. | |
""" | |
from __future__ import absolute_import | |
import sys | |
import time | |
import torch | |
import torch.optim | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
ENCODER_PARAMS = [ | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 16, | |
'output_channels': 64, | |
'downsample': True, | |
'dropout_prob': 0.01 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 64, | |
'output_channels': 64, | |
'downsample': False, | |
'dropout_prob': 0.01 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 64, | |
'output_channels': 64, | |
'downsample': False, | |
'dropout_prob': 0.01 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 64, | |
'output_channels': 64, | |
'downsample': False, | |
'dropout_prob': 0.01 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 64, | |
'output_channels': 64, | |
'downsample': False, | |
'dropout_prob': 0.01 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 64, | |
'output_channels': 128, | |
'downsample': True, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
}, | |
{ | |
'internal_scale': 4, | |
'use_relu': True, | |
'asymmetric': False, | |
'dilated': False, | |
'input_channels': 128, | |
'output_channels': 128, | |
'downsample': False, | |
'dropout_prob': 0.1 | |
} | |
] | |
DECODER_PARAMS = [ | |
{ | |
'input_channels': 128, | |
'output_channels': 128, | |
'upsample': False, | |
'pooling_module': None | |
}, | |
{ | |
'input_channels': 128, | |
'output_channels': 64, | |
'upsample': True, | |
'pooling_module': None | |
}, | |
{ | |
'input_channels': 64, | |
'output_channels': 64, | |
'upsample': False, | |
'pooling_module': None | |
}, | |
{ | |
'input_channels': 64, | |
'output_channels': 64, | |
'upsample': False, | |
'pooling_module': None | |
}, | |
{ | |
'input_channels': 64, | |
'output_channels': 16, | |
'upsample': True, | |
'pooling_module': None | |
}, | |
{ | |
'input_channels': 16, | |
'output_channels': 16, | |
'upsample': False, | |
'pooling_module': None | |
} | |
] | |
class InitialBlock(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv = nn.Conv2d( | |
3, 13, (3, 3), | |
stride=2, padding=1, bias=True) | |
self.pool = nn.MaxPool2d(2, stride=2) | |
self.batch_norm = nn.BatchNorm2d(16, eps=1e-3) | |
def forward(self, input): | |
output = torch.cat([self.conv(input), self.pool(input)], 1) | |
output = self.batch_norm(output) | |
return F.relu(output) | |
class EncoderMainPath(nn.Module): | |
def __init__(self, internal_scale=None, use_relu=None, asymmetric=None, dilated=None, input_channels=None, output_channels=None, downsample=None, dropout_prob=None): | |
super().__init__() | |
internal_channels = output_channels // internal_scale | |
input_stride = downsample and 2 or 1 | |
self.__dict__.update(locals()) | |
del self.self | |
self.input_conv = nn.Conv2d( | |
input_channels, internal_channels, input_stride, | |
stride=input_stride, padding=0, bias=False) | |
self.input_batch_norm = nn.BatchNorm2d(internal_channels, eps=1e-03) | |
# TODO: use dilated and asymmetric convolutions, as in the | |
# original implementation. For now just add a 3x3 convolution. | |
self.middle_conv = nn.Conv2d( | |
internal_channels, internal_channels, 3, | |
stride=1, padding=1, bias=True) | |
self.middle_batch_norm = nn.BatchNorm2d(internal_channels, eps=1e-03) | |
self.output_conv = nn.Conv2d( | |
internal_channels, output_channels, 1, | |
stride=1, padding=0, bias=False) | |
self.output_batch_norm = nn.BatchNorm2d(output_channels, eps=1e-03) | |
self.dropout = nn.Dropout2d(dropout_prob) | |
def forward(self, input): | |
output = self.input_conv(input) | |
output = self.input_batch_norm(output) | |
output = F.relu(output) | |
output = self.middle_conv(output) | |
output = self.middle_batch_norm(output) | |
output = F.relu(output) | |
output = self.output_conv(output) | |
output = self.output_batch_norm(output) | |
output = self.dropout(output) | |
return output | |
class EncoderOtherPath(nn.Module): | |
def __init__(self, internal_scale=None, use_relu=None, asymmetric=None, dilated=None, input_channels=None, output_channels=None, downsample=None, **kwargs): | |
super().__init__() | |
self.__dict__.update(locals()) | |
del self.self | |
if downsample: | |
self.pool = nn.MaxPool2d(2, stride=2, return_indices=True) | |
def forward(self, input): | |
output = input | |
if self.downsample: | |
output, self.indices = self.pool(input) | |
if self.output_channels != self.input_channels: | |
new_size = [1, 1, 1, 1] | |
new_size[1] = self.output_channels // self.input_channels | |
output = output.repeat(*new_size) | |
return output | |
class EncoderModule(nn.Module): | |
def __init__(self, **kwargs): | |
super().__init__() | |
self.main = EncoderMainPath(**kwargs) | |
self.other = EncoderOtherPath(**kwargs) | |
def forward(self, input): | |
main = self.main(input) | |
other = self.other(input) | |
return F.relu(main + other) | |
class Encoder(nn.Module): | |
def __init__(self, params, nclasses): | |
super().__init__() | |
self.initial_block = InitialBlock() | |
self.layers = [] | |
for i, params in enumerate(params): | |
layer_name = 'encoder_{:02d}'.format(i) | |
layer = EncoderModule(**params) | |
super().__setattr__(layer_name, layer) | |
self.layers.append(layer) | |
self.output_conv = nn.Conv2d( | |
128, nclasses, 1, | |
stride=1, padding=0, bias=True) | |
def forward(self, input, predict=False): | |
output = self.initial_block(input) | |
for layer in self.layers: | |
output = layer(output) | |
if predict: | |
output = self.output_conv(output) | |
return output | |
class DecoderMainPath(nn.Module): | |
def __init__(self, input_channels=None, output_channels=None, upsample=None, pooling_module=None): | |
super().__init__() | |
internal_channels = output_channels // 4 | |
input_stride = 2 if upsample is True else 1 | |
self.__dict__.update(locals()) | |
del self.self | |
self.input_conv = nn.Conv2d( | |
input_channels, internal_channels, 1, | |
stride=1, padding=0, bias=False) | |
self.input_batch_norm = nn.BatchNorm2d(internal_channels, eps=1e-03) | |
if not upsample: | |
self.middle_conv = nn.Conv2d( | |
internal_channels, internal_channels, 3, | |
stride=1, padding=1, bias=True) | |
else: | |
self.middle_conv = nn.ConvTranspose2d( | |
internal_channels, internal_channels, 3, | |
stride=2, padding=1, output_padding=1, | |
bias=True) | |
self.middle_batch_norm = nn.BatchNorm2d(internal_channels, eps=1e-03) | |
self.output_conv = nn.Conv2d( | |
internal_channels, output_channels, 1, | |
stride=1, padding=0, bias=False) | |
self.output_batch_norm = nn.BatchNorm2d(output_channels, eps=1e-03) | |
def forward(self, input): | |
output = self.input_conv(input) | |
output = self.input_batch_norm(output) | |
output = F.relu(output) | |
output = self.middle_conv(output) | |
output = self.middle_batch_norm(output) | |
output = F.relu(output) | |
output = self.output_conv(output) | |
output = self.output_batch_norm(output) | |
return output | |
class DecoderOtherPath(nn.Module): | |
def __init__(self, input_channels=None, output_channels=None, upsample=None, pooling_module=None): | |
super().__init__() | |
self.__dict__.update(locals()) | |
del self.self | |
if output_channels != input_channels or upsample: | |
self.conv = nn.Conv2d( | |
input_channels, output_channels, 1, | |
stride=1, padding=0, bias=False) | |
self.batch_norm = nn.BatchNorm2d(output_channels, eps=1e-03) | |
if upsample and pooling_module: | |
self.unpool = nn.MaxUnpool2d(2, stride=2, padding=0) | |
def forward(self, input): | |
output = input | |
if self.output_channels != self.input_channels or self.upsample: | |
output = self.conv(output) | |
output = self.batch_norm(output) | |
if self.upsample and self.pooling_module: | |
output_size = list(output.size()) | |
output_size[2] *= 2 | |
output_size[3] *= 2 | |
output = self.unpool( | |
output, self.pooling_module.indices, | |
output_size=output_size) | |
return output | |
class DecoderModule(nn.Module): | |
def __init__(self, **kwargs): | |
super().__init__() | |
self.main = DecoderMainPath(**kwargs) | |
self.other = DecoderOtherPath(**kwargs) | |
def forward(self, input): | |
main = self.main(input) | |
other = self.other(input) | |
return F.relu(main + other) | |
class Decoder(nn.Module): | |
def __init__(self, params, nclasses, encoder): | |
super().__init__() | |
self.encoder = encoder | |
self.pooling_modules = [] | |
for mod in self.encoder.modules(): | |
try: | |
if mod.other.downsample: | |
self.pooling_modules.append(mod.other) | |
except AttributeError: | |
pass | |
self.layers = [] | |
for i, params in enumerate(params): | |
if params['upsample']: | |
params['pooling_module'] = self.pooling_modules.pop(-1) | |
layer = DecoderModule(**params) | |
self.layers.append(layer) | |
layer_name = 'decoder{:02d}'.format(i) | |
super().__setattr__(layer_name, layer) | |
self.output_conv = nn.ConvTranspose2d( | |
16, nclasses, 2, | |
stride=2, padding=0, output_padding=0, bias=True) | |
def forward(self, input): | |
output = self.encoder(input, predict=False) | |
for layer in self.layers: | |
output = layer(output) | |
output = self.output_conv(output) | |
return output | |
class ENet(nn.Module): | |
def __init__(self, nclasses): | |
super().__init__() | |
self.encoder = Encoder(ENCODER_PARAMS, nclasses) | |
self.decoder = Decoder(DECODER_PARAMS, nclasses, self.encoder) | |
def forward(self, input, only_encode=False, predict=True): | |
if only_encode: | |
return self.encoder.forward(input, predict=predict) | |
else: | |
return self.decoder.forward(input) | |
if __name__ == '__main__': | |
nclasses = 15 | |
train = True | |
niter = 100 | |
times = torch.FloatTensor(niter) | |
batch_size = 1 | |
nchannels = 3 | |
height = 360 | |
width = 480 | |
model = ENet(nclasses) | |
loss = nn.NLLLoss2d() | |
softmax = nn.Softmax() | |
model.cuda() | |
loss.cuda() | |
softmax.cuda() | |
optimizer = torch.optim.Adam(model.parameters()) | |
for i in range(niter): | |
x = torch.FloatTensor( | |
torch.randn(batch_size, nchannels, height, width)) | |
y = torch.LongTensor(batch_size, height, width) | |
y.random_(nclasses) | |
x.pin_memory() | |
y.pin_memory() | |
input = Variable(x.cuda(async=True)) | |
target = Variable(y.cuda(async=True)) | |
sys.stdout.write('\r{}/{}'.format(i, niter)) | |
start = time.time() | |
if train: | |
optimizer.zero_grad() | |
model.train() | |
else: | |
model.eval() | |
output = model(input) | |
if train: | |
loss_ = loss.forward(output, target) | |
loss_.backward() | |
optimizer.step() | |
else: | |
output_2d = output.view(height * width, nclasses) | |
pred = softmax(output_2d).view(output.size()) | |
times[i] = time.time() - start | |
sys.stdout.write('\n') | |
print('average time per image {:04f} sec'.format( | |
times.mean())) |
I use this code to train CamVid dataset, but it is even worse than fcn32s though it is fast enough, is this implement testing for real code?
Shameless plug to my Pytorch-Enet implementation (based on a converter): https://github.com/bermanmaxim/Enet-PyTorch
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi,
Is it possible to use your implementation of enet on any kind of image? For example on pairs of images (greyscaled+groundtruth label)?
The images can be found here :
https://github.com/chromosome-seg/DeepFISH/tree/master/dataset/LowRes/train
Best regards
jp pommier