-
-
Save kylemcdonald/0518aa9e63e2514073fbf6efd506be20 to your computer and use it in GitHub Desktop.
Pytorch char rnn as a script, based on examples from Kyle McDonald, Laurent Dinh, and Sean Robertson
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Special thanks to Kyle McDonald, this is based on his example | |
# https://gist.github.com/kylemcdonald/2d06dc736789f0b329e11d504e8dee9f | |
# Thanks to Laurent Dinh for examples of parameter saving/loading in PyTorch | |
# Thanks to Sean Robertson for https://github.com/spro/practical-pytorch | |
from tqdm import tqdm | |
from torch.autograd import Variable | |
import torch.nn as nn | |
import torch | |
import numpy as np | |
import math | |
import os | |
import argparse | |
parser = argparse.ArgumentParser(description='PyTorch char-rnn') | |
parser.add_argument('--eval', action='store_true', help='evaluate instead of train') | |
parser.add_argument('--temperature', type=float, default=0.8) | |
parser.add_argument('--eval_len', type=int, default=500) | |
parser.add_argument('--seq_length', type=int, default=50) | |
parser.add_argument('--batch_size', type=int, default=50) | |
parser.add_argument('--rnn_size', type=int, default=128) | |
parser.add_argument('--max_epochs', type=int, default=10) | |
parser.add_argument('--num_layers', type=int, default=2) | |
parser.add_argument('--learning_rate', type=float, default=2e-3) | |
# from https://raw.githubusercontent.com/jcjohnson/torch-rnn/master/data/tiny-shakespeare.txt | |
parser.add_argument('--input', '-i', type=str, default='tiny-shakespeare.txt') | |
parser.add_argument('--output', '-o', type=str, default='.') | |
parser.add_argument('--seed', type=str, default='a') | |
args = parser.parse_args() | |
use_cuda = torch.cuda.is_available() | |
# try to get deterministic runs | |
torch.manual_seed(1999) | |
random_state = np.random.RandomState(1999) | |
seq_length = args.seq_length | |
batch_size = args.batch_size | |
hidden_size = args.rnn_size | |
epoch_count = args.max_epochs | |
n_layers = args.num_layers | |
lr = args.learning_rate | |
input_filename = args.input | |
checkpoint_path = os.path.join(args.output, 'checkpoint.pth.tar') | |
final_path = os.path.join(args.output, 'final.pth.tar') | |
with open(input_filename, 'r') as f: | |
text = f.read() | |
chars = sorted(list(set(text))) | |
chars_len = len(chars) | |
char_to_index = {} | |
index_to_char = {} | |
for i, c in enumerate(chars): | |
char_to_index[c] = i | |
index_to_char[i] = c | |
def chunks(l, n): | |
for i in range(0, len(l) - n, n): | |
yield l[i:i + n] | |
def index_to_tensor(index): | |
tensor = torch.zeros(1, 1).long() | |
tensor[0,0] = index | |
return Variable(tensor) | |
class RNN(nn.Module): | |
def __init__(self, input_size, hidden_size, output_size, n_layers): | |
super(RNN, self).__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.output_size = output_size | |
self.n_layers = n_layers | |
self.encoder = nn.Embedding(input_size, hidden_size) | |
self.cells = nn.GRU(hidden_size, hidden_size, n_layers) | |
self.decoder = nn.Linear(hidden_size, output_size) | |
def forward(self, input, hidden): | |
input = self.encoder(input) | |
output, hidden = self.cells(input, hidden) | |
output = self.decoder(output.view(output.size(0) * output.size(1), output.size(2))) | |
return output, hidden | |
def create_hidden(self, batch_size): | |
# should this be small random instead of zeros | |
# should this also be stored in the class rather than being passed around? | |
return torch.zeros(self.n_layers, batch_size, self.hidden_size) | |
def train(): | |
# convert all characters to indices | |
batches = [char_to_index[char] for char in text] | |
# chunk into sequences of length seq_length + 1 | |
batches = list(chunks(batches, seq_length + 1)) | |
# chunk sequences into batches | |
batches = list(chunks(batches, batch_size)) | |
# convert batches to tensors and transpose | |
# each batch is (sequence_length + 1) x batch_size | |
batches = [torch.LongTensor(batch).transpose_(0, 1) for batch in batches] | |
optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
loss_function = nn.CrossEntropyLoss() | |
hidden = Variable(model.create_hidden(batch_size)) | |
if use_cuda: | |
hidden = hidden.cuda() | |
model.cuda() | |
if os.path.exists(checkpoint_path): | |
print('Parameters found at {}... loading'.format(checkpoint_path)) | |
checkpoint = torch.load(checkpoint_path) | |
model.load_state_dict(checkpoint['model']) | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
all_losses = [] | |
try: | |
epoch_progress = tqdm(range(1, epoch_count + 1)) | |
for epoch in epoch_progress: | |
random_state.shuffle(batches) | |
batches_progress = tqdm(batches) | |
best_loss = float('inf') | |
for batch, batch_tensor in enumerate(batches_progress): | |
if use_cuda: | |
batch_tensor = batch_tensor.cuda() | |
# reset the model | |
model.zero_grad() | |
# everything except the last | |
input_variable = Variable(batch_tensor[:-1]) | |
# everything except the first, flattened | |
# what does this .contiguous() do? | |
target_variable = Variable(batch_tensor[1:].contiguous().view(-1)) | |
# prediction and calculate loss | |
output, _ = model(input_variable, hidden) | |
loss = loss_function(output, target_variable) | |
# backprop and optimize | |
loss.backward() | |
optimizer.step() | |
loss = loss.data[0] | |
best_loss = min(best_loss, loss) | |
all_losses.append(loss) | |
batches_progress.set_postfix(loss='{:.03f}'.format(loss)) | |
epoch_progress.set_postfix(loss='{:.03f}'.format(best_loss)) | |
torch.save({ | |
'model': model.state_dict(), | |
'optimizer': optimizer.state_dict() | |
}, checkpoint_path) | |
except KeyboardInterrupt: | |
pass | |
# final save | |
torch.save({ | |
'model': model.state_dict(), | |
'optimizer': optimizer.state_dict() | |
}, final_path) | |
def sample_temperature(x, temperature=1.0): | |
x = x.reshape(-1).astype(np.float) | |
x /= temperature | |
x = np.exp(x) | |
x /= np.sum(x) | |
x = random_state.multinomial(1, x) | |
x = np.argmax(x) | |
return x.astype(np.int64) | |
def evaluate(prime_str, predict_len=100, temperature=0.8): | |
if os.path.exists(final_path): | |
print('Final parameters found at {}... loading'.format(final_path)) | |
checkpoint = torch.load(final_path) | |
model.load_state_dict(checkpoint['model']) | |
else: | |
raise ValueError('Training was not finalized, no file found at {}. Run without --eval first to train a model'.format(final_path)) | |
hidden = Variable(model.create_hidden(1), volatile=True) | |
prime_tensors = [index_to_tensor(char_to_index[char]) for char in prime_str] | |
for prime_tensor in prime_tensors[-2:]: | |
_, hidden = model(prime_tensor, hidden) | |
inp = prime_tensors[-1] | |
predicted = prime_str | |
for p in range(predict_len): | |
output, hidden = model(inp, hidden) | |
# Sample from the network as a multinomial distribution | |
# output_dist = output.data.view(-1).div(temperature).exp() | |
# top_i = torch.multinomial(output_dist, 1)[0] | |
# Alternative: use numpy | |
top_i = sample_temperature(output.data.numpy(), temperature) | |
# Add predicted character to string and use as next input | |
predicted_char = index_to_char[top_i] | |
predicted += predicted_char | |
inp = index_to_tensor(char_to_index[predicted_char]) | |
return predicted | |
model = RNN(chars_len, hidden_size, chars_len, n_layers) | |
if args.eval: | |
print(evaluate(args.seed, args.eval_len, temperature=args.temperature)) | |
else: | |
train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment