Created
October 30, 2018 07:54
-
-
Save aszenz/91944f307fc94e9d9f7e053c6298b5ff to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import print_function | |
| import os | |
| import math | |
| import torch | |
| import argparse | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torch.legacy.nn as lnn | |
| import torch.nn.functional as F | |
| from functools import reduce | |
| from torch.autograd import Variable | |
| from torch.utils.serialization import load_lua | |
| class LambdaBase(nn.Sequential): | |
| def __init__(self, fn, *args): | |
| super(LambdaBase, self).__init__(*args) | |
| self.lambda_func = fn | |
| def forward_prepare(self, input): | |
| output = [] | |
| for module in self._modules.values(): | |
| output.append(module(input)) | |
| return output if output else input | |
| class Lambda(LambdaBase): | |
| def forward(self, input): | |
| return self.lambda_func(self.forward_prepare(input)) | |
| class LambdaMap(LambdaBase): | |
| def forward(self, input): | |
| # result is Variables list [Variable1, Variable2, ...] | |
| return list(map(self.lambda_func,self.forward_prepare(input))) | |
| class LambdaReduce(LambdaBase): | |
| def forward(self, input): | |
| # result is a Variable | |
| return reduce(self.lambda_func,self.forward_prepare(input)) | |
| def copy_param(m,n): | |
| if m.weight is not None: n.weight.data.copy_(m.weight) | |
| if m.bias is not None: n.bias.data.copy_(m.bias) | |
| if hasattr(n,'running_mean'): n.running_mean.copy_(m.running_mean) | |
| if hasattr(n,'running_var'): n.running_var.copy_(m.running_var) | |
| def add_submodule(seq, *args): | |
| for n in args: | |
| seq.add_module(str(len(seq._modules)),n) | |
| def lua_recursive_model(module,seq): | |
| for m in module.modules: | |
| name = type(m).__name__ | |
| real = m | |
| if name == 'TorchObject': | |
| name = m._typename.replace('cudnn.','') | |
| m = m._obj | |
| if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM': | |
| if not hasattr(m,'groups') or m.groups is None: m.groups=1 | |
| n = nn.Conv2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,bias=(m.bias is not None)) | |
| copy_param(m,n) | |
| add_submodule(seq,n) | |
| elif name == 'SpatialBatchNormalization': | |
| n = nn.BatchNorm2d(m.running_mean.size(0), m.eps, m.momentum, m.affine) | |
| copy_param(m,n) | |
| add_submodule(seq,n) | |
| elif name == 'VolumetricBatchNormalization': | |
| n = nn.BatchNorm3d(m.running_mean.size(0), m.eps, m.momentum, m.affine) | |
| copy_param(m, n) | |
| add_submodule(seq, n) | |
| elif name == 'ReLU': | |
| n = nn.ReLU() | |
| add_submodule(seq,n) | |
| elif name == 'Sigmoid': | |
| n = nn.Sigmoid() | |
| add_submodule(seq,n) | |
| elif name == 'SpatialMaxPooling': | |
| n = nn.MaxPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode) | |
| add_submodule(seq,n) | |
| elif name == 'SpatialAveragePooling': | |
| n = nn.AvgPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode) | |
| add_submodule(seq,n) | |
| elif name == 'SpatialUpSamplingNearest': | |
| n = nn.UpsamplingNearest2d(scale_factor=m.scale_factor) | |
| add_submodule(seq,n) | |
| elif name == 'View': | |
| n = Lambda(lambda x: x.view(x.size(0),-1)) | |
| add_submodule(seq,n) | |
| elif name == 'Reshape': | |
| n = Lambda(lambda x: x.view(x.size(0),-1)) | |
| add_submodule(seq,n) | |
| elif name == 'Linear': | |
| # Linear in pytorch only accept 2D input | |
| n1 = Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ) | |
| n2 = nn.Linear(m.weight.size(1),m.weight.size(0),bias=(m.bias is not None)) | |
| copy_param(m,n2) | |
| n = nn.Sequential(n1,n2) | |
| add_submodule(seq,n) | |
| elif name == 'Dropout': | |
| m.inplace = False | |
| n = nn.Dropout(m.p) | |
| add_submodule(seq,n) | |
| elif name == 'SoftMax': | |
| n = nn.Softmax() | |
| add_submodule(seq,n) | |
| elif name == 'Identity': | |
| n = Lambda(lambda x: x) # do nothing | |
| add_submodule(seq,n) | |
| elif name == 'SpatialFullConvolution': | |
| n = nn.ConvTranspose2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.adjW,m.adjH)) | |
| copy_param(m,n) | |
| add_submodule(seq,n) | |
| elif name == 'VolumetricFullConvolution': | |
| n = nn.ConvTranspose3d(m.nInputPlane,m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH),(m.adjT,m.adjW,m.adjH),m.groups) | |
| copy_param(m,n) | |
| add_submodule(seq, n) | |
| elif name == 'SpatialReplicationPadding': | |
| n = nn.ReplicationPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b)) | |
| add_submodule(seq,n) | |
| elif name == 'SpatialReflectionPadding': | |
| n = nn.ReflectionPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b)) | |
| add_submodule(seq,n) | |
| elif name == 'Copy': | |
| n = Lambda(lambda x: x) # do nothing | |
| add_submodule(seq,n) | |
| elif name == 'Narrow': | |
| n = Lambda(lambda x,a=(m.dimension,m.index,m.length): x.narrow(*a)) | |
| add_submodule(seq,n) | |
| elif name == 'SpatialCrossMapLRN': | |
| lrn = lnn.SpatialCrossMapLRN(m.size,m.alpha,m.beta,m.k) | |
| n = Lambda(lambda x,lrn=lrn: Variable(lrn.forward(x.data))) | |
| add_submodule(seq,n) | |
| elif name == 'Sequential': | |
| n = nn.Sequential() | |
| lua_recursive_model(m,n) | |
| add_submodule(seq,n) | |
| elif name == 'ConcatTable': # output is list | |
| n = LambdaMap(lambda x: x) | |
| lua_recursive_model(m,n) | |
| add_submodule(seq,n) | |
| elif name == 'CAddTable': # input is list | |
| n = LambdaReduce(lambda x,y: x+y) | |
| add_submodule(seq,n) | |
| elif name == 'Concat': | |
| dim = m.dimension | |
| n = LambdaReduce(lambda x,y,dim=dim: torch.cat((x,y),dim)) | |
| lua_recursive_model(m,n) | |
| add_submodule(seq,n) | |
| elif name == 'TorchObject': | |
| print('Not Implement',name,real._typename) | |
| else: | |
| print('Not Implement',name) | |
| def lua_recursive_source(module): | |
| s = [] | |
| for m in module.modules: | |
| name = type(m).__name__ | |
| real = m | |
| if name == 'TorchObject': | |
| name = m._typename.replace('cudnn.','') | |
| m = m._obj | |
| if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM': | |
| if not hasattr(m,'groups') or m.groups is None: m.groups=1 | |
| s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(m.nInputPlane, | |
| m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,m.bias is not None)] | |
| elif name == 'SpatialBatchNormalization': | |
| s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)] | |
| elif name == 'VolumetricBatchNormalization': | |
| s += ['nn.BatchNorm3d({},{},{},{}),#BatchNorm3d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)] | |
| elif name == 'ReLU': | |
| s += ['nn.ReLU()'] | |
| elif name == 'Sigmoid': | |
| s += ['nn.Sigmoid()'] | |
| elif name == 'SpatialMaxPooling': | |
| s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)] | |
| elif name == 'SpatialAveragePooling': | |
| s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)] | |
| elif name == 'SpatialUpSamplingNearest': | |
| s += ['nn.UpsamplingNearest2d(scale_factor={})'.format(m.scale_factor)] | |
| elif name == 'View': | |
| s += ['Lambda(lambda x: x.view(x.size(0),-1)), # View'] | |
| elif name == 'Reshape': | |
| s += ['Lambda(lambda x: x.view(x.size(0),-1)), # Reshape'] | |
| elif name == 'Linear': | |
| s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )' | |
| s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1),m.weight.size(0),(m.bias is not None)) | |
| s += ['nn.Sequential({},{}),#Linear'.format(s1,s2)] | |
| elif name == 'Dropout': | |
| s += ['nn.Dropout({})'.format(m.p)] | |
| elif name == 'SoftMax': | |
| s += ['nn.Softmax()'] | |
| elif name == 'Identity': | |
| s += ['Lambda(lambda x: x), # Identity'] | |
| elif name == 'SpatialFullConvolution': | |
| s += ['nn.ConvTranspose2d({},{},{},{},{},{})'.format(m.nInputPlane, | |
| m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.adjW,m.adjH))] | |
| elif name == 'VolumetricFullConvolution': | |
| s += ['nn.ConvTranspose3d({},{},{},{},{},{},{})'.format(m.nInputPlane, | |
| m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH),(m.adjT,m.adjW,m.adjH),m.groups)] | |
| elif name == 'SpatialReplicationPadding': | |
| s += ['nn.ReplicationPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))] | |
| elif name == 'SpatialReflectionPadding': | |
| s += ['nn.ReflectionPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))] | |
| elif name == 'Copy': | |
| s += ['Lambda(lambda x: x), # Copy'] | |
| elif name == 'Narrow': | |
| s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format((m.dimension,m.index,m.length))] | |
| elif name == 'SpatialCrossMapLRN': | |
| lrn = 'lnn.SpatialCrossMapLRN(*{})'.format((m.size,m.alpha,m.beta,m.k)) | |
| s += ['Lambda(lambda x,lrn={}: Variable(lrn.forward(x.data)))'.format(lrn)] | |
| elif name == 'Sequential': | |
| s += ['nn.Sequential( # Sequential'] | |
| s += lua_recursive_source(m) | |
| s += [')'] | |
| elif name == 'ConcatTable': | |
| s += ['LambdaMap(lambda x: x, # ConcatTable'] | |
| s += lua_recursive_source(m) | |
| s += [')'] | |
| elif name == 'CAddTable': | |
| s += ['LambdaReduce(lambda x,y: x+y), # CAddTable'] | |
| elif name == 'Concat': | |
| dim = m.dimension | |
| s += ['LambdaReduce(lambda x,y,dim={}: torch.cat((x,y),dim), # Concat'.format(m.dimension)] | |
| s += lua_recursive_source(m) | |
| s += [')'] | |
| else: | |
| s += '# ' + name + ' Not Implement,\n' | |
| s = map(lambda x: '\t{}'.format(x),s) | |
| return s | |
| def simplify_source(s): | |
| s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d',')'),s) | |
| s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d',')'),s) | |
| s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d',')'),s) | |
| s = map(lambda x: x.replace(',bias=True),#Conv2d',')'),s) | |
| s = map(lambda x: x.replace('),#Conv2d',')'),s) | |
| s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d',')'),s) | |
| s = map(lambda x: x.replace('),#BatchNorm2d',')'),s) | |
| s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d',')'),s) | |
| s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d',')'),s) | |
| s = map(lambda x: x.replace('),#MaxPool2d',')'),s) | |
| s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d',')'),s) | |
| s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d',')'),s) | |
| s = map(lambda x: x.replace(',bias=True)),#Linear',')), # Linear'),s) | |
| s = map(lambda x: x.replace(')),#Linear',')), # Linear'),s) | |
| s = map(lambda x: '{},\n'.format(x),s) | |
| s = map(lambda x: x[1:],s) | |
| s = reduce(lambda x,y: x+y, s) | |
| return s | |
| def torch_to_pytorch(t7_filename,outputname=None): | |
| model = load_lua(t7_filename,unknown_classes=True) | |
| if type(model).__name__=='hashable_uniq_dict': model=model.model | |
| model.gradInput = None | |
| slist = lua_recursive_source(lnn.Sequential().add(model)) | |
| s = simplify_source(slist) | |
| header = ''' | |
| import torch | |
| import torch.nn as nn | |
| import torch.legacy.nn as lnn | |
| from functools import reduce | |
| from torch.autograd import Variable | |
| class LambdaBase(nn.Sequential): | |
| def __init__(self, fn, *args): | |
| super(LambdaBase, self).__init__(*args) | |
| self.lambda_func = fn | |
| def forward_prepare(self, input): | |
| output = [] | |
| for module in self._modules.values(): | |
| output.append(module(input)) | |
| return output if output else input | |
| class Lambda(LambdaBase): | |
| def forward(self, input): | |
| return self.lambda_func(self.forward_prepare(input)) | |
| class LambdaMap(LambdaBase): | |
| def forward(self, input): | |
| return list(map(self.lambda_func,self.forward_prepare(input))) | |
| class LambdaReduce(LambdaBase): | |
| def forward(self, input): | |
| return reduce(self.lambda_func,self.forward_prepare(input)) | |
| ''' | |
| varname = t7_filename.replace('.t7','').replace('.','_').replace('-','_') | |
| s = '{}\n\n{} = {}'.format(header,varname,s[:-2]) | |
| if outputname is None: outputname=varname | |
| with open(outputname+'.py', "w") as pyfile: | |
| pyfile.write(s) | |
| n = nn.Sequential() | |
| lua_recursive_model(model,n) | |
| torch.save(n.state_dict(),outputname+'.pth') | |
| parser = argparse.ArgumentParser(description='Convert torch t7 model to pytorch') | |
| parser.add_argument('--model','-m', type=str, required=True, | |
| help='torch model file in t7 format') | |
| parser.add_argument('--output', '-o', type=str, default=None, | |
| help='output file name prefix, xxx.py xxx.pth') | |
| args = parser.parse_args() | |
| torch_to_pytorch(args.model,args.output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment