Skip to content

Instantly share code, notes, and snippets.

@goddoe
Created September 27, 2018 11:33
Show Gist options
  • Save goddoe/ec6ee2cfc43b7716bdae2f9ac3bc1bfe to your computer and use it in GitHub Desktop.
Save goddoe/ec6ee2cfc43b7716bdae2f9ac3bc1bfe to your computer and use it in GitHub Desktop.
class Net(torch.nn.Module):
def __init__(self):
self.a = torch.nn.Parameter(torch.zeros(1))
self.b = torch.zeros(1)
def forward(self, inputs):
#do some stuff
def cuda(self, device=None):
self = super().cuda(device)
self.b = self.b.cuda(device)
return self
def to(self, *args, **kwargs):
self = super().to(*args, **kwargs)
self.b = self.b.to(*args, **kwargs)
return self
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment