Skip to content

Instantly share code, notes, and snippets.

@bkaankuguoglu
Created April 1, 2021 21:02
Show Gist options
  • Save bkaankuguoglu/e68ba906b89435f1d97e257ca710694e to your computer and use it in GitHub Desktop.
Save bkaankuguoglu/e68ba906b89435f1d97e257ca710694e to your computer and use it in GitHub Desktop.
def train(self, train_loader, val_loader, batch_size=64, n_epochs=50, n_features=1):
model_path = f'models/{self.model}_{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}'
for epoch in range(1, n_epochs + 1):
batch_losses = []
for x_batch, y_batch in train_loader:
x_batch = x_batch.view([batch_size, -1, n_features]).to(device)
y_batch = y_batch.to(device)
loss = self.train_step(x_batch, y_batch)
batch_losses.append(loss)
training_loss = np.mean(batch_losses)
self.train_losses.append(training_loss)
with torch.no_grad():
batch_val_losses = []
for x_val, y_val in val_loader:
x_val = x_val.view([batch_size, -1, n_features]).to(device)
y_val = y_val.to(device)
self.model.eval()
yhat = self.model(x_val)
val_loss = self.loss_fn(y_val, yhat).item()
batch_val_losses.append(val_loss)
validation_loss = np.mean(batch_val_losses)
self.val_losses.append(validation_loss)
if (epoch <= 10) | (epoch % 50 == 0):
print(
f"[{epoch}/{n_epochs}] Training loss: {training_loss:.4f}\t Validation loss: {validation_loss:.4f}"
)
torch.save(self.model.state_dict(), model_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment