Skip to content

Instantly share code, notes, and snippets.

@N8python
Last active December 7, 2024 22:01
Show Gist options
  • Save N8python/2431ab577d1de60e46180fa43a743c61 to your computer and use it in GitHub Desktop.
Save N8python/2431ab577d1de60e46180fa43a743c61 to your computer and use it in GitHub Desktop.
faster every day
from mlx_lm import load
import mlx.core as mx
from mlx.utils import tree_flatten, tree_map, tree_unflatten
import numpy as np
# Copyright © 2023-2024 Apple Inc.
import contextlib
import copy
import glob
import importlib
import json
import logging
import shutil
import time
from pathlib import Path
from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer
# Local imports
from mlx_lm.models import cache
from mlx_lm.sample_utils import make_logits_processors, make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper, load_tokenizer
from mlx_lm.tuner.utils import dequantize as dequantize_model
from mlx_lm.tuner.utils import load_adapters
"""model, tokenizer = load("Qwen2.5-1.5B-Instruct-bf16")
prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
response = generate(model, tokenizer, prompt=prompt, verbose=True)"""
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
if (
kv_bits is not None
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
and prompt_cache[0].offset > quantized_kv_start
):
for i in range(len(prompt_cache)):
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)
generation_stream = mx.new_stream(mx.default_device())
@contextlib.contextmanager
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
"""
A context manager to temporarily change the wired limit.
Note, the wired limit should not be changed during an async eval. If an
async eval could be running pass in the streams to synchronize with prior
to exiting the context manager.
"""
model_bytes = tree_reduce(
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
)
max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"]
if model_bytes > 0.9 * max_rec_size:
model_mb = model_bytes // 2**20
max_rec_mb = max_rec_size // 2**20
print(
f"[WARNING] Generating with a model that requires {model_mb} MB "
f"which is close to the maximum recommended size of {max_rec_mb} "
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models"
)
old_limit = mx.metal.set_wired_limit(max_rec_size)
try:
yield None
finally:
if streams is not None:
for s in streams:
mx.synchronize(s)
else:
mx.synchronize()
mx.metal.set_wired_limit(old_limit)
def generate_step(
prompt: mx.array,
model: nn.Module,
temp: float = 0.0,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
top_p: float = 1.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
y = prompt
tokens = None
# Create the KV cache for generation
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=max_kv_size,
)
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
logits_processors = logits_processors or []
logits_processors.extend(
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
)
def _step(y):
with mx.stream(generation_stream):
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
if logits_processors:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
for processor in logits_processors:
logits = processor(tokens, logits)
maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs)
return y, logprobs.squeeze(0)
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
y = y[prefill_step_size:]
mx.metal.clear_cache()
y, logprobs = _step(y)
mx.async_eval(y, logprobs)
n = 0
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
yield y.item(), logprobs
if n % 256 == 0:
mx.metal.clear_cache()
n += 1
y, logprobs = next_y, next_logprobs
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, List[int]],
max_tokens: int = 100,
**kwargs,
) -> Generator[Tuple[str, int, mx.array], None, None]:
"""
A generator producing text based on the given prompt from the model.
Args:
model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
Yields:
Tuple[str, int, mx.array]:
The next text segment, token, and vector of log probabilities.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
prompt_tokens = mx.array(
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
)
detokenizer = tokenizer.detokenizer
with wired_limit(model, [generation_stream]):
detokenizer.reset()
for n, (token, logits) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
if n == (max_tokens - 1):
break
yield detokenizer.last_segment, token, logits
detokenizer.finalize()
yield detokenizer.last_segment, token, logits
def generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
max_tokens: int = 100,
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
) -> str:
"""
Generate a complete response from the model.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
formatter (Optional[Callable]): A function which takes a token and a
probability and displays it.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
if verbose:
print("=" * 10)
print("Prompt:", prompt)
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
with wired_limit(model, [generation_stream]):
tic = time.perf_counter()
detokenizer.reset()
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
if verbose:
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)
token_count = n + 1
detokenizer.finalize()
if verbose:
gen_time = time.perf_counter() - tic
print(detokenizer.last_segment, flush=True)
print("=" * 10)
if token_count == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print(
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
)
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 1e9
print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text
def can_trim_prompt_cache(cache: List[Any]) -> bool:
"""
Check if model's cache can be trimmed.
"""
return all(c.is_trimmable() for c in cache)
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
"""
Trim the model's cache by the given number of tokens.
This function will trim the cache if possible (in-place) and return the
number of tokens that were trimmed.
Args:
cache (List[Any]): The model's cache.
num_tokens (int): The number of tokens to trim.
Returns:
(int): The number of tokens that were trimmed.
"""
return [c.trim(num_tokens) for c in cache][0]
@mx.compile
def compute_divergence_point(model_choices, draft_tokens, tokens_generated):
eq_mask = model_choices[:tokens_generated] == draft_tokens[1:]
divergence_point = (eq_mask.argmin() + tokens_generated * eq_mask.all())
return divergence_point
@mx.compile
def compute_divergence_point_2(model_choices, draft_tokens):
return (model_choices[0] == draft_tokens[1]) * (1 + (model_choices[1] == draft_tokens[2]))
@mx.compile
def get_new_tokens(d, divergence_point, model_choices):
d[:divergence_point] = d[1:1 + divergence_point]
d[divergence_point] = model_choices[divergence_point]
return d[:divergence_point + 1]
def generate_speculative(
model: nn.Module,
draft_model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
max_tokens: int = 100,
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
) -> str:
"""
Generate a complete response from the model.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
formatter (Optional[Callable]): A function which takes a token and a
probability and displays it.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
if verbose:
print("=" * 10)
print("Prompt:", prompt)
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
with wired_limit(model, [generation_stream]):
detokenizer.reset()
y = prompt_tokens
tokens = []
# Create the KV cache for generation
prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=None,
)
draft_cache = cache.make_prompt_cache(
draft_model,
max_kv_size=None,
)
# Pass the prompt through the model
k = 2
_ = model(y[None, :-1], cache=prompt_cache)
last_model_token = y[-1]
draft_logits = draft_model(y[None], cache=draft_cache)[0, -1, :][None, None]
accepted = 0
max_accepted = 0
all_accepted = 0
total_runs = 0
d = mx.zeros((k + 1,), dtype=mx.uint32)
tokens_added_without_cache_clear = 0
mx.eval(draft_logits)
mx.eval(_)
start_time = time.perf_counter()
while len(tokens) < max_tokens:
if tokens_added_without_cache_clear > 255:
mx.metal.clear_cache()
tokens_added_without_cache_clear = 0
d[0] = last_model_token
for i in range(k):
token = mx.argmax(draft_logits, axis=-1)
mx.async_eval(token)
d[i + 1] = token[0, 0]
draft_logits = draft_model(token, cache=draft_cache)
model_logits = model(d[None], cache=prompt_cache)
model_choices = mx.argmax(model_logits, axis=-1)[0]
mx.async_eval(model_choices)
divergence_point = compute_divergence_point_2(model_choices, d).item()
accepted += divergence_point
max_accepted += k
if divergence_point == k:
all_accepted += 1
total_runs += 1
new_tokens = get_new_tokens(d, divergence_point, model_choices)
need_to_break = False
for token in new_tokens:
if token == tokenizer.eos_token_id:
need_to_break = True
break
item = token.item()
detokenizer.add_token(item)
tokens.append(item)
tokens_added_without_cache_clear += 1
print(detokenizer.last_segment, flush=True, end="")
if k != divergence_point:
trim_prompt_cache(prompt_cache, k - divergence_point)
trim_prompt_cache(draft_cache, k - divergence_point)
if need_to_break:
break
last_model_token = model_choices[divergence_point]
draft_logits = draft_model(last_model_token[None, None], cache=draft_cache)
detokenizer.finalize()
print(detokenizer.last_segment, flush=True)
print("=" * 10)
end_time = time.perf_counter()
tokens_per_second = len(tokens) / (end_time - start_time)
print(f"Tokens per second: {tokens_per_second}")
print(f"Accepted Percentage: {(accepted / max_accepted) * 100:2f}%")
print(f"All Accepted Percentage: {(all_accepted / total_runs) * 100:2f}%")
print(f"Max Memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
return detokenizer.text
model, tokenizer = load("Qwen2.5-Coder-32B-Instruct-4bit")
draft_model,_ = load("Qwen2.5-0.5B-Instruct-8bit")
prompt = "Write quicksort in python."
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
response = generate_speculative(model, draft_model, tokenizer, prompt=prompt, verbose=True, max_tokens=1024)
#true_response = generate(model, tokenizer, prompt=prompt, verbose=True, max_tokens=1024)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment