Skip to content

Instantly share code, notes, and snippets.

@sebington
Last active February 8, 2025 16:25
Show Gist options
  • Save sebington/ece931a90048109a38b1df1fa4dc4a03 to your computer and use it in GitHub Desktop.
Save sebington/ece931a90048109a38b1df1fa4dc4a03 to your computer and use it in GitHub Desktop.
CPU-friendly version of vgel's r1.py script (https://gist.github.com/vgel/8a2497dc45b1ded33287fa7bb6cc1adc)
# Generated with Claude 3.5 Sonnet using vgel's r1.py script + the following prompt:
# "Can you modify this script to improve inference speed on a CPU-only PC?"
# It is possible to define the number of threads (= CPU cores) in the prompt
# Example run : python r1-cpu.py -t 32 "What is 1+1?" --threads 4
import argparse
import random
import sys
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
import torch
parser = argparse.ArgumentParser()
parser.add_argument("question", type=str)
parser.add_argument(
"-m", "--model-name", default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
)
parser.add_argument("-r", "--replacements", nargs="+", default=["\nWait, but", "\nHmm", "\nSo"])
parser.add_argument("-t", "--min-thinking-tokens", type=int, default=128)
parser.add_argument("-p", "--prefill", default="")
parser.add_argument("--threads", type=int, default=4, help="Number of CPU threads")
args = parser.parse_args()
# Set number of threads for CPU inference
torch.set_num_threads(args.threads)
# Initialize tokenizer with caching
tokenizer = AutoTokenizer.from_pretrained(
args.model_name,
use_fast=True, # Use fast tokenizer
model_max_length=2048, # Limit context size
)
# Load model with CPU optimizations
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
torch_dtype=torch.float32, # Use float32 for CPU
low_cpu_mem_usage=True,
device_map="cpu"
)
# Enable torch inference optimizations
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
_, _start_think_token, end_think_token = tokenizer.encode("<think></think>")
@torch.inference_mode()
def reasoning_effort(question: str, min_thinking_tokens: int):
# Pre-allocate tensors
tokens = tokenizer.apply_chat_template(
[
{"role": "user", "content": question},
{"role": "assistant", "content": "<think>\n" + args.prefill},
],
continue_final_message=True,
return_tensors="pt",
)
kv = DynamicCache()
n_thinking_tokens = 0
# Process in batches for better CPU utilization
batch_size = 1
yield tokenizer.decode(list(tokens[0]))
while True:
with torch.no_grad(): # Explicit no_grad for safety
out = model(
input_ids=tokens,
past_key_values=kv,
use_cache=True
)
next_token = torch.multinomial(
torch.softmax(out.logits[0, -1, :], dim=-1),
1
).item()
kv = out.past_key_values
if (
next_token in (end_think_token, model.config.eos_token_id)
and n_thinking_tokens < min_thinking_tokens
):
replacement = random.choice(args.replacements)
yield replacement
replacement_tokens = tokenizer.encode(replacement)
n_thinking_tokens += len(replacement_tokens)
tokens = torch.tensor([replacement_tokens])
elif next_token == model.config.eos_token_id:
break
else:
yield tokenizer.decode([next_token])
n_thinking_tokens += 1
tokens = torch.tensor([[next_token]])
# Main execution with error handling
try:
for chunk in reasoning_effort(args.question, args.min_thinking_tokens):
print(chunk, end="", flush=True)
except Exception as e:
print(f"\nError during inference: {str(e)}", file=sys.stderr)
sys.exit(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment