Created
May 17, 2023 06:46
-
-
Save jiamingkong/41d0bcf1f52be104a335bd0fa288c407 to your computer and use it in GitHub Desktop.
Proof of Concept for State Caching in RWKV
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# -*- coding: UTF-8 -*- | |
import os, gc | |
from huggingface_hub import hf_hub_download | |
from pynvml import * | |
import torch | |
from copy import deepcopy | |
nvmlInit() | |
gpu_h = nvmlDeviceGetHandleByIndex(0) | |
ctx_limit = 4096 | |
model = "RWKV-4-Raven-7B-v9-Eng49%-Chn50%-Other1%-20230414-ctx4096" | |
os.environ["RWKV_JIT_ON"] = '1' | |
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster) | |
from rwkv.model import RWKV | |
model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-raven", filename=f"{model}.pth") | |
strategy = "cuda fp16" | |
model = RWKV(model=model_path, strategy=strategy) | |
from rwkv.utils import PIPELINE, PIPELINE_ARGS | |
pipeline = PIPELINE(model, "20B_tokenizer.json") | |
def evaluate( | |
state = None, | |
instruction = "", | |
token_count=200, | |
temperature=1.0, | |
top_p=0.5, | |
presencePenalty = 0.2, | |
countPenalty = 0.2, | |
): | |
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p), | |
alpha_frequency = countPenalty, | |
alpha_presence = presencePenalty, | |
token_ban = [], # ban the generation of some tokens | |
token_stop = [0]) # stop generation whenever you see any token here | |
all_tokens = [] | |
out_last = 0 | |
out_str = '' | |
occurrence = {} | |
for i in range(int(token_count)): | |
out, _state = model.forward(pipeline.encode(instruction)[-ctx_limit:] if i == 0 else [token], state) | |
for n in occurrence: | |
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) | |
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p) | |
if token in args.token_stop: | |
break | |
all_tokens += [token] | |
if token not in occurrence: | |
occurrence[token] = 1 | |
else: | |
occurrence[token] += 1 | |
tmp = pipeline.decode(all_tokens[out_last:]) | |
if '\ufffd' not in tmp: | |
out_str += tmp | |
out_last = i + 1 | |
if "\n" in out_str: | |
break | |
gc.collect() | |
torch.cuda.empty_cache() | |
return out_str, _state | |
# now first inject a context and | |
CONTEXT = """Answer the question based on the context below. Keep the answer short and concise. Respond \"Unsure about answer\" if not sure about the answer. | |
Context: Teplizumab traces its roots to a New Jersey drug company called Ortho Pharmaceutical. There, scientists generated an early version of the antibody, dubbed OKT3. Originally sourced from mice, the molecule was able to bind to the surface of T cells and limit their cell-killing potential. In 1986, it was approved to help prevent organ rejection after kidney transplants, making it the first therapeutic antibody allowed for human use. | |
""" | |
def formulate_question(question): | |
return f""" | |
Question: {question} | |
Answer:""" | |
# first let the model generate a state based on the CONTEXT | |
_, state = evaluate(None, CONTEXT, 1, 0.5, 0.5, 0.2, 0.2) | |
question_list = [ | |
"What was OKT3 originally sourced from?", | |
"What was the use of OKT3?", | |
"How did OKT3 help prevent organ rejection?", | |
] | |
for question in question_list: | |
q = formulate_question(question) | |
s = deepcopy(state) | |
answer, _ = evaluate(s, q, 100, 0.5, 0.5, 0.2, 0.2) | |
print(f"Question: {question}") | |
print(f"Answer: {answer}") | |
print("") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment