Last active
August 31, 2019 00:12
-
-
Save Lexie88rus/5699e10792d1bbbdb4c9c581ad706efd to your computer and use it in GitHub Desktop.
Sample titles from the model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Sample title from the trained model | |
def sample(): | |
num_words = 10 | |
# Initialize input step and hidden state | |
input = torch.zeros(1, 1, vocab_size) | |
hidden = (torch.zeros(1, 1, n_hidden).to(device), torch.zeros(1, 1, n_hidden).to(device)) | |
i = 0 | |
output_word = None | |
sentence = [] | |
# Sample words from the model | |
while output_word != '.' and i < num_words: | |
input = input.to(device) | |
output, next_hidden = rnn(input[0], hidden) | |
y = output.clone() | |
y = y.to(device) | |
# Use the probabilities from the output to choose the next word | |
idx = np.random.choice(range(vocab_size), p = f.softmax(y, dim=1).detach().cpu().numpy().ravel()) | |
output_word = [k for (k, v) in vocab.items() if v == idx][0] | |
sentence.append(output_word) | |
hidden = next_hidden | |
input = wordToTensor(output_word) | |
i = i+1 | |
return sentence | |
# Sample 15 titles and print | |
for i in range(15): | |
sampled_title = sample() | |
title = ' '.join(sampled_title) | |
print(title) | |
print("\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment