Skip to content

Instantly share code, notes, and snippets.

@secemp9
Forked from vgel/r1.py
Created January 22, 2025 10:03
Show Gist options
  • Save secemp9/9ac75a96bf5164132f533fcc5c5bd8d4 to your computer and use it in GitHub Desktop.
Save secemp9/9ac75a96bf5164132f533fcc5c5bd8d4 to your computer and use it in GitHub Desktop.
script to run deepseek-r1 with a min-thinking-tokens parameter, replacing </think> with a random continuation string to extend the model's chain of thought
import argparse
import random
import sys
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
import torch
defo = ["\nWait, let's look at this from a system thinking approach:", "\nHmm let's look at this from a step by step approach:"]
# ~ defo = ["\nWait, but", "\nHmm", "\nSo", "\nActually"]
parser = argparse.ArgumentParser()
parser.add_argument("question", type=str)
parser.add_argument(
"-m", "--model-name", default="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
)
parser.add_argument("-d", "--device", default="auto")
parser.add_argument(
"-r", "--replacements", nargs="+", default=defo
)
parser.add_argument("-t", "--min-thinking-tokens", type=int, default=1256)
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForCausalLM.from_pretrained(
args.model_name, torch_dtype=torch.bfloat16, device_map=args.device
)
start_think_token, end_think_token = tokenizer.encode("<think></think>")
@torch.inference_mode
def reasoning_effort(question: str, min_thinking_tokens: int):
tokens = tokenizer.apply_chat_template(
[{"role": "user", "content": question}],
add_generation_prompt=True,
return_tensors="pt",
)
tokens = torch.cat((tokens, torch.tensor([[start_think_token]])), dim=-1)
tokens = tokens.to(model.device)
kv = DynamicCache()
n_thinking_tokens = 0
yield tokenizer.decode(list(tokens[0]))
while True:
out = model(input_ids=tokens, past_key_values=kv, use_cache=True)
next_token = torch.multinomial(
torch.softmax(out.logits[0, -1, :], dim=-1), 1
).item()
kv = out.past_key_values
if next_token == model.config.eos_token_id:
continue
elif next_token == end_think_token and n_thinking_tokens < min_thinking_tokens:
replacement = random.choice(args.replacements)
yield replacement
replacement_tokens = tokenizer.encode(replacement)
n_thinking_tokens += len(replacement_tokens)
tokens = torch.tensor([replacement_tokens]).to(tokens.device)
elif next_token == start_think_token and n_thinking_tokens < min_thinking_tokens:
replacement = random.choice(args.replacements)
yield replacement
replacement_tokens = tokenizer.encode(replacement)
n_thinking_tokens += len(replacement_tokens)
tokens = torch.tensor([replacement_tokens]).to(tokens.device)
else:
yield tokenizer.decode([next_token])
n_thinking_tokens += 1
tokens = torch.tensor([[next_token]]).to(tokens.device)
for chunk in reasoning_effort(args.question, args.min_thinking_tokens):
print(chunk, end="", flush=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment