Skip to content

Instantly share code, notes, and snippets.

@VoVAllen
Last active July 16, 2017 15:46
Show Gist options
  • Save VoVAllen/1420e410e3dfd368b8dc9061ad0c206a to your computer and use it in GitHub Desktop.
Save VoVAllen/1420e410e3dfd368b8dc9061ad0c206a to your computer and use it in GitHub Desktop.
nn.Linear infer shape implementation
def is_cuda(operation):
if isinstance(operation, th.nn.Module):
return 'cuda' in str(type(operation.parameters().next()))
elif isinstance(operation, th._TensorBase):
return 'cuda' in str(type(operation))
elif isinstance(operation, Variable):
return 'cuda' in str(type(operation.data))
else:
raise Exception("Operation is not nn.Module or Variable or Tensor")
class Dense(nn.Module):
def __init__(self, hidden_size):
super(Dense, self).__init__()
self.hidden_size = hidden_size
self.input_size = -1
self.not_initialized = True
def forward(self, x):
if self.not_initialized:
assert x.dim() == 2
self.input_size = x.size(1)
if is_cuda(x):
self.add_module("linear", nn.Linear(self.input_size, self.hidden_size).cuda())
else:
self.add_module("linear", nn.Linear(self.input_size, self.hidden_size))
self.not_initialized = False
return self.linear(x)
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.input_size) + ' -> ' \
+ str(self.hidden_size) + ')'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment