Created
May 5, 2020 05:23
-
-
Save preetum/f69f7287dc4b34f8c71bf45133dc96c7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelWithLMHead, AutoTokenizer | |
import logging | |
model_name = 'gpt2-xl' | |
dev = 'cuda' | |
#dev = 'cpu' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelWithLMHead.from_pretrained(model_name).to(dev) | |
logging.getLogger().setLevel(logging.ERROR) | |
def generate(context, max_length=50): | |
input_ids = tokenizer.encode(context, return_tensors="pt").to(dev) | |
output_sequences = model.generate( | |
input_ids=input_ids, | |
max_length = max_length, | |
#temperature=args.temperature, | |
#top_k=args.k, | |
#top_p=args.p, | |
#repetition_penalty=args.repetition_penalty, | |
do_sample=True, | |
num_return_sequences=1, | |
) | |
text = tokenizer.decode(output_sequences.tolist()[0], clean_up_tokenization_spaces=True, skip_special_tokens=True) | |
return text | |
def run_interactive(max_length=50): | |
inp = input("Enter context to be completed (or quit): ") | |
while inp != 'quit': | |
for i in range(2): | |
print("=== GENERATED SEQUENCE {} ===".format(i + 1)) | |
print(generate(inp, max_length=max_length)) | |
inp = input("Enter context to be completed (or quit): ") | |
run_interactive() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment