Created
October 25, 2018 06:09
-
-
Save nulledge/528bf8402e6fc35c2acf8b2661469ece to your computer and use it in GitHub Desktop.
Torch7 to PyTorch converter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import skimage\n", | |
"import skimage.io\n", | |
"import numpy as np\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"from dotmap import DotMap\n", | |
"from torch.utils.serialization import load_lua\n", | |
"from operator import xor" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Layer():\n", | |
" Identity = 0b000000001\n", | |
" Convolution = 0b000000010\n", | |
" Batch_Norm = 0b000000100\n", | |
" ReLU = 0b000001000\n", | |
" Sequential = 0b000010000\n", | |
" Max_Pool = 0b000100000\n", | |
" Add = 0b001000000\n", | |
" Nearest_Upsample = 0b010000000\n", | |
" Concat = 0b100000000\n", | |
" \n", | |
" @staticmethod\n", | |
" def to_string(layer):\n", | |
" return 'Identity' if layer & Layer.Identity else \\\n", | |
" 'Convolution' if layer & Layer.Convolution else \\\n", | |
" 'Batch_Norm' if layer & Layer.Batch_Norm else \\\n", | |
" 'ReLU' if layer & Layer.ReLU else \\\n", | |
" 'Sequential' if layer & Layer.Sequential else \\\n", | |
" 'Max_Pool' if layer & Layer.Max_Pool else \\\n", | |
" 'Add' if layer & Layer.Add else \\\n", | |
" 'Nearest_Upsample' if layer & Layer.Nearest_Upsample else \\\n", | |
" 'Concat' if layer & Layer.Concat else \\\n", | |
" None\n", | |
" \n", | |
" @staticmethod\n", | |
" def from_name(name):\n", | |
" return Layer.Identity if name.startswith('nn.Identity') else \\\n", | |
" Layer.Convolution if name.startswith('nn.SpatialConvolution') else \\\n", | |
" Layer.Batch_Norm if name.startswith('nn.SpatialBatchNormalization') else \\\n", | |
" Layer.ReLU if name.startswith('nn.ReLU') else \\\n", | |
" Layer.Sequential if name.startswith('nn.Sequential') else \\\n", | |
" Layer.Max_Pool if name.startswith('nn.SpatialMaxPooling') else \\\n", | |
" Layer.Add if name.startswith('nn.CAddTable') else \\\n", | |
" Layer.Nearest_Upsample if name.startswith('nn.SpatialUpSamplingNearest') else \\\n", | |
" Layer.Concat if name.startswith('torch.legacy.nn.ConcatTable.ConcatTable') else \\\n", | |
" None" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Node:\n", | |
" def __init__(self, forwardnode, module):\n", | |
" forwardnode = forwardnode.split('\\n')[0]\n", | |
" self.id, children = forwardnode.split(';')\n", | |
" self.children = [word for word in children.split(' ') if word]\n", | |
" self.data = module\n", | |
" self.op = Layer.from_name(str(self.data))\n", | |
" \n", | |
" assert self.op is not None\n", | |
" assert self.id.isdigit()\n", | |
" assert all([child.isdigit() for child in self.children])\n", | |
" \n", | |
" def __str__(self):\n", | |
" return '{node}; {operation}'.format(node=self.id, operation=Layer.to_string(self.op))\n", | |
" \n", | |
" @staticmethod\n", | |
" def _get_param(module):\n", | |
" op = Layer.from_name(str(module))\n", | |
" if op & Layer.Convolution:\n", | |
" param = module.weight, module.bias\n", | |
" return op, param\n", | |
" \n", | |
" elif op & Layer.Batch_Norm:\n", | |
" param = module.running_mean, module.running_var, module.weight, module.bias, module.momentum\n", | |
" return op, param\n", | |
" \n", | |
" elif op & (Layer.Sequential | Layer.Concat):\n", | |
" sub_modules = [Node._get_param(sub_module) for sub_module in module.modules]\n", | |
" return op, sub_modules\n", | |
" \n", | |
" else:\n", | |
" param = None\n", | |
" return op, param\n", | |
" \n", | |
" \n", | |
" def get_param(self):\n", | |
" return Node._get_param(self.data)\n", | |
" \n", | |
" \n", | |
" @staticmethod\n", | |
" def _copy_to_convolution(source, target):\n", | |
" op, param = source\n", | |
" weight, bias = param\n", | |
"\n", | |
" assert target.weight.shape == weight.shape\n", | |
" assert target.bias.shape == bias.shape\n", | |
"\n", | |
" target.weight.data = weight\n", | |
" target.bias.data = bias\n", | |
" \n", | |
" \n", | |
" @staticmethod\n", | |
" def _copy_to_batch_norm(source, target):\n", | |
" op, param = source\n", | |
" running_mean, running_var, weight, bias, momentum = param\n", | |
"\n", | |
" assert target.running_mean.shape == running_mean.shape\n", | |
" assert target.running_var.shape == running_var.shape\n", | |
" assert target.weight.shape == weight.shape\n", | |
" assert target.bias.shape == bias.shape\n", | |
" assert isinstance(target.momentum, float) and isinstance(momentum, float)\n", | |
"\n", | |
" target.running_mean = running_mean\n", | |
" target.running_var = running_var\n", | |
" target.weight.data = weight\n", | |
" target.bias.data = bias\n", | |
" target.momentum = momentum\n", | |
" \n", | |
" \n", | |
" @staticmethod\n", | |
" def _copy_to_residual(source, target):\n", | |
" op, sub_modules = source\n", | |
" op, sub_modules = sub_modules[0]\n", | |
" op, sub_modules = sub_modules[0]\n", | |
"\n", | |
" assert op & Layer.Sequential\n", | |
"\n", | |
" for torch_module, pytorch_module in zip(sub_modules, target.resSeq):\n", | |
" op, _ = torch_module\n", | |
"\n", | |
" if op & Layer.Batch_Norm:\n", | |
" Node._copy_to_batch_norm(torch_module, pytorch_module)\n", | |
"\n", | |
" elif op & Layer.ReLU:\n", | |
" continue\n", | |
"\n", | |
" elif op & Layer.Convolution:\n", | |
" Node._copy_to_convolution(torch_module, pytorch_module)\n", | |
"\n", | |
" else:\n", | |
" raise NotImplementedError()\n", | |
"\n", | |
" op, sub_modules = source\n", | |
" op, sub_modules = sub_modules[0]\n", | |
" op, param = sub_modules[1]\n", | |
"\n", | |
" if op & Layer.Identity:\n", | |
" pass\n", | |
"\n", | |
" elif op & Layer.Sequential:\n", | |
" sub_modules = param\n", | |
" op, param = sub_modules[0]\n", | |
"\n", | |
" assert op & Layer.Convolution and len(sub_modules) == 1\n", | |
" \n", | |
" Node._copy_to_convolution(sub_modules[0], target.conv_skip)\n", | |
" \n", | |
" @staticmethod\n", | |
" def _copy_to(source, target):\n", | |
" op, param = source\n", | |
" if op & Layer.Convolution:\n", | |
" Node._copy_to_convolution(source, target)\n", | |
" \n", | |
" elif op & Layer.Batch_Norm:\n", | |
" Node._copy_to_batch_norm(source, target)\n", | |
" \n", | |
" elif op & Layer.Sequential:\n", | |
" if Node._is_residual(source):\n", | |
" Node._copy_to_residual(source, target)\n", | |
" else:\n", | |
" sub_modules = param\n", | |
" for torch_module, pytorch_module in zip(sub_modules, target):\n", | |
" Node._copy_to(torch_module, pytorch_module)\n", | |
" raise NotImplementedError()\n", | |
" \n", | |
" elif op & (Layer.Identity | Layer.ReLU | Layer.Add | Layer.Max_Pool | Layer.Nearest_Upsample):\n", | |
" pass\n", | |
" \n", | |
" else:\n", | |
" raise NotImplementedError()\n", | |
" \n", | |
" def copy_to(self, target):\n", | |
" Node._copy_to(self.get_param(), target)\n", | |
" \n", | |
" @staticmethod\n", | |
" def _is_residual(source):\n", | |
" return True" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Graph:\n", | |
" def __init__(self):\n", | |
" forwardnodes, modules = self.load_data()\n", | |
"\n", | |
" self.node = list()\n", | |
" for forwardnode, module in zip(forwardnodes, modules):\n", | |
" self.node.append(Node(forwardnode, module))\n", | |
" \n", | |
" def load_data(self):\n", | |
" modules = load_lua('cpu.t7')\n", | |
" with open('forwardnodes.txt', 'r') as fd:\n", | |
" lines = fd.readlines()\n", | |
" forwardnodes = lines[1:] # The 1st forwardnode is dummy, a input distributor\n", | |
" return forwardnodes, modules\n", | |
" \n", | |
" def find_by_id(self, key):\n", | |
" for _, node in enumerate(self.node):\n", | |
" if node.id == key:\n", | |
" return node\n", | |
" raise LookupError()\n", | |
" \n", | |
" def copy_to_hg(self, first_res_in_torch7, hg):\n", | |
" res_in_torch7 = [\n", | |
" 0, # 64x64 skip\n", | |
" 2, # 32x32 res\n", | |
" 3, # 32x32 skip\n", | |
" 5, # 16x16 res\n", | |
" 6, # 16x16 skip\n", | |
" 8, # 8x8 res\n", | |
" 9, # 8x8 skip\n", | |
" 11, # 4x4 res\n", | |
" 12, # 4x4 lowest\n", | |
" 13, # 4x4 res\n", | |
" 16, # 8x8 res\n", | |
" 19, # 16x16 res\n", | |
" 22, # 32x32 res\n", | |
" ]\n", | |
" res_in_pytorch = [\n", | |
" hg.res1[0],\n", | |
" hg.res2[0],\n", | |
" hg.subHourglass.res1[0],\n", | |
" hg.subHourglass.res2[0],\n", | |
" hg.subHourglass.subHourglass.res1[0],\n", | |
" hg.subHourglass.subHourglass.res2[0],\n", | |
" hg.subHourglass.subHourglass.subHourglass.res1[0],\n", | |
" hg.subHourglass.subHourglass.subHourglass.res2[0],\n", | |
" hg.subHourglass.subHourglass.subHourglass.resWaist[0],\n", | |
" hg.subHourglass.subHourglass.subHourglass.res3[0],\n", | |
" hg.subHourglass.subHourglass.res3[0],\n", | |
" hg.subHourglass.res3[0],\n", | |
" hg.res3[0],\n", | |
" ]\n", | |
" for torch7_idx, pytorch_module in zip(res_in_torch7, res_in_pytorch):\n", | |
" torch7_module = self.node[first_res_in_torch7 + torch7_idx]\n", | |
" torch7_module.copy_to(pytorch_module)\n", | |
" \n", | |
" def copy_to_intermediate(self, first_conv_in_torch7, lin, htmap, llBar, htmapBar):\n", | |
" self.node[first_conv_in_torch7 + 0].copy_to(lin[0]) # Conv\n", | |
" self.node[first_conv_in_torch7 + 1].copy_to(lin[1]) # Batch-norm\n", | |
" self.node[first_conv_in_torch7 + 2].copy_to(lin[2]) # ReLU, ll in Newell's\n", | |
" \n", | |
" self.node[first_conv_in_torch7 + 3].copy_to(htmap) # Conv, tmpOut in Newell's\n", | |
" \n", | |
" if llBar == None and htmapBar == None:\n", | |
" return\n", | |
" \n", | |
" self.node[first_conv_in_torch7 + 4].copy_to(llBar) # Conv, ll_ in Newell's\n", | |
" self.node[first_conv_in_torch7 + 5].copy_to(htmapBar) # Conv, tmpOut_ in Newell's" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"graph = Graph()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class CONFIG:\n", | |
" nStacks = 8\n", | |
" nFeatures = 256\n", | |
" nModules = 1\n", | |
" nJoints = 16\n", | |
" nDepth = 4" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ResModule(nn.Module):\n", | |
" def __init__(self, in_channels, out_channels):\n", | |
" super(ResModule, self).__init__()\n", | |
" self.in_channels = in_channels\n", | |
" self.out_channels = out_channels\n", | |
" self.conv_skip = nn.Conv2d(in_channels, out_channels, kernel_size=1)\n", | |
"\n", | |
" self.resSeq = nn.Sequential(\n", | |
" nn.BatchNorm2d(in_channels),\n", | |
" nn.ReLU(),\n", | |
" nn.Conv2d(in_channels, out_channels // 2, kernel_size=1),\n", | |
" nn.BatchNorm2d(out_channels // 2),\n", | |
" nn.ReLU(),\n", | |
" nn.Conv2d(out_channels // 2, out_channels // 2, kernel_size=3, stride=1, padding=1),\n", | |
" nn.BatchNorm2d(out_channels // 2),\n", | |
" nn.ReLU(),\n", | |
" nn.Conv2d(out_channels // 2, out_channels, kernel_size=1)\n", | |
" )\n", | |
"\n", | |
" def forward(self, x):\n", | |
" if self.in_channels != self.out_channels:\n", | |
" skip = self.conv_skip(x)\n", | |
" else:\n", | |
" skip = x\n", | |
"\n", | |
" return skip + self.resSeq(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Hourglass(nn.Module):\n", | |
" def __init__(self, hg_depth, nFeatures):\n", | |
" super(Hourglass, self).__init__()\n", | |
" self.hg_depth = hg_depth\n", | |
" self.nFeatures = nFeatures\n", | |
" res1list = [ResModule(nFeatures, nFeatures) for _ in range(CONFIG.nModules)]\n", | |
" res2list = [ResModule(nFeatures, nFeatures) for _ in range(CONFIG.nModules)]\n", | |
" res3list = [ResModule(nFeatures, nFeatures) for _ in range(CONFIG.nModules)]\n", | |
" self.res1 = nn.Sequential(*res1list)\n", | |
" self.res2 = nn.Sequential(*res2list)\n", | |
" self.res3 = nn.Sequential(*res3list)\n", | |
" self.subHourglass = None\n", | |
" self.resWaist = None\n", | |
" if self.hg_depth > 1:\n", | |
" self.subHourglass = Hourglass(self.hg_depth - 1, nFeatures)\n", | |
" else:\n", | |
" res_waist_list = [ResModule(nFeatures, nFeatures) for _ in range(CONFIG.nModules)]\n", | |
" self.resWaist = nn.Sequential(*res_waist_list)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" up = self.res1(x)\n", | |
" low1 = nn.MaxPool2d(kernel_size=2, stride=2)(x)\n", | |
" low1 = self.res2(low1)\n", | |
"\n", | |
" if self.hg_depth > 1:\n", | |
" low2 = self.subHourglass(low1)\n", | |
" else:\n", | |
" low2 = self.resWaist(low1)\n", | |
"\n", | |
" low3 = self.res3(low2)\n", | |
"\n", | |
" low = nn.UpsamplingNearest2d(scale_factor=2)(low3)\n", | |
"\n", | |
" return up + low" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class MainModel(nn.Module):\n", | |
" def __init__(self, in_channels=3):\n", | |
" super(MainModel, self).__init__()\n", | |
"\n", | |
" self.beforeHourglass = nn.Sequential(\n", | |
" nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2, padding=3),\n", | |
" nn.BatchNorm2d(num_features=64),\n", | |
" nn.ReLU(),\n", | |
" ResModule(in_channels=64, out_channels=128),\n", | |
" nn.MaxPool2d(kernel_size=2, stride=2),\n", | |
" ResModule(128, 128),\n", | |
" ResModule(128, CONFIG.nFeatures)\n", | |
" )\n", | |
"\n", | |
" self.hgArray = nn.ModuleList([])\n", | |
" self.linArray = nn.ModuleList([])\n", | |
" self.htmapArray = nn.ModuleList([])\n", | |
" self.llBarArray = nn.ModuleList([])\n", | |
" self.htmapBarArray = nn.ModuleList([])\n", | |
"\n", | |
" for i in range(CONFIG.nStacks):\n", | |
" self.hgArray.append(Hourglass(CONFIG.nDepth, CONFIG.nFeatures))\n", | |
" self.linArray.append(self.lin(CONFIG.nFeatures, CONFIG.nFeatures))\n", | |
" self.htmapArray.append(nn.Conv2d(CONFIG.nFeatures, CONFIG.nJoints, kernel_size=1, stride=1, padding=0))\n", | |
"\n", | |
" for i in range(CONFIG.nStacks - 1):\n", | |
" self.llBarArray.append(nn.Conv2d(CONFIG.nFeatures, CONFIG.nFeatures, kernel_size=1, stride=1, padding=0))\n", | |
" self.htmapBarArray.append(nn.Conv2d(CONFIG.nJoints, CONFIG.nFeatures, kernel_size=1, stride=1, padding=0))\n", | |
"\n", | |
" def forward(self, x):\n", | |
" inter = self.beforeHourglass(x)\n", | |
" outHeatmap = []\n", | |
"\n", | |
" for i in range(CONFIG.nStacks):\n", | |
" ll = self.hgArray[i](inter)\n", | |
" ll = self.linArray[i](ll)\n", | |
" htmap = self.htmapArray[i](ll)\n", | |
" outHeatmap.append(htmap)\n", | |
"\n", | |
" if i < CONFIG.nStacks - 1:\n", | |
" ll_ = self.llBarArray[i](ll)\n", | |
" htmap_ = self.htmapBarArray[i](htmap)\n", | |
" inter = inter + ll_ + htmap_\n", | |
"\n", | |
" return outHeatmap\n", | |
"\n", | |
" def lin(self, in_channels, out_channels):\n", | |
" return nn.Sequential(\n", | |
" nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),\n", | |
" nn.BatchNorm2d(num_features=out_channels),\n", | |
" nn.ReLU()\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sh = MainModel()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for torch_module, pytorch_module in zip(graph.node[1:7+1], sh.beforeHourglass):\n", | |
" torch_module.copy_to(pytorch_module)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for torch_idx, pytorch_hg in zip(range(8, 8 + 32 * CONFIG.nStacks + 1, 32), sh.hgArray):\n", | |
" graph.copy_to_hg(first_res_in_torch7=torch_idx, hg=pytorch_hg)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for torch_idx, pytorch_idx in zip(range(33, 33 + 32 * CONFIG.nStacks + 1, 32), range(0, 8, 1)):\n", | |
" lin = sh.linArray[pytorch_idx]\n", | |
" htmap = sh.htmapArray[pytorch_idx]\n", | |
" llBar = sh.llBarArray[pytorch_idx] if pytorch_idx != 7 else None\n", | |
" htmapBar = sh.htmapBarArray[pytorch_idx] if pytorch_idx != 7 else None\n", | |
" graph.copy_to_intermediate(first_conv_in_torch7=torch_idx, lin=lin, htmap=htmap, llBar=llBar, htmapBar=htmapBar)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"rgb = np.asarray(skimage.img_as_float(skimage.io.imread('asdf.jpg')))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"rgb = np.expand_dims(rgb.transpose(2, 0, 1), axis=0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"rgb = torch.Tensor(rgb)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"htmaps = sh(rgb)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"htmaps = htmaps[-1]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"htmaps = htmaps[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"htmaps.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from itertools import product" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for x, y in product(range(64), range(64)):\n", | |
" htmaps[0, y, x] = torch.max(htmaps[:, y, x])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x = np.asarray(htmaps[0, :, :].data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import imageio" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"imageio.imwrite('pred.jpg', x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"torch.save(\n", | |
" {\n", | |
" 'state': sh.state_dict(),\n", | |
" },\n", | |
" 'torch7.save',\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def beautify(pair, indent=0):\n", | |
" op, param = pair\n", | |
" \n", | |
" if op == Layer.Convolution:\n", | |
" weight, bias = param\n", | |
" sentence = '{indent}{operation}; {shape}'.format(indent='\\t'*indent, operation=op.name, shape=weight.shape)\n", | |
" \n", | |
" elif op == Layer.Batch_Norm:\n", | |
" running_mean, running_var = param\n", | |
" sentence = '{indent}{operation}; {shape}'.format(indent='\\t'*indent, operation=op.name, shape=running_mean.shape)\n", | |
" \n", | |
" elif op == Layer.Sequential or op == Layer.Concat:\n", | |
" sentence = '{indent}['.format(indent='\\t'*indent)\n", | |
" for sub_module in param:\n", | |
" sentence = sentence + '\\n' + beautify(sub_module, indent=indent+1)\n", | |
" sentence = sentence + '\\n' + '{indent}]'.format(indent='\\t'*indent)\n", | |
" \n", | |
" else:\n", | |
" sentence = '{indent}{operation}'.format(indent='\\t'*indent, operation=op.name)\n", | |
" \n", | |
" return sentence" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"idx = 30\n", | |
"print(str(graph.node[idx]))\n", | |
"for child in graph.node[idx].children:\n", | |
" if child == '2':\n", | |
" continue\n", | |
" print('\\t', str(graph.find_by_id(child)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment