Skip to content

Instantly share code, notes, and snippets.

@umbra-scientia
Last active September 9, 2020 00:16
Show Gist options
  • Save umbra-scientia/62bb3506cfd6cd62a8ebe61ff364da5f to your computer and use it in GitHub Desktop.
Save umbra-scientia/62bb3506cfd6cd62a8ebe61ff364da5f to your computer and use it in GitHub Desktop.
#!/usr/bin/python3
# export PYTHONIOENCODING=UTF-8
gpt2_size = "gpt2"
import numpy as np
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer, GPT2Config
tokenizer = GPT2Tokenizer.from_pretrained(gpt2_size)
gpt2 = TFGPT2LMHeadModel.from_pretrained(gpt2_size)
buf = input("Provide a prompt: ")
sampler_p = 0.75
sampler_k = 8
buffer_size = 256
doit = True
while doit:
try:
ctx = tokenizer.encode(buf, max_len=buffer_size*4)
ctxlen = len(ctx)
if ctxlen < buffer_size:
ctx.extend([0 for qqq in range(buffer_size)])
ctx = ctx[:buffer_size]
else:
ctx = ctx[-buffer_size:]
ctxlen = len(ctx)
ctx = np.array([ctx])
wavefront = gpt2.predict([ctx])[0]
prob = wavefront[0][ctxlen-1]
keys = np.argsort(-prob)
option_keys = []
option_prob = []
total_p = 0.0
for k in keys:
total_p += prob[k]
option_keys.append(k)
option_prob.append(prob[k])
if (total_p > sampler_p) or (len(option_keys) > sampler_k): break
option_prob = np.array(option_prob)
option_prob = np.exp(option_prob - np.max(option_prob))
option_prob /= option_prob.sum(axis=0)
print("\nPrompt:\n%s" % buf)
for opti in range(len(option_keys)):
k = option_keys[opti]
p = option_prob[opti]
t = tokenizer.decode([k])
print("\t[%d] \"%s\" (%.02f%%)"%(opti, t, p*100.0))
opti += 1
c = input("[choose 0 thru %d, q to quit, or type anything] "%(len(option_keys)-1))
if c == "q":
doit = False
break
if c == "":
c = 0
try:
c = int(c)
buf += tokenizer.decode([option_keys[c]])
except:
buf += c
except KeyboardInterrupt:
doit = False
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment