Last active
August 31, 2019 00:03
-
-
Save Lexie88rus/944b34797b4093e6b4350645d64e6713 to your computer and use it in GitHub Desktop.
Train the model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import numpy as np | |
import math | |
import matplotlib.pyplot as plt | |
# Set up the number of iterations, printing and plotting options | |
n_iters = 1100000 | |
print_every = 1000 | |
plot_every = 1000 | |
# Keep track of losses for plotting | |
current_loss = 0 | |
all_losses = [] | |
rnn = rnn.to(device) | |
def timeSince(since): | |
now = time.time() | |
s = now - since | |
m = math.floor(s / 60) | |
s -= m * 60 | |
return '%dm %ds' % (m, s) | |
# Shuffle indices | |
indices = np.random.permutation(len(sequences)) | |
start = time.time() | |
# Run training procedure | |
for iter in range(1, n_iters + 1): | |
# Pick index | |
index = indices[iter % len(sequences)] | |
# Run one training step | |
output, loss = train(sequences[index], targets[index][0].long(), device) | |
current_loss += loss | |
# Print iter number and loss | |
if iter % print_every == 0: | |
guess, guess_i = wordFromOutput(output) | |
print('%d %d%% (%s) Loss: %.4f' % (iter, iter / n_iters * 100, timeSince(start), loss)) | |
# Add current loss avg to list of losses | |
if iter % plot_every == 0: | |
all_losses.append(current_loss / plot_every) | |
current_loss = 0 | |
# Plot training loss | |
plt.figure() | |
plt.plot(all_losses) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment