Created
December 31, 2024 00:39
-
-
Save N8python/8c363efce39a4c843d67f3a7aae998f8 to your computer and use it in GitHub Desktop.
Prototypes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generate_batched( | |
model: nn.Module, | |
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], | |
prompt: str, | |
batch_size: int, | |
*, | |
verbose: bool = False, | |
formatter: Optional[Callable] = None, | |
max_tokens: int = 256, | |
temp: float = 0.0, | |
top_p: float = 0.0, | |
min_p: float = 0.0, | |
min_tokens_to_keep: int = 1, | |
repetition_penalty: float = 1.0, | |
repetition_context_size: int = 20, | |
kv_bits: Optional[int] = None, | |
kv_group_size: int = 64, | |
quantized_kv_start: int = 0, | |
max_kv_size: Optional[int] = None, | |
prefill_step_size: int = 512, | |
**kwargs, | |
) -> List[str]: | |
""" | |
Generate multiple responses in parallel from the same prompt. | |
Args: | |
model (nn.Module): The language model. | |
tokenizer (PreTrainedTokenizer or TokenizerWrapper): The tokenizer. | |
prompt (str): The string prompt. | |
batch_size (int): Number of parallel sequences to generate. | |
verbose (bool): If True, prints tokens and timing information. Default: False. | |
formatter (Callable): Deprecated. (No longer used) | |
max_tokens (int): The maximum number of tokens to generate. Default: 256. | |
temp (float): Temperature for sampling. Default: 0.0. | |
top_p (float): Nucleus sampling top-p parameter. Default: 0.0. | |
min_p (float): Minimum cumulative probability cutoff. Default: 0.0. | |
min_tokens_to_keep (int): Ensures a minimum number of tokens remain after filtering. Default: 1. | |
repetition_penalty (float): Repetition penalty. Default: 1.0. | |
repetition_context_size (int): The context size to consider for the repetition penalty. Default: 20. | |
kv_bits (int): Number of bits for KV cache quantization. None = disabled. Default: None. | |
kv_group_size (int): Group size for KV cache quantization. Default: 64. | |
quantized_kv_start (int): The step to begin using a quantized KV cache. Default: 0. | |
max_kv_size (int): The maximum size of the KV cache. Old tokens get overwritten. Default: None. | |
prefill_step_size (int): Step size used when processing the prompt (prefill). Default: 512. | |
**kwargs: Unused extra kwargs, included for API-compatibility. | |
Returns: | |
List[str]: A list of decoded text strings of length `batch_size`. | |
""" | |
if formatter is not None: | |
print( | |
"[Warning] Text formatting is deprecated and no longer used. " | |
"The argument will be removed in a future version." | |
) | |
# Ensure we have a TokenizerWrapper | |
if not isinstance(tokenizer, TokenizerWrapper): | |
tokenizer = TokenizerWrapper(tokenizer) | |
# Encode the prompt | |
prompt_ids = mx.array(tokenizer.encode(prompt)) | |
# Prepare to replicate the single prompt for all batch sequences | |
# Shape: (batch_size, prompt_length) | |
#batched_prompt_ids = mx.repeat(prompt_ids[None, :], repeats=batch_size, axis=0) | |
# We'll maintain the partially decoded tokens for each batch element | |
# and the final decoded strings | |
decoded_texts = ["" for _ in range(batch_size)] | |
# Bookkeeping for which sequences have ended | |
ended = [False] * batch_size | |
# Create any required sampler and logits processors | |
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) | |
logits_processors = make_logits_processors(None, repetition_penalty, repetition_context_size) | |
# Create or initialize the prompt cache (one per batch element). | |
# However, the usual approach in MLX expects a single cache for the entire | |
# batch if the model and code are implemented to handle B in dimension 0. | |
# We'll create that here: | |
prompt_cache = cache.make_prompt_cache(model, max_kv_size=max_kv_size) | |
# If the model's forward pass is truly multi-batch aware, a single prompt_cache | |
# with shape (B, ...) is used internally. Otherwise, additional logic may be | |
# needed to replicate caches. We assume the model can handle batch dimension = B. | |
# Some timers for verbosity | |
generation_stream = mx.new_stream(mx.default_device()) | |
with wired_limit(model, [generation_stream]): | |
tic = time.perf_counter() | |
# Step 1. "Prefill" / process the prompt in chunks to fill the KV cache | |
total_prompt_len = prompt_ids.shape[0] | |
processed = 0 | |
while processed < total_prompt_len: | |
chunk_end = min(processed + prefill_step_size, total_prompt_len) | |
# Forward pass of shape: (batch_size, chunk_size) | |
inputs_chunk = prompt_ids[processed:chunk_end] | |
with mx.stream(generation_stream): | |
_ = model(inputs_chunk[None], cache=prompt_cache) | |
maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits) | |
mx.eval([c.state for c in prompt_cache]) | |
processed = chunk_end | |
mx.metal.clear_cache() | |
# The time spent so far was for prompt processing | |
prompt_time = time.perf_counter() - tic | |
if total_prompt_len == 0: | |
prompt_tps = 0.0 | |
else: | |
prompt_tps = (batch_size * total_prompt_len) / prompt_time | |
# For decoding, we feed in one token at a time (the last token from each sequence). | |
# So we pick up from the last token of the prompt. If the prompt is empty, we | |
# set it to an empty array. We'll store the "current token" for each sequence in y. | |
# Initialize it with the last token from each sequence (if prompt_length > 0). | |
# If the prompt is empty, we can set to e.g. a special BOS token if needed. | |
# For simplicity we handle an empty prompt by refusing to decode (but you could pick | |
# a default token if your model needs that). | |
"""if total_prompt_len > 0: | |
y = batched_prompt_ids[:, -1:] | |
else: | |
raise ValueError("Cannot decode from an empty prompt in this example. " | |
"Add or adapt code to handle an empty initial token if needed.")""" | |
y = mx.repeat(prompt_ids[-1:][None, :], repeats=batch_size, axis=0) | |
for c in prompt_cache: | |
c.keys = mx.repeat(c.keys, repeats=batch_size, axis=0) | |
c.values = mx.repeat(c.values, repeats=batch_size, axis=0) | |
# We also keep track of the entire decoded tokens. We start them out with the entire prompt. | |
# We'll append new tokens as they are generated. shape: (B, T_so_far) | |
#tokens_so_far = batched_prompt_ids | |
tokens_so_far = [[] for _ in range(batch_size)] | |
# Step 2. Start generating new tokens (decode) until all ended or max_tokens is reached | |
tic = time.perf_counter() | |
n = 0 | |
generated_positions_to_batch_idxs = [i for i in range(batch_size)] | |
while True: | |
if n >= max_tokens: | |
break | |
# Forward pass for the current token(s). The model expects shape (B, L). | |
# L=1 for incremental decoding. We'll get shape (B, L, V). | |
with mx.stream(generation_stream): | |
logits = model(y, cache=prompt_cache) | |
# logits: shape (B, L, vocab_size) -> each row's last token is at index -1 | |
# We only need the last token's logits for sampling | |
logits = logits[:, -1, :] # shape: (B, vocab_size) | |
maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits) | |
mx.async_eval(logits) | |
# We'll do a per-sequence loop for sampling and store the results | |
next_tokens_list = [] | |
to_remove = [] | |
for i in range(len(generated_positions_to_batch_idxs)): | |
# 1. Apply any logits processors | |
proc_logits = logits[i] | |
for processor in logits_processors: | |
proc_logits = processor(tokens_so_far[i], proc_logits[None]) | |
# 2. Convert to logprobs | |
logprobs = proc_logits - mx.logsumexp(proc_logits, keepdims=True) | |
# 3. Sample the next token | |
sampled_token = sampler(logprobs) # shape: (1,) | |
sampled_token_val = sampled_token.item() # an int | |
# 4. Check if EOS | |
if sampled_token_val in tokenizer.eos_token_ids: | |
to_remove.append(generated_positions_to_batch_idxs[i]) | |
else: | |
tokens_so_far[generated_positions_to_batch_idxs[i]].append(sampled_token_val) | |
next_tokens_list.append(sampled_token_val) | |
#if not ended[i]: | |
if len(to_remove) > 0: | |
indices = [] | |
for i in range(len(generated_positions_to_batch_idxs)): | |
if generated_positions_to_batch_idxs[i] not in to_remove: | |
indices.append(i) | |
if len(indices) == 0: | |
break | |
next_tokens_list = [next_tokens_list[i] for i in indices] | |
for c in prompt_cache: | |
#c.keys = mx.repeat(c.keys, repeats=batch_size, axis=0) | |
#c.values = mx.repeat(c.values, repeats=batch_size, axis=0) | |
c.keys = mx.take(c.keys, indices=mx.array(indices), axis=0) | |
c.values = mx.take(c.values, indices=mx.array(indices), axis=0) | |
if len(indices) == 1: | |
c.keys = c.keys[None] | |
c.values = c.values[None] | |
# Remove from generated_positions_to_batch_idxs | |
for i in to_remove: | |
generated_positions_to_batch_idxs.remove(i) | |
mx.metal.clear_cache() | |
# Convert next_tokens_list -> (B, 1) array | |
next_tokens = mx.array(next_tokens_list).reshape(len(next_tokens_list), 1) | |
# Prepare for the next iteration | |
y = next_tokens | |
n += 1 | |
if n % 256 == 0: | |
mx.metal.clear_cache() | |
# Done with generation | |
generation_time = time.perf_counter() - tic | |
total_generated_tokens = 0 | |
# decode all sequences | |
for i in range(batch_size): | |
to_decode = tokens_so_far[i] | |
total_generated_tokens += len(to_decode) | |
decoded_texts[i] = tokenizer.decode(to_decode) | |
# Optionally print verbose info | |
if verbose: | |
for i, txt in enumerate(decoded_texts): | |
print("=" * 10) | |
print(f"Batch {i}: {txt}") | |
print("=" * 10) | |
if len(decoded_texts) == 0: | |
print("No text generated for this prompt.") | |
else: | |
# If all sequences have the same # prompt tokens (which they do in this design), | |
# we can still measure TPS | |
print( | |
f"Prompt tokens (per sequence): {total_prompt_len}, " | |
f"Prompt TPS (across all sequences): {prompt_tps:.3f}" | |
) | |
# generation tokens is n for each sequence in the worst case | |
# (some may have ended earlier though). We'll still show a rough TPS: | |
print( | |
f"Generation tokens (max per sequence): {n}, " | |
f"Generation TPS (across all sequences): " | |
f"{total_generated_tokens / generation_time:.3f}" | |
) | |
peak_mem = mx.metal.get_peak_memory() / 1e9 | |
print(f"Peak memory: {peak_mem:.3f} GB") | |
return decoded_texts |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment