Skip to content

Instantly share code, notes, and snippets.

@N8python
Created December 31, 2024 00:39
Show Gist options
  • Save N8python/8c363efce39a4c843d67f3a7aae998f8 to your computer and use it in GitHub Desktop.
Save N8python/8c363efce39a4c843d67f3a7aae998f8 to your computer and use it in GitHub Desktop.
Prototypes.
def generate_batched(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
batch_size: int,
*,
verbose: bool = False,
formatter: Optional[Callable] = None,
max_tokens: int = 256,
temp: float = 0.0,
top_p: float = 0.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
repetition_penalty: float = 1.0,
repetition_context_size: int = 20,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
max_kv_size: Optional[int] = None,
prefill_step_size: int = 512,
**kwargs,
) -> List[str]:
"""
Generate multiple responses in parallel from the same prompt.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer or TokenizerWrapper): The tokenizer.
prompt (str): The string prompt.
batch_size (int): Number of parallel sequences to generate.
verbose (bool): If True, prints tokens and timing information. Default: False.
formatter (Callable): Deprecated. (No longer used)
max_tokens (int): The maximum number of tokens to generate. Default: 256.
temp (float): Temperature for sampling. Default: 0.0.
top_p (float): Nucleus sampling top-p parameter. Default: 0.0.
min_p (float): Minimum cumulative probability cutoff. Default: 0.0.
min_tokens_to_keep (int): Ensures a minimum number of tokens remain after filtering. Default: 1.
repetition_penalty (float): Repetition penalty. Default: 1.0.
repetition_context_size (int): The context size to consider for the repetition penalty. Default: 20.
kv_bits (int): Number of bits for KV cache quantization. None = disabled. Default: None.
kv_group_size (int): Group size for KV cache quantization. Default: 64.
quantized_kv_start (int): The step to begin using a quantized KV cache. Default: 0.
max_kv_size (int): The maximum size of the KV cache. Old tokens get overwritten. Default: None.
prefill_step_size (int): Step size used when processing the prompt (prefill). Default: 512.
**kwargs: Unused extra kwargs, included for API-compatibility.
Returns:
List[str]: A list of decoded text strings of length `batch_size`.
"""
if formatter is not None:
print(
"[Warning] Text formatting is deprecated and no longer used. "
"The argument will be removed in a future version."
)
# Ensure we have a TokenizerWrapper
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
# Encode the prompt
prompt_ids = mx.array(tokenizer.encode(prompt))
# Prepare to replicate the single prompt for all batch sequences
# Shape: (batch_size, prompt_length)
#batched_prompt_ids = mx.repeat(prompt_ids[None, :], repeats=batch_size, axis=0)
# We'll maintain the partially decoded tokens for each batch element
# and the final decoded strings
decoded_texts = ["" for _ in range(batch_size)]
# Bookkeeping for which sequences have ended
ended = [False] * batch_size
# Create any required sampler and logits processors
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
logits_processors = make_logits_processors(None, repetition_penalty, repetition_context_size)
# Create or initialize the prompt cache (one per batch element).
# However, the usual approach in MLX expects a single cache for the entire
# batch if the model and code are implemented to handle B in dimension 0.
# We'll create that here:
prompt_cache = cache.make_prompt_cache(model, max_kv_size=max_kv_size)
# If the model's forward pass is truly multi-batch aware, a single prompt_cache
# with shape (B, ...) is used internally. Otherwise, additional logic may be
# needed to replicate caches. We assume the model can handle batch dimension = B.
# Some timers for verbosity
generation_stream = mx.new_stream(mx.default_device())
with wired_limit(model, [generation_stream]):
tic = time.perf_counter()
# Step 1. "Prefill" / process the prompt in chunks to fill the KV cache
total_prompt_len = prompt_ids.shape[0]
processed = 0
while processed < total_prompt_len:
chunk_end = min(processed + prefill_step_size, total_prompt_len)
# Forward pass of shape: (batch_size, chunk_size)
inputs_chunk = prompt_ids[processed:chunk_end]
with mx.stream(generation_stream):
_ = model(inputs_chunk[None], cache=prompt_cache)
maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits)
mx.eval([c.state for c in prompt_cache])
processed = chunk_end
mx.metal.clear_cache()
# The time spent so far was for prompt processing
prompt_time = time.perf_counter() - tic
if total_prompt_len == 0:
prompt_tps = 0.0
else:
prompt_tps = (batch_size * total_prompt_len) / prompt_time
# For decoding, we feed in one token at a time (the last token from each sequence).
# So we pick up from the last token of the prompt. If the prompt is empty, we
# set it to an empty array. We'll store the "current token" for each sequence in y.
# Initialize it with the last token from each sequence (if prompt_length > 0).
# If the prompt is empty, we can set to e.g. a special BOS token if needed.
# For simplicity we handle an empty prompt by refusing to decode (but you could pick
# a default token if your model needs that).
"""if total_prompt_len > 0:
y = batched_prompt_ids[:, -1:]
else:
raise ValueError("Cannot decode from an empty prompt in this example. "
"Add or adapt code to handle an empty initial token if needed.")"""
y = mx.repeat(prompt_ids[-1:][None, :], repeats=batch_size, axis=0)
for c in prompt_cache:
c.keys = mx.repeat(c.keys, repeats=batch_size, axis=0)
c.values = mx.repeat(c.values, repeats=batch_size, axis=0)
# We also keep track of the entire decoded tokens. We start them out with the entire prompt.
# We'll append new tokens as they are generated. shape: (B, T_so_far)
#tokens_so_far = batched_prompt_ids
tokens_so_far = [[] for _ in range(batch_size)]
# Step 2. Start generating new tokens (decode) until all ended or max_tokens is reached
tic = time.perf_counter()
n = 0
generated_positions_to_batch_idxs = [i for i in range(batch_size)]
while True:
if n >= max_tokens:
break
# Forward pass for the current token(s). The model expects shape (B, L).
# L=1 for incremental decoding. We'll get shape (B, L, V).
with mx.stream(generation_stream):
logits = model(y, cache=prompt_cache)
# logits: shape (B, L, vocab_size) -> each row's last token is at index -1
# We only need the last token's logits for sampling
logits = logits[:, -1, :] # shape: (B, vocab_size)
maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits)
mx.async_eval(logits)
# We'll do a per-sequence loop for sampling and store the results
next_tokens_list = []
to_remove = []
for i in range(len(generated_positions_to_batch_idxs)):
# 1. Apply any logits processors
proc_logits = logits[i]
for processor in logits_processors:
proc_logits = processor(tokens_so_far[i], proc_logits[None])
# 2. Convert to logprobs
logprobs = proc_logits - mx.logsumexp(proc_logits, keepdims=True)
# 3. Sample the next token
sampled_token = sampler(logprobs) # shape: (1,)
sampled_token_val = sampled_token.item() # an int
# 4. Check if EOS
if sampled_token_val in tokenizer.eos_token_ids:
to_remove.append(generated_positions_to_batch_idxs[i])
else:
tokens_so_far[generated_positions_to_batch_idxs[i]].append(sampled_token_val)
next_tokens_list.append(sampled_token_val)
#if not ended[i]:
if len(to_remove) > 0:
indices = []
for i in range(len(generated_positions_to_batch_idxs)):
if generated_positions_to_batch_idxs[i] not in to_remove:
indices.append(i)
if len(indices) == 0:
break
next_tokens_list = [next_tokens_list[i] for i in indices]
for c in prompt_cache:
#c.keys = mx.repeat(c.keys, repeats=batch_size, axis=0)
#c.values = mx.repeat(c.values, repeats=batch_size, axis=0)
c.keys = mx.take(c.keys, indices=mx.array(indices), axis=0)
c.values = mx.take(c.values, indices=mx.array(indices), axis=0)
if len(indices) == 1:
c.keys = c.keys[None]
c.values = c.values[None]
# Remove from generated_positions_to_batch_idxs
for i in to_remove:
generated_positions_to_batch_idxs.remove(i)
mx.metal.clear_cache()
# Convert next_tokens_list -> (B, 1) array
next_tokens = mx.array(next_tokens_list).reshape(len(next_tokens_list), 1)
# Prepare for the next iteration
y = next_tokens
n += 1
if n % 256 == 0:
mx.metal.clear_cache()
# Done with generation
generation_time = time.perf_counter() - tic
total_generated_tokens = 0
# decode all sequences
for i in range(batch_size):
to_decode = tokens_so_far[i]
total_generated_tokens += len(to_decode)
decoded_texts[i] = tokenizer.decode(to_decode)
# Optionally print verbose info
if verbose:
for i, txt in enumerate(decoded_texts):
print("=" * 10)
print(f"Batch {i}: {txt}")
print("=" * 10)
if len(decoded_texts) == 0:
print("No text generated for this prompt.")
else:
# If all sequences have the same # prompt tokens (which they do in this design),
# we can still measure TPS
print(
f"Prompt tokens (per sequence): {total_prompt_len}, "
f"Prompt TPS (across all sequences): {prompt_tps:.3f}"
)
# generation tokens is n for each sequence in the worst case
# (some may have ended earlier though). We'll still show a rough TPS:
print(
f"Generation tokens (max per sequence): {n}, "
f"Generation TPS (across all sequences): "
f"{total_generated_tokens / generation_time:.3f}"
)
peak_mem = mx.metal.get_peak_memory() / 1e9
print(f"Peak memory: {peak_mem:.3f} GB")
return decoded_texts
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment