Skip to content

Instantly share code, notes, and snippets.

@hellosaumil
Created March 19, 2020 09:39
Show Gist options
  • Save hellosaumil/ecbd48989cff3559ac915fe9eb7b7b38 to your computer and use it in GitHub Desktop.
Save hellosaumil/ecbd48989cff3559ac915fe9eb7b7b38 to your computer and use it in GitHub Desktop.
class myNNUnwrapper():
def __init__(self, inputTensor, nnModules, nnParams, verbose=False):
self.inputTensor = inputTensor
self.nnModules, self.nnParams = nnModules, nnParams
self.verbose = verbose
self.resolveAll()
return
def __str__(self):
return ''
def resolveAll(self):
self.currentTensor = self.inputTensor
self.module_to_nnResolver = {'linear': self.resolveLinear,
'preprocess': self.resolveLinear,
'view': self.resolveView,
'conv2d': self.resolveConv2d,
'convout': self.resolveConv2dOut}
print("\nResolved Dimensions: \ninputTensor: \t{}\n".format(self.currentTensor))
for (idx, (nnModule, params)) in enumerate(zip(self.nnModules, self.nnParams)):
moduleName = nnModule.lower()
nnModuleResolver = self.module_to_nnResolver[moduleName]
print("\n{}. {}".format(idx+1, nnModuleResolver( *(params+(self.verbose,)) )))
def resolveLinear(self, in_dim, out_dim, verbose=False):
self.currentTensor[1] = out_dim
if not verbose:
return "linear: \t{}".format(self.currentTensor)
else:
return "{} : Linear(in_features={}, out_features={})".format(*(self.currentTensor)*2)
def resolveView(self, extra, new_dim, x, y, verbose=False):
[batch, out_dim] = self.inputTensor[0], self.currentTensor[1]
if out_dim // new_dim == x * y:
self.currentTensor = [batch, new_dim, x, y]
if not verbose:
return "view: \t[{}, {}, {}, {}]".format(*(self.currentTensor))
else:
return "[{}, {}, {}, {}] : View(batch={}, samples={}, dim=({}, {}))".format(*(self.currentTensor*2))
else:
print("\nError: Dimension Mismatch. {}, {}".format(out_dim // new_dim, x * y))
print("\textra:{}, new_dim:{}, x:{}, y:{}\n".format(extra, new_dim, x, y))
def resolveConv2dOut(self, *args):
return self.resolveConv2d(*args, convOut=True)
def resolveConv2d(self, C_in, C_out, kernel_size, stride=1, padding=0, pixel_shuffle=False, verbose=False, convOut=False):
f = kernel_size
p = padding
s = stride
d = dilation = 1
[N, C_in, H_in, W_in] = self.currentTensor
H_out = (H_in + 2*p - d*(f-1) - 1) // s + 1
W_out = (W_in + 2*p - d*(f-1) -1) // s + 1
self.currentTensor = [N, C_out, H_out, W_out]
if pixel_shuffle:
shuffle_by = 2
self.currentTensor = [N, C_out//shuffle_by**2, H_out*shuffle_by, W_out*shuffle_by]
if not verbose:
return "{}: \t{}".format(("conv_out" if convOut else "conv2d"), self.currentTensor)
else:
return "{}: {}(in_channels={}, out_channels={}, kernel_size=({},{}), strides=({},{}), padding=({},{}), dilation={})".\
format(("ConvOut" if convOut else "Conv2d"),
self.currentTensor, C_in, C_out,
*(f,)*2, *((s,)*2), *((p,)*2), d)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment