Skip to content

Instantly share code, notes, and snippets.

@elyase
Forked from vgel/r1.py
Created January 26, 2025 12:25
Show Gist options
  • Save elyase/168cbe0d5f66afd37830c4286407a208 to your computer and use it in GitHub Desktop.
Save elyase/168cbe0d5f66afd37830c4286407a208 to your computer and use it in GitHub Desktop.
script to run deepseek-r1 with a min-thinking-tokens parameter, replacing </think> with a random continuation string to extend the model's chain of thought
import argparse
import random
import sys
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
import torch
parser = argparse.ArgumentParser()
parser.add_argument("question", type=str)
parser.add_argument(
"-m", "--model-name", default="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
)
parser.add_argument("-d", "--device", default="auto")
parser.add_argument(
"-r", "--replacements", nargs="+", default=["\nWait, but", "\nHmm", "\nSo"]
)
parser.add_argument("-t", "--min-thinking-tokens", type=int, default=128)
parser.add_argument("-p", "--prefill", default="")
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForCausalLM.from_pretrained(
args.model_name, torch_dtype=torch.bfloat16, device_map=args.device
)
_, _start_think_token, end_think_token = tokenizer.encode("<think></think>")
@torch.inference_mode
def reasoning_effort(question: str, min_thinking_tokens: int):
tokens = tokenizer.apply_chat_template(
[
{"role": "user", "content": question},
{"role": "assistant", "content": "<think>\n" + args.prefill},
],
continue_final_message=True,
return_tensors="pt",
)
tokens = tokens.to(model.device)
kv = DynamicCache()
n_thinking_tokens = 0
yield tokenizer.decode(list(tokens[0]))
while True:
out = model(input_ids=tokens, past_key_values=kv, use_cache=True)
next_token = torch.multinomial(
torch.softmax(out.logits[0, -1, :], dim=-1), 1
).item()
kv = out.past_key_values
if (
next_token in (end_think_token, model.config.eos_token_id)
and n_thinking_tokens < min_thinking_tokens
):
replacement = random.choice(args.replacements)
yield replacement
replacement_tokens = tokenizer.encode(replacement)
n_thinking_tokens += len(replacement_tokens)
tokens = torch.tensor([replacement_tokens]).to(tokens.device)
elif next_token == model.config.eos_token_id:
break
else:
yield tokenizer.decode([next_token])
n_thinking_tokens += 1
tokens = torch.tensor([[next_token]]).to(tokens.device)
for chunk in reasoning_effort(args.question, args.min_thinking_tokens):
print(chunk, end="", flush=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment