Skip to content

Instantly share code, notes, and snippets.

@N8python
Created November 16, 2024 05:38
Show Gist options
  • Save N8python/49e837b7b8a7addd3d9342679de6f576 to your computer and use it in GitHub Desktop.
Save N8python/49e837b7b8a7addd3d9342679de6f576 to your computer and use it in GitHub Desktop.
def generate_speculative(
model: nn.Module,
draft_model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
max_tokens: int = 100,
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
) -> str:
"""
Generate a complete response from the model.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
formatter (Optional[Callable]): A function which takes a token and a
probability and displays it.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
if verbose:
print("=" * 10)
print("Prompt:", prompt)
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
with wired_limit(model, [generation_stream]):
tic = time.perf_counter()
detokenizer.reset()
"""for n, (token, logprobs) in zip(
range(max_tokens),
generate_speculative_step(prompt_tokens, model, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
if verbose:
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)"""
y = prompt_tokens
tokens = []
# Create the KV cache for generation
prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=None,
)
draft_cache = cache.make_prompt_cache(
draft_model,
max_kv_size=None,
)
# Pass the prompt through the model
start_time = time.time()
k = 4
model(y[None, :-1], cache=prompt_cache)
last_model_token = y[-1].item()
draft_logits = draft_model(y[None], cache=draft_cache)[0, -1, :]
while len(tokens) < max_tokens:
draft_tokens = []
for i in range(k):
token = mx.argmax(draft_logits, axis=-1)
draft_tokens.append(token.item())
y = mx.array([token.item()])
#tokenizer.detokenizer.add_token(token.item())
draft_logits = draft_model(y[None], cache=draft_cache)
# Concatenate
y = mx.array([last_model_token] + draft_tokens)
model_logits = model(y[None], cache=prompt_cache)
model_choices = mx.argmax(model_logits, axis=-1)[0]
# Find point of divergence
divergence_point = 0
for i in range(k):
if model_choices[i] != draft_tokens[i]:
divergence_point = i
break
else:
divergence_point = k
#print(draft_tokens)
#print(model_choices)
#print(divergence_point)
new_tokens = draft_tokens[:divergence_point] + [model_choices[divergence_point].item()]
# Check for EOS
tokens.extend(new_tokens)
#detokenizer.add_tokens(new_tokens)
need_to_break = False
for token in new_tokens:
if token == tokenizer.eos_token_id:
need_to_break = True
break
detokenizer.add_token(token)
print(detokenizer.last_segment, flush=True, end="")
if need_to_break:
break
# Rollback cache to the divergence point
#print(prompt_cache[0].offset)
#print(draft_cache[0].offset)
trim_prompt_cache(prompt_cache, k - divergence_point)
trim_prompt_cache(draft_cache, k - divergence_point)
#print(prompt_cache[0].offset)
#print(draft_cache[0].offset)
last_model_token = model_choices[divergence_point].item()
draft_logits = draft_model(mx.array([last_model_token])[None], cache=draft_cache)[0, -1, :]
detokenizer.finalize()
print(detokenizer.last_segment, flush=True)
print("=" * 10)
end_time = time.time()
tokens_per_second = len(tokens) / (end_time - start_time)
print(f"Tokens per second: {tokens_per_second}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment