Skip to content

Instantly share code, notes, and snippets.

@keunwoochoi
Created December 23, 2018 18:57
Show Gist options
  • Save keunwoochoi/d8bea1b0f0503f5cd3dacec9250ac084 to your computer and use it in GitHub Desktop.
Save keunwoochoi/d8bea1b0f0503f5cd3dacec9250ac084 to your computer and use it in GitHub Desktop.
class Net(nn.Module):
def __init__(self, src):
super(net, self).__init__()
self.src = src
self.a = a
self.conv1 = nn.Conv1d(1, 2, 3, 4,..)
def _fwd_1(self, x):
x = self.conv1(x)
x = x + self.a
def _fwd_2(self, x):
# sth similar to _fwd_1
def forward(self, x):
x = self._fwd_1(x)
return self._fwd_2(x)
def _compute_loss(self, y, ..):
losses = # whatever..
return losses
def learn_epoch(self, tr_loader, epoch):
self.train()
for batch_i, (x1, x2,..) in enumerate(tr_loader):
whatever = self.forward(x1)
loss = self.compute_loss(something)
loss.backward()
self.optimizer.step()
self._update_loss_history(loss)
def test(self, test_loader, ..):
self.eval()
for batch_i, (x1, x2,..) in enumerate(test_loader):
# do something
# plot the result
# etc
net = Net(a, b, ..)
data_loader = # whatever
net.learn_epoch(data_loader, ..)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment