Skip to content

Instantly share code, notes, and snippets.

@Neilblaze
Forked from awni/README.md
Last active February 8, 2025 22:16
Show Gist options
  • Save Neilblaze/13143abe2becd3fc189e91339ec64bcf to your computer and use it in GitHub Desktop.
Save Neilblaze/13143abe2becd3fc189e91339ec64bcf to your computer and use it in GitHub Desktop.
Test Time Scaling with R1-based Models and MLX LM

Test Time Scaling with MLX LM and R1-based LLMs

Install MLX LM:

pip install mlx-lm

And run:

python reason.py

The default model is mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit. You can specify the model with --model.

To see all the options:

python reason.py --help
# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import mlx.core as mx
from functools import partial
from mlx_lm.models.cache import make_prompt_cache, trim_prompt_cache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.utils import load, stream_generate
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MAX_TOKENS = 4096
DEFAULT_MODEL = "mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit"
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="Chat with an LLM")
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
default=DEFAULT_MODEL,
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
return parser
def main():
parser = setup_arg_parser()
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load(
args.model,
adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True},
)
wait_token = "Wait"
wait_token_id = tokenizer.convert_tokens_to_ids(wait_token)
end_think_token = "</think>"
end_think_token_id = tokenizer.convert_tokens_to_ids(end_think_token)
think_more_prompt = mx.array([wait_token_id], mx.uint32)
end_think_prompt = mx.array(
tokenizer.encode(end_think_token + "\n", add_special_tokens=False), mx.uint32
)
generator = partial(
stream_generate,
model=model,
tokenizer=tokenizer,
sampler=make_sampler(args.temp, args.top_p),
)
print(f"[INFO] Starting reasoning session with {args.model}. To exit, enter 'q'.")
while True:
prompt_cache = make_prompt_cache(model)
query = input(">> ")
if query == "q":
break
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
while True:
max_tokens = args.max_tokens
end_think_idx = None
for response in generator(
prompt=prompt,
max_tokens=max_tokens,
prompt_cache=prompt_cache,
):
if response.token == wait_token_id:
break
elif response.token == end_think_token_id:
end_think_idx = prompt_cache[0].offset
print(response.text, flush=True, end="")
max_tokens -= response.generation_tokens
# If we got a wait token insert </think> and generate the response
if end_think_idx is None:
print(end_think_token, flush=True)
end_think_idx = prompt_cache[0].offset
prompt = end_think_prompt
# Trim the wait token from the cache
trim_prompt_cache(prompt_cache, 1)
# Generate answer
for response in generator(
prompt=prompt,
max_tokens=max_tokens,
prompt_cache=prompt_cache,
):
print(response.text, flush=True, end="")
max_tokens -= response.generation_tokens
think_more = input(
"\n\n\033[31mWould you like me to think more? (y/n):\033[0m "
)
if think_more == "y":
# Trim the prompt cache to just before the end of think token
print("<think>")
print(wait_token, flush=True, end="")
num_to_trim = prompt_cache[0].offset - end_think_idx + 1
max_tokens += num_to_trim
trim_prompt_cache(prompt_cache, num_to_trim)
prompt = think_more_prompt
else:
break
print()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment