Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Created February 27, 2020 15:59
Show Gist options
  • Save williamFalcon/c5b9a7b0d748db28a1c4324194676cb2 to your computer and use it in GitHub Desktop.
Save williamFalcon/c5b9a7b0d748db28a1c4324194676cb2 to your computer and use it in GitHub Desktop.
def training_step(self, batch, batch_idx):
x, y = batch
# define your own forward and loss calculation
hidden_states = self.encoder(x)
# even as complex as a seq-2-seq + attn model
# (this is just a toy, non-working example to illustrate)
start_token = '<SOS>'
last_hidden = torch.zeros(...)
loss = 0
for step in range(max_seq_len):
attn_context = self.attention_nn(hidden_states, start_token)
pred = self.decoder(start_token, attn_context, last_hidden)
last_hidden = pred
pred = self.predict_nn(pred)
loss += self.loss(last_hidden, y[step])
#toy example as well
loss = loss / max_seq_len
return {'loss': loss}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment