Skip to content

Instantly share code, notes, and snippets.

@Hanrui-Wang
Created July 25, 2019 02:32
Show Gist options
  • Save Hanrui-Wang/53a01a8e96ff00a4911decf6fc9c750e to your computer and use it in GitHub Desktop.
Save Hanrui-Wang/53a01a8e96ff00a4911decf6fc9c750e to your computer and use it in GitHub Desktop.
truncated BPTT
# Truncated backpropagation
def detach(states):
return [state.detach() for state in states]
# Train the model
for epoch in range(num_epochs):
# Set initial hidden and cell states
states = (torch.zeros(num_layers, batch_size, hidden_size).to(device),
torch.zeros(num_layers, batch_size, hidden_size).to(device))
for i in range(0, ids.size(1) - seq_length, seq_length):
# Get mini-batch inputs and targets
inputs = ids[:, i:i+seq_length].to(device)
targets = ids[:, (i+1):(i+1)+seq_length].to(device)
# Forward pass
states = detach(states)
outputs, states = model(inputs, states)
loss = criterion(outputs, targets.reshape(-1))
# Backward and optimize
model.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
step = (i+1) // seq_length
if step % 100 == 0:
print ('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, Perplexity: {:5.2f}'
.format(epoch+1, num_epochs, step, num_batches, loss.item(), np.exp(loss.item())))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment