Skip to content

Instantly share code, notes, and snippets.

@jiamingkong
Created May 17, 2023 06:46
Show Gist options
  • Save jiamingkong/41d0bcf1f52be104a335bd0fa288c407 to your computer and use it in GitHub Desktop.
Save jiamingkong/41d0bcf1f52be104a335bd0fa288c407 to your computer and use it in GitHub Desktop.
Proof of Concept for State Caching in RWKV
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import os, gc
from huggingface_hub import hf_hub_download
from pynvml import *
import torch
from copy import deepcopy
nvmlInit()
gpu_h = nvmlDeviceGetHandleByIndex(0)
ctx_limit = 4096
model = "RWKV-4-Raven-7B-v9-Eng49%-Chn50%-Other1%-20230414-ctx4096"
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
from rwkv.model import RWKV
model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-raven", filename=f"{model}.pth")
strategy = "cuda fp16"
model = RWKV(model=model_path, strategy=strategy)
from rwkv.utils import PIPELINE, PIPELINE_ARGS
pipeline = PIPELINE(model, "20B_tokenizer.json")
def evaluate(
state = None,
instruction = "",
token_count=200,
temperature=1.0,
top_p=0.5,
presencePenalty = 0.2,
countPenalty = 0.2,
):
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
alpha_frequency = countPenalty,
alpha_presence = presencePenalty,
token_ban = [], # ban the generation of some tokens
token_stop = [0]) # stop generation whenever you see any token here
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
for i in range(int(token_count)):
out, _state = model.forward(pipeline.encode(instruction)[-ctx_limit:] if i == 0 else [token], state)
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
if token in args.token_stop:
break
all_tokens += [token]
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
tmp = pipeline.decode(all_tokens[out_last:])
if '\ufffd' not in tmp:
out_str += tmp
out_last = i + 1
if "\n" in out_str:
break
gc.collect()
torch.cuda.empty_cache()
return out_str, _state
# now first inject a context and
CONTEXT = """Answer the question based on the context below. Keep the answer short and concise. Respond \"Unsure about answer\" if not sure about the answer.
Context: Teplizumab traces its roots to a New Jersey drug company called Ortho Pharmaceutical. There, scientists generated an early version of the antibody, dubbed OKT3. Originally sourced from mice, the molecule was able to bind to the surface of T cells and limit their cell-killing potential. In 1986, it was approved to help prevent organ rejection after kidney transplants, making it the first therapeutic antibody allowed for human use.
"""
def formulate_question(question):
return f"""
Question: {question}
Answer:"""
# first let the model generate a state based on the CONTEXT
_, state = evaluate(None, CONTEXT, 1, 0.5, 0.5, 0.2, 0.2)
question_list = [
"What was OKT3 originally sourced from?",
"What was the use of OKT3?",
"How did OKT3 help prevent organ rejection?",
]
for question in question_list:
q = formulate_question(question)
s = deepcopy(state)
answer, _ = evaluate(s, q, 100, 0.5, 0.5, 0.2, 0.2)
print(f"Question: {question}")
print(f"Answer: {answer}")
print("")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment