Created
December 23, 2018 18:57
-
-
Save keunwoochoi/d8bea1b0f0503f5cd3dacec9250ac084 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Net(nn.Module): | |
def __init__(self, src): | |
super(net, self).__init__() | |
self.src = src | |
self.a = a | |
self.conv1 = nn.Conv1d(1, 2, 3, 4,..) | |
def _fwd_1(self, x): | |
x = self.conv1(x) | |
x = x + self.a | |
def _fwd_2(self, x): | |
# sth similar to _fwd_1 | |
def forward(self, x): | |
x = self._fwd_1(x) | |
return self._fwd_2(x) | |
def _compute_loss(self, y, ..): | |
losses = # whatever.. | |
return losses | |
def learn_epoch(self, tr_loader, epoch): | |
self.train() | |
for batch_i, (x1, x2,..) in enumerate(tr_loader): | |
whatever = self.forward(x1) | |
loss = self.compute_loss(something) | |
loss.backward() | |
self.optimizer.step() | |
self._update_loss_history(loss) | |
def test(self, test_loader, ..): | |
self.eval() | |
for batch_i, (x1, x2,..) in enumerate(test_loader): | |
# do something | |
# plot the result | |
# etc | |
net = Net(a, b, ..) | |
data_loader = # whatever | |
net.learn_epoch(data_loader, ..) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment