Skip to content

Instantly share code, notes, and snippets.

@umbra-scientia
Last active June 28, 2020 03:32
Show Gist options
  • Save umbra-scientia/8a24442a0c723021491a7410c959cdf0 to your computer and use it in GitHub Desktop.
Save umbra-scientia/8a24442a0c723021491a7410c959cdf0 to your computer and use it in GitHub Desktop.
#!/usr/bin/python3
#export PYTHONIOENCODING=UTF-8
print("Loading. This might take a while...")
import os
import sys
from transformers import AutoModelWithLMHead, AutoTokenizer
from datetime import date
#model_type = 'gpt2' # Much lower resource use! Also lower intelligence.
model_type = 'ctrl'
tokenizer = AutoTokenizer.from_pretrained(model_type)
model = AutoModelWithLMHead.from_pretrained(model_type)
print("\n[waking up]")
max_len = 512
max_input_len = 256
chunk_size = 8
history_len = 8
name = os.environ.get("AI_NAME")
if name == None:
name = input("What is my name? ")
if name == "": name = "Starlight"
os.environ["AI_NAME"] = name
prefix = "Questions "
top_p = 0.618
top_k = 0
debug_mode = 0
initial_state = [
["Hello?", "Hello there!"],
["Who are you?", "My name is "+name+"."],
["What are you?", "I am a friendly AI."],
["What is the date?", date.today().strftime("Today is %A, %B %d, %Y.")],
["Are you sure?", "lol no."]
]
history = list(initial_state)
while True:
Q = input("> ").strip()
Qs = Q.split(" ")
if Qs[0] == "/debug":
debug_mode = int(Qs[1])
continue
if Qs[0] == "/context":
print(history)
continue
if Q[0:7] == "/prefix":
if len(Q) <= 7: prefix = ""
else: prefix = Q[7:] + " "
continue
if Qs[0] == "/model":
model_type = Qs[1]
tokenizer = AutoTokenizer.from_pretrained(model_type)
model = AutoModelWithLMHead.from_pretrained(model_type)
continue
if Qs[0] == "/blank":
history = []
continue
if Qs[0] == "/checkpoint":
initial_state = list(history)
continue
if Qs[0] == "/reset":
if len(Qs) > 1:
n = int(Qs[1])
history = history[:-n]
else:
history = list(initial_state)
continue
if Qs[0] == "/set":
for i in range(1, len(Qs)):
c = Qs[i].lower().split("=")
if c[0] == "top_p": top_p = float(c[1])
elif c[0] == "top_k": top_k = int(c[1])
elif c[0] == "chunk_size": chunk_size = int(c[1])
else: print("Unknown variable %s" % c[0])
continue
A = ""
try:
while len(A) < max_len:
history_len += 1
while True:
buf = prefix
for h in history[-history_len:]:
if h[0] == None:
buf += h[1] + "\n"
else:
buf += "Q: " + h[0] + "\nA: " + h[1] + "\n"
buf += "Q: " + Q + "\nA: " + A
ids = tokenizer.encode(buf, return_tensors='pt')
if (len(ids) > max_input_len - chunk_size) and (history_len > 1):
history_len -= 1
else:
break
if debug_mode != 0:
print("buf = %s\nids.shape =" % buf, ids.shape)
out = model.generate(input_ids=ids, max_length=ids.shape[1]+chunk_size, do_sample=True, top_p=top_p, top_k=top_k)
T = out[0][ids.shape[1]:]
S = tokenizer.decode(T)
done = False
while True:
if "\n" in S:
S = S[:S.index("\n")]
done = True
elif "A:" in S:
S = S[:S.index("A:")]
done = True
else:
break
S = S.strip()
if S[-2:] == "@@":
S = S[:-2]
else:
S = S + " "
A += S
sys.stdout.write(S)
if done:
sys.stdout.write("\n")
break
else:
sys.stdout.flush()
except KeyboardInterrupt:
pass
A = A.strip()
history.append([Q, A])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment