Last active
December 7, 2024 22:01
-
-
Save N8python/2431ab577d1de60e46180fa43a743c61 to your computer and use it in GitHub Desktop.
faster every day
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from mlx_lm import load | |
import mlx.core as mx | |
from mlx.utils import tree_flatten, tree_map, tree_unflatten | |
import numpy as np | |
# Copyright © 2023-2024 Apple Inc. | |
import contextlib | |
import copy | |
import glob | |
import importlib | |
import json | |
import logging | |
import shutil | |
import time | |
from pathlib import Path | |
from textwrap import dedent | |
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union | |
import mlx.core as mx | |
import mlx.nn as nn | |
from huggingface_hub import snapshot_download | |
from mlx.utils import tree_flatten, tree_reduce | |
from transformers import PreTrainedTokenizer | |
# Local imports | |
from mlx_lm.models import cache | |
from mlx_lm.sample_utils import make_logits_processors, make_sampler | |
from mlx_lm.tokenizer_utils import TokenizerWrapper, load_tokenizer | |
from mlx_lm.tuner.utils import dequantize as dequantize_model | |
from mlx_lm.tuner.utils import load_adapters | |
"""model, tokenizer = load("Qwen2.5-1.5B-Instruct-bf16") | |
prompt = "Write a story about Einstein" | |
messages = [{"role": "user", "content": prompt}] | |
prompt = tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
response = generate(model, tokenizer, prompt=prompt, verbose=True)""" | |
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): | |
if ( | |
kv_bits is not None | |
and not isinstance(prompt_cache[0], cache.QuantizedKVCache) | |
and prompt_cache[0].offset > quantized_kv_start | |
): | |
for i in range(len(prompt_cache)): | |
prompt_cache[i] = prompt_cache[i].to_quantized( | |
group_size=kv_group_size, bits=kv_bits | |
) | |
generation_stream = mx.new_stream(mx.default_device()) | |
@contextlib.contextmanager | |
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): | |
""" | |
A context manager to temporarily change the wired limit. | |
Note, the wired limit should not be changed during an async eval. If an | |
async eval could be running pass in the streams to synchronize with prior | |
to exiting the context manager. | |
""" | |
model_bytes = tree_reduce( | |
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 | |
) | |
max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"] | |
if model_bytes > 0.9 * max_rec_size: | |
model_mb = model_bytes // 2**20 | |
max_rec_mb = max_rec_size // 2**20 | |
print( | |
f"[WARNING] Generating with a model that requires {model_mb} MB " | |
f"which is close to the maximum recommended size of {max_rec_mb} " | |
"MB. This can be slow. See the documentation for possible work-arounds: " | |
"https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models" | |
) | |
old_limit = mx.metal.set_wired_limit(max_rec_size) | |
try: | |
yield None | |
finally: | |
if streams is not None: | |
for s in streams: | |
mx.synchronize(s) | |
else: | |
mx.synchronize() | |
mx.metal.set_wired_limit(old_limit) | |
def generate_step( | |
prompt: mx.array, | |
model: nn.Module, | |
temp: float = 0.0, | |
repetition_penalty: Optional[float] = None, | |
repetition_context_size: Optional[int] = 20, | |
top_p: float = 1.0, | |
min_p: float = 0.0, | |
min_tokens_to_keep: int = 1, | |
prefill_step_size: int = 512, | |
max_kv_size: Optional[int] = None, | |
prompt_cache: Optional[Any] = None, | |
logit_bias: Optional[Dict[int, float]] = None, | |
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, | |
kv_bits: Optional[int] = None, | |
kv_group_size: int = 64, | |
quantized_kv_start: int = 0, | |
) -> Generator[Tuple[mx.array, mx.array], None, None]: | |
""" | |
A generator producing token ids based on the given prompt from the model. | |
Args: | |
prompt (mx.array): The input prompt. | |
model (nn.Module): The model to use for generation. | |
temp (float): The temperature for sampling, if 0 the argmax is used. | |
Default: ``0``. | |
repetition_penalty (float, optional): The penalty factor for repeating | |
tokens. | |
repetition_context_size (int, optional): The number of tokens to | |
consider for repetition penalty. Default: ``20``. | |
top_p (float, optional): Nulceus sampling, higher means model considers | |
more less likely words. | |
min_p (float, optional): The minimum value (scaled by the top token's | |
probability) that a token probability must have to be considered. | |
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot | |
be filtered by min_p sampling. | |
prefill_step_size (int): Step size for processing the prompt. | |
max_kv_size (int, optional): Maximum size of the key-value cache. Old | |
entries (except the first 4 tokens) will be overwritten. | |
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if | |
provided, the cache will be updated in place. | |
logit_bias (dictionary, optional): Additive logit bias. | |
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): | |
A list of functions that take tokens and logits and return the processed | |
logits. Default: ``None``. | |
kv_bits (int, optional): Number of bits to use for KV cache quantization. | |
None implies no cache quantization. Default: ``None``. | |
kv_group_size (int): Group size for KV cache quantization. Default: ``64``. | |
quantized_kv_start (int): Step to begin using a quantized KV cache. | |
when ``kv_bits`` is non-None. Default: ``0``. | |
Yields: | |
Tuple[mx.array, mx.array]: One token and a vector of log probabilities. | |
""" | |
y = prompt | |
tokens = None | |
# Create the KV cache for generation | |
if prompt_cache is None: | |
prompt_cache = cache.make_prompt_cache( | |
model, | |
max_kv_size=max_kv_size, | |
) | |
elif len(prompt_cache) != len(model.layers): | |
raise ValueError("Wrong number of layers in the prompt cache.") | |
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) | |
logits_processors = logits_processors or [] | |
logits_processors.extend( | |
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size) | |
) | |
def _step(y): | |
with mx.stream(generation_stream): | |
logits = model(y[None], cache=prompt_cache) | |
logits = logits[:, -1, :] | |
if logits_processors: | |
nonlocal tokens | |
tokens = mx.concat([tokens, y]) if tokens is not None else y | |
for processor in logits_processors: | |
logits = processor(tokens, logits) | |
maybe_quantize_kv_cache( | |
prompt_cache, quantized_kv_start, kv_group_size, kv_bits | |
) | |
logprobs = logits - mx.logsumexp(logits, keepdims=True) | |
y = sampler(logprobs) | |
return y, logprobs.squeeze(0) | |
while y.size > prefill_step_size: | |
model(y[:prefill_step_size][None], cache=prompt_cache) | |
mx.eval([c.state for c in prompt_cache]) | |
y = y[prefill_step_size:] | |
mx.metal.clear_cache() | |
y, logprobs = _step(y) | |
mx.async_eval(y, logprobs) | |
n = 0 | |
while True: | |
next_y, next_logprobs = _step(y) | |
mx.async_eval(next_y, next_logprobs) | |
yield y.item(), logprobs | |
if n % 256 == 0: | |
mx.metal.clear_cache() | |
n += 1 | |
y, logprobs = next_y, next_logprobs | |
def stream_generate( | |
model: nn.Module, | |
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], | |
prompt: Union[str, List[int]], | |
max_tokens: int = 100, | |
**kwargs, | |
) -> Generator[Tuple[str, int, mx.array], None, None]: | |
""" | |
A generator producing text based on the given prompt from the model. | |
Args: | |
model (nn.Module): The model to use for generation. | |
tokenizer (PreTrainedTokenizer): The tokenizer. | |
prompt (Union[str, List[int]]): The input prompt string or integer tokens. | |
max_tokens (int): The maximum number of tokens. Default: ``100``. | |
kwargs: The remaining options get passed to :func:`generate_step`. | |
See :func:`generate_step` for more details. | |
Yields: | |
Tuple[str, int, mx.array]: | |
The next text segment, token, and vector of log probabilities. | |
""" | |
if not isinstance(tokenizer, TokenizerWrapper): | |
tokenizer = TokenizerWrapper(tokenizer) | |
prompt_tokens = mx.array( | |
prompt if isinstance(prompt, list) else tokenizer.encode(prompt) | |
) | |
detokenizer = tokenizer.detokenizer | |
with wired_limit(model, [generation_stream]): | |
detokenizer.reset() | |
for n, (token, logits) in zip( | |
range(max_tokens), | |
generate_step(prompt_tokens, model, **kwargs), | |
): | |
if token == tokenizer.eos_token_id: | |
break | |
detokenizer.add_token(token) | |
if n == (max_tokens - 1): | |
break | |
yield detokenizer.last_segment, token, logits | |
detokenizer.finalize() | |
yield detokenizer.last_segment, token, logits | |
def generate( | |
model: nn.Module, | |
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], | |
prompt: str, | |
max_tokens: int = 100, | |
verbose: bool = False, | |
formatter: Optional[Callable] = None, | |
**kwargs, | |
) -> str: | |
""" | |
Generate a complete response from the model. | |
Args: | |
model (nn.Module): The language model. | |
tokenizer (PreTrainedTokenizer): The tokenizer. | |
prompt (str): The string prompt. | |
max_tokens (int): The maximum number of tokens. Default: ``100``. | |
verbose (bool): If ``True``, print tokens and timing information. | |
Default: ``False``. | |
formatter (Optional[Callable]): A function which takes a token and a | |
probability and displays it. | |
kwargs: The remaining options get passed to :func:`generate_step`. | |
See :func:`generate_step` for more details. | |
""" | |
if not isinstance(tokenizer, TokenizerWrapper): | |
tokenizer = TokenizerWrapper(tokenizer) | |
if verbose: | |
print("=" * 10) | |
print("Prompt:", prompt) | |
prompt_tokens = mx.array(tokenizer.encode(prompt)) | |
detokenizer = tokenizer.detokenizer | |
with wired_limit(model, [generation_stream]): | |
tic = time.perf_counter() | |
detokenizer.reset() | |
for n, (token, logprobs) in zip( | |
range(max_tokens), | |
generate_step(prompt_tokens, model, **kwargs), | |
): | |
if n == 0: | |
prompt_time = time.perf_counter() - tic | |
tic = time.perf_counter() | |
if token == tokenizer.eos_token_id: | |
break | |
detokenizer.add_token(token) | |
if verbose: | |
if formatter: | |
# We have to finalize so that the prob corresponds to the last segment | |
detokenizer.finalize() | |
prob = mx.exp(logprobs[token]).item() | |
formatter(detokenizer.last_segment, prob) | |
else: | |
print(detokenizer.last_segment, end="", flush=True) | |
token_count = n + 1 | |
detokenizer.finalize() | |
if verbose: | |
gen_time = time.perf_counter() - tic | |
print(detokenizer.last_segment, flush=True) | |
print("=" * 10) | |
if token_count == 0: | |
print("No tokens generated for this prompt") | |
return | |
prompt_tps = prompt_tokens.size / prompt_time | |
gen_tps = (token_count - 1) / gen_time | |
print( | |
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" | |
) | |
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") | |
peak_mem = mx.metal.get_peak_memory() / 1e9 | |
print(f"Peak memory: {peak_mem:.3f} GB") | |
return detokenizer.text | |
def can_trim_prompt_cache(cache: List[Any]) -> bool: | |
""" | |
Check if model's cache can be trimmed. | |
""" | |
return all(c.is_trimmable() for c in cache) | |
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: | |
""" | |
Trim the model's cache by the given number of tokens. | |
This function will trim the cache if possible (in-place) and return the | |
number of tokens that were trimmed. | |
Args: | |
cache (List[Any]): The model's cache. | |
num_tokens (int): The number of tokens to trim. | |
Returns: | |
(int): The number of tokens that were trimmed. | |
""" | |
return [c.trim(num_tokens) for c in cache][0] | |
@mx.compile | |
def compute_divergence_point(model_choices, draft_tokens, tokens_generated): | |
eq_mask = model_choices[:tokens_generated] == draft_tokens[1:] | |
divergence_point = (eq_mask.argmin() + tokens_generated * eq_mask.all()) | |
return divergence_point | |
@mx.compile | |
def compute_divergence_point_2(model_choices, draft_tokens): | |
return (model_choices[0] == draft_tokens[1]) * (1 + (model_choices[1] == draft_tokens[2])) | |
@mx.compile | |
def get_new_tokens(d, divergence_point, model_choices): | |
d[:divergence_point] = d[1:1 + divergence_point] | |
d[divergence_point] = model_choices[divergence_point] | |
return d[:divergence_point + 1] | |
def generate_speculative( | |
model: nn.Module, | |
draft_model: nn.Module, | |
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], | |
prompt: str, | |
max_tokens: int = 100, | |
verbose: bool = False, | |
formatter: Optional[Callable] = None, | |
**kwargs, | |
) -> str: | |
""" | |
Generate a complete response from the model. | |
Args: | |
model (nn.Module): The language model. | |
tokenizer (PreTrainedTokenizer): The tokenizer. | |
prompt (str): The string prompt. | |
max_tokens (int): The maximum number of tokens. Default: ``100``. | |
verbose (bool): If ``True``, print tokens and timing information. | |
Default: ``False``. | |
formatter (Optional[Callable]): A function which takes a token and a | |
probability and displays it. | |
kwargs: The remaining options get passed to :func:`generate_step`. | |
See :func:`generate_step` for more details. | |
""" | |
if not isinstance(tokenizer, TokenizerWrapper): | |
tokenizer = TokenizerWrapper(tokenizer) | |
if verbose: | |
print("=" * 10) | |
print("Prompt:", prompt) | |
prompt_tokens = mx.array(tokenizer.encode(prompt)) | |
detokenizer = tokenizer.detokenizer | |
with wired_limit(model, [generation_stream]): | |
detokenizer.reset() | |
y = prompt_tokens | |
tokens = [] | |
# Create the KV cache for generation | |
prompt_cache = cache.make_prompt_cache( | |
model, | |
max_kv_size=None, | |
) | |
draft_cache = cache.make_prompt_cache( | |
draft_model, | |
max_kv_size=None, | |
) | |
# Pass the prompt through the model | |
k = 2 | |
_ = model(y[None, :-1], cache=prompt_cache) | |
last_model_token = y[-1] | |
draft_logits = draft_model(y[None], cache=draft_cache)[0, -1, :][None, None] | |
accepted = 0 | |
max_accepted = 0 | |
all_accepted = 0 | |
total_runs = 0 | |
d = mx.zeros((k + 1,), dtype=mx.uint32) | |
tokens_added_without_cache_clear = 0 | |
mx.eval(draft_logits) | |
mx.eval(_) | |
start_time = time.perf_counter() | |
while len(tokens) < max_tokens: | |
if tokens_added_without_cache_clear > 255: | |
mx.metal.clear_cache() | |
tokens_added_without_cache_clear = 0 | |
d[0] = last_model_token | |
for i in range(k): | |
token = mx.argmax(draft_logits, axis=-1) | |
mx.async_eval(token) | |
d[i + 1] = token[0, 0] | |
draft_logits = draft_model(token, cache=draft_cache) | |
model_logits = model(d[None], cache=prompt_cache) | |
model_choices = mx.argmax(model_logits, axis=-1)[0] | |
mx.async_eval(model_choices) | |
divergence_point = compute_divergence_point_2(model_choices, d).item() | |
accepted += divergence_point | |
max_accepted += k | |
if divergence_point == k: | |
all_accepted += 1 | |
total_runs += 1 | |
new_tokens = get_new_tokens(d, divergence_point, model_choices) | |
need_to_break = False | |
for token in new_tokens: | |
if token == tokenizer.eos_token_id: | |
need_to_break = True | |
break | |
item = token.item() | |
detokenizer.add_token(item) | |
tokens.append(item) | |
tokens_added_without_cache_clear += 1 | |
print(detokenizer.last_segment, flush=True, end="") | |
if k != divergence_point: | |
trim_prompt_cache(prompt_cache, k - divergence_point) | |
trim_prompt_cache(draft_cache, k - divergence_point) | |
if need_to_break: | |
break | |
last_model_token = model_choices[divergence_point] | |
draft_logits = draft_model(last_model_token[None, None], cache=draft_cache) | |
detokenizer.finalize() | |
print(detokenizer.last_segment, flush=True) | |
print("=" * 10) | |
end_time = time.perf_counter() | |
tokens_per_second = len(tokens) / (end_time - start_time) | |
print(f"Tokens per second: {tokens_per_second}") | |
print(f"Accepted Percentage: {(accepted / max_accepted) * 100:2f}%") | |
print(f"All Accepted Percentage: {(all_accepted / total_runs) * 100:2f}%") | |
print(f"Max Memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB") | |
return detokenizer.text | |
model, tokenizer = load("Qwen2.5-Coder-32B-Instruct-4bit") | |
draft_model,_ = load("Qwen2.5-0.5B-Instruct-8bit") | |
prompt = "Write quicksort in python." | |
messages = [{"role": "user", "content": prompt}] | |
prompt = tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
response = generate_speculative(model, draft_model, tokenizer, prompt=prompt, verbose=True, max_tokens=1024) | |
#true_response = generate(model, tokenizer, prompt=prompt, verbose=True, max_tokens=1024) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment