Created
November 16, 2024 05:38
-
-
Save N8python/49e837b7b8a7addd3d9342679de6f576 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generate_speculative( | |
model: nn.Module, | |
draft_model: nn.Module, | |
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], | |
prompt: str, | |
max_tokens: int = 100, | |
verbose: bool = False, | |
formatter: Optional[Callable] = None, | |
**kwargs, | |
) -> str: | |
""" | |
Generate a complete response from the model. | |
Args: | |
model (nn.Module): The language model. | |
tokenizer (PreTrainedTokenizer): The tokenizer. | |
prompt (str): The string prompt. | |
max_tokens (int): The maximum number of tokens. Default: ``100``. | |
verbose (bool): If ``True``, print tokens and timing information. | |
Default: ``False``. | |
formatter (Optional[Callable]): A function which takes a token and a | |
probability and displays it. | |
kwargs: The remaining options get passed to :func:`generate_step`. | |
See :func:`generate_step` for more details. | |
""" | |
if not isinstance(tokenizer, TokenizerWrapper): | |
tokenizer = TokenizerWrapper(tokenizer) | |
if verbose: | |
print("=" * 10) | |
print("Prompt:", prompt) | |
prompt_tokens = mx.array(tokenizer.encode(prompt)) | |
detokenizer = tokenizer.detokenizer | |
with wired_limit(model, [generation_stream]): | |
tic = time.perf_counter() | |
detokenizer.reset() | |
"""for n, (token, logprobs) in zip( | |
range(max_tokens), | |
generate_speculative_step(prompt_tokens, model, **kwargs), | |
): | |
if n == 0: | |
prompt_time = time.perf_counter() - tic | |
tic = time.perf_counter() | |
if token == tokenizer.eos_token_id: | |
break | |
detokenizer.add_token(token) | |
if verbose: | |
if formatter: | |
# We have to finalize so that the prob corresponds to the last segment | |
detokenizer.finalize() | |
prob = mx.exp(logprobs[token]).item() | |
formatter(detokenizer.last_segment, prob) | |
else: | |
print(detokenizer.last_segment, end="", flush=True)""" | |
y = prompt_tokens | |
tokens = [] | |
# Create the KV cache for generation | |
prompt_cache = cache.make_prompt_cache( | |
model, | |
max_kv_size=None, | |
) | |
draft_cache = cache.make_prompt_cache( | |
draft_model, | |
max_kv_size=None, | |
) | |
# Pass the prompt through the model | |
start_time = time.time() | |
k = 4 | |
model(y[None, :-1], cache=prompt_cache) | |
last_model_token = y[-1].item() | |
draft_logits = draft_model(y[None], cache=draft_cache)[0, -1, :] | |
while len(tokens) < max_tokens: | |
draft_tokens = [] | |
for i in range(k): | |
token = mx.argmax(draft_logits, axis=-1) | |
draft_tokens.append(token.item()) | |
y = mx.array([token.item()]) | |
#tokenizer.detokenizer.add_token(token.item()) | |
draft_logits = draft_model(y[None], cache=draft_cache) | |
# Concatenate | |
y = mx.array([last_model_token] + draft_tokens) | |
model_logits = model(y[None], cache=prompt_cache) | |
model_choices = mx.argmax(model_logits, axis=-1)[0] | |
# Find point of divergence | |
divergence_point = 0 | |
for i in range(k): | |
if model_choices[i] != draft_tokens[i]: | |
divergence_point = i | |
break | |
else: | |
divergence_point = k | |
#print(draft_tokens) | |
#print(model_choices) | |
#print(divergence_point) | |
new_tokens = draft_tokens[:divergence_point] + [model_choices[divergence_point].item()] | |
# Check for EOS | |
tokens.extend(new_tokens) | |
#detokenizer.add_tokens(new_tokens) | |
need_to_break = False | |
for token in new_tokens: | |
if token == tokenizer.eos_token_id: | |
need_to_break = True | |
break | |
detokenizer.add_token(token) | |
print(detokenizer.last_segment, flush=True, end="") | |
if need_to_break: | |
break | |
# Rollback cache to the divergence point | |
#print(prompt_cache[0].offset) | |
#print(draft_cache[0].offset) | |
trim_prompt_cache(prompt_cache, k - divergence_point) | |
trim_prompt_cache(draft_cache, k - divergence_point) | |
#print(prompt_cache[0].offset) | |
#print(draft_cache[0].offset) | |
last_model_token = model_choices[divergence_point].item() | |
draft_logits = draft_model(mx.array([last_model_token])[None], cache=draft_cache)[0, -1, :] | |
detokenizer.finalize() | |
print(detokenizer.last_segment, flush=True) | |
print("=" * 10) | |
end_time = time.time() | |
tokens_per_second = len(tokens) / (end_time - start_time) | |
print(f"Tokens per second: {tokens_per_second}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment