Created
February 22, 2025 03:34
-
-
Save cavit99/ea7d0ea41a54a41f577a75420c8762d0 to your computer and use it in GitHub Desktop.
sample beam search decoding with MLX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 2025 Cavit Erginsoy, MIT License | |
import time | |
from typing import Optional | |
import mlx.core as mx | |
import mlx.nn as nn | |
from transformers import AutoTokenizer, AutoConfig | |
from dataclasses import dataclass | |
from huggingface_hub import hf_hub_download | |
import heapq | |
# Constants | |
DEFAULT_PROMPT = "Find a very original name for an imaginary city in a sci-fi novel featuring an omnipotentAI." | |
DEFAULT_MAX_TOKENS = 100 | |
DEFAULT_NUM_BEAMS = 20 | |
DEFAULT_SHOW_CANDIDATES = 3 # Must be <= DEFAULT_NUM_BEAMS | |
MODEL_NAME = "mlx-community/Qwen2.5-0.5B-Instruct-bf16" | |
@dataclass | |
class ModelArgs: | |
model_type: str | |
hidden_size: int | |
num_hidden_layers: int | |
intermediate_size: int | |
num_attention_heads: int | |
rms_norm_eps: float | |
vocab_size: int | |
num_key_value_heads: Optional[int] = None | |
rope_theta: float = 1000000.0 | |
rope_traditional: bool = False | |
tie_word_embeddings: bool = True | |
def create_attention_mask(h: mx.array): | |
seq_length = h.shape[1] | |
mask = mx.arange(seq_length)[:, None] < mx.arange(seq_length)[None, :] | |
mask = mx.where(mask, -1e9, 0).astype(h.dtype) | |
return mask | |
class Attention(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
dim = args.hidden_size | |
self.n_heads = args.num_attention_heads | |
self.n_kv_heads = args.num_key_value_heads or args.num_attention_heads | |
head_dim = dim // self.n_heads | |
self.scale = head_dim**-0.5 | |
self.q_proj = nn.Linear(dim, self.n_heads * head_dim, bias=True) | |
self.k_proj = nn.Linear(dim, self.n_kv_heads * head_dim, bias=True) | |
self.v_proj = nn.Linear(dim, self.n_kv_heads * head_dim, bias=True) | |
self.o_proj = nn.Linear(self.n_heads * head_dim, dim, bias=False) | |
self.rope = nn.RoPE(head_dim, traditional=args.rope_traditional, base=args.rope_theta) | |
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: | |
B, L, D = x.shape | |
queries = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) | |
keys = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | |
values = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | |
if self.n_heads != self.n_kv_heads: | |
repeat_times = self.n_heads // self.n_kv_heads | |
keys = mx.repeat(keys, repeat_times, axis=1) | |
values = mx.repeat(values, repeat_times, axis=1) | |
queries = self.rope(queries) | |
keys = self.rope(keys) | |
scores = (queries @ keys.transpose(0, 1, 3, 2)) * self.scale | |
if mask is not None: | |
scores = scores + mask | |
scores = mx.softmax(scores, axis=-1) | |
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) | |
return self.o_proj(output) | |
class MLP(nn.Module): | |
def __init__(self, dim: int, hidden_dim: int): | |
super().__init__() | |
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) | |
self.down_proj = nn.Linear(hidden_dim, dim, bias=False) | |
self.up_proj = nn.Linear(dim, hidden_dim, bias=False) | |
def __call__(self, x) -> mx.array: | |
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) | |
class TransformerBlock(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
self.self_attn = Attention(args) | |
self.mlp = MLP(args.hidden_size, args.intermediate_size) | |
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) | |
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) | |
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: | |
h = x + self.self_attn(self.input_layernorm(x), mask) | |
out = h + self.mlp(self.post_attention_layernorm(h)) | |
return out | |
class Qwen2Model(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) | |
self.layers = [TransformerBlock(args) for _ in range(args.num_hidden_layers)] | |
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) | |
def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None): | |
h = self.embed_tokens(inputs) | |
if mask is None: | |
mask = create_attention_mask(h) | |
for layer in self.layers: | |
h = layer(h, mask) | |
return self.norm(h) | |
class Model(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
self.args = args | |
self.model = Qwen2Model(args) | |
if not args.tie_word_embeddings: | |
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) | |
def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None): | |
out = self.model(inputs, mask) | |
if self.args.tie_word_embeddings: | |
return self.model.embed_tokens.as_linear(out) | |
return self.lm_head(out) | |
def load(model_path: str): | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | |
model_file = hf_hub_download(repo_id=model_path, filename="model.safetensors") | |
args = ModelArgs( | |
model_type=config.model_type, | |
hidden_size=config.hidden_size, | |
num_hidden_layers=config.num_hidden_layers, | |
intermediate_size=config.intermediate_size, | |
num_attention_heads=config.num_attention_heads, | |
rms_norm_eps=config.rms_norm_eps, | |
vocab_size=config.vocab_size, | |
num_key_value_heads=config.num_key_value_heads, | |
rope_theta=getattr(config, "rope_theta", 1000000), | |
rope_traditional=getattr(config, "rope_traditional", False), | |
tie_word_embeddings=getattr(config, "tie_word_embeddings", True) | |
) | |
model = Model(args) | |
weights = mx.load(model_file) | |
if args.tie_word_embeddings: | |
weights.pop("lm_head.weight", None) | |
model.load_weights(list(weights.items())) | |
# Ensure model's vocab size matches weights | |
embed_weight = weights["model.embed_tokens.weight"] | |
actual_vocab_size = embed_weight.shape[0] | |
if actual_vocab_size != args.vocab_size: | |
print(f"Warning: Config vocab_size ({args.vocab_size}) differs from model weights ({actual_vocab_size}). Using model's vocab size.") | |
args.vocab_size = actual_vocab_size | |
model.args.vocab_size = actual_vocab_size | |
model.eval() | |
return model, tokenizer | |
class BeamHypotheses: | |
"""Manages beam search hypotheses and their scores. | |
Args: | |
num_beams: Number of beams to maintain | |
length_penalty: Penalty factor for sequence length | |
early_stopping: Whether to stop early when good candidates are found | |
max_length: Maximum allowed sequence length | |
""" | |
def __init__(self, num_beams: int, length_penalty: float = 1.0, | |
early_stopping: bool = False, max_length: Optional[int] = None) -> None: | |
self.num_beams = num_beams | |
self.length_penalty = length_penalty | |
self.early_stopping = early_stopping | |
self.max_length = max_length | |
self.beams = [] | |
self.worst_score = 1e9 | |
def add(self, sequence: list, score: float) -> None: | |
"""Add a new hypothesis to the beam.""" | |
heapq.heappush(self.beams, (score, sequence)) | |
def is_done(self, best_score: float, cur_len: int) -> bool: | |
"""Check if beam search is complete.""" | |
if self.early_stopping == "never": | |
return False | |
elif self.early_stopping: | |
return len(self.beams) >= self.num_beams | |
else: | |
if self.max_length and cur_len >= self.max_length: | |
return True | |
current_score = best_score / ((cur_len + 1) ** self.length_penalty) | |
return current_score <= self.worst_score and len(self.beams) >= self.num_beams | |
class Generator: | |
"""Text generation using various decoding strategies.""" | |
def __init__(self, model: Model, tokenizer: AutoTokenizer, | |
max_new_tokens: int = DEFAULT_MAX_TOKENS) -> None: | |
self.model = model | |
self.tokenizer = tokenizer | |
self.max_new_tokens = max_new_tokens | |
self.eos_token_id = tokenizer.eos_token_id | |
def _prepare_prompt(self, prompt: str) -> mx.array: | |
"""Prepare prompt by applying chat template and tokenization.""" | |
messages = [{"role": "system", "content": ""}, {"role": "user", "content": prompt}] | |
if hasattr(self.tokenizer, 'apply_chat_template'): | |
formatted_prompt = self.tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
else: | |
formatted_prompt = f"User: {prompt}\nAssistant:" | |
tokens = self.tokenizer.encode(formatted_prompt) | |
return mx.array(tokens) | |
def clean_response(self, text: str) -> str: | |
""" | |
Post-process generated text to extract only the assistant's response. | |
Handles both special tokens and role labels that might appear in plain text. | |
""" | |
text = text.replace('<|im_start|>', '').replace('<|im_end|>', '').strip() | |
if 'assistant' in text: | |
_, _, assistant_part = text.rpartition('assistant') | |
assistant_part = assistant_part.strip() | |
return assistant_part.split('user')[0].split('system')[0].strip() | |
return text | |
def generate_greedy(self, prompt: str) -> str: | |
"""Generate text using greedy decoding strategy.""" | |
input_ids = self._prepare_prompt(prompt) | |
generated_ids = input_ids | |
new_tokens = [] | |
for _ in range(self.max_new_tokens): | |
outputs = self.model(generated_ids[None, :])[0, -1, :] | |
mx.eval(outputs) | |
next_token = mx.argmax(outputs).item() | |
new_tokens.append(next_token) | |
generated_ids = mx.concatenate([generated_ids, mx.array([next_token])]) | |
if next_token == self.eos_token_id: | |
break | |
full_text = self.tokenizer.decode(generated_ids.tolist(), skip_special_tokens=True) | |
return self.clean_response(full_text) | |
def generate_beam(self, prompt: str, num_beams: int = DEFAULT_NUM_BEAMS, | |
early_stopping: bool = True, temperature: float = 1.0, | |
length_penalty: float = 0.7) -> list[str]: | |
"""Generate text using beam search decoding. | |
Args: | |
prompt: Input text prompt | |
num_beams: Number of beams for search | |
early_stopping: Whether to stop early when good candidates are found | |
temperature: Sampling temperature | |
length_penalty: Penalty factor for sequence length | |
Returns: | |
List of generated text candidates | |
""" | |
input_ids = self._prepare_prompt(prompt) | |
vocab_size = self.model.args.vocab_size | |
input_ids = mx.tile(input_ids[None, :], (num_beams, 1)) | |
beam_scores = mx.zeros((num_beams,), dtype=mx.float32) | |
beam_scores[1:] = -1e9 | |
beam_hyps = BeamHypotheses(num_beams, length_penalty=length_penalty, early_stopping=early_stopping) | |
finished = mx.zeros((num_beams,), dtype=mx.bool_) | |
for step in range(self.max_new_tokens): | |
if mx.all(finished): | |
break | |
logits = self.model(input_ids)[:, -1, :] | |
next_token_scores = nn.log_softmax(logits / temperature, axis=-1) | |
next_scores = next_token_scores + beam_scores[:, None] | |
next_scores = mx.where(finished[:, None], -1e9, next_scores) | |
top_k = min(num_beams * 2, vocab_size) | |
flat_next_scores = next_scores.reshape(-1) | |
top_indices = mx.argpartition(-flat_next_scores, top_k - 1)[:top_k] | |
top_scores = mx.take(flat_next_scores, top_indices) | |
sorted_indices = mx.argsort(-top_scores)[:num_beams] | |
beam_indices = top_indices[sorted_indices] // vocab_size | |
next_tokens = top_indices[sorted_indices] % vocab_size | |
next_scores = top_scores[sorted_indices] | |
new_input_ids = [] | |
for i in range(num_beams): | |
beam_idx = beam_indices[i].item() | |
new_seq = mx.concatenate([input_ids[beam_idx], mx.array([next_tokens[i].item()])]) | |
new_input_ids.append(new_seq) | |
input_ids = mx.stack(new_input_ids) | |
if mx.any(next_tokens == self.eos_token_id): | |
for i in range(num_beams): | |
if next_tokens[i] == self.eos_token_id: | |
adjusted_score = beam_scores[i].item() / ((input_ids.shape[1] + 1) ** length_penalty) | |
beam_hyps.add(input_ids[i].tolist(), adjusted_score) | |
finished = (next_tokens == self.eos_token_id) | mx.take(finished, beam_indices) | |
beam_scores = next_scores | |
if beam_hyps.is_done(beam_scores.max().item(), input_ids.shape[1]): | |
break | |
# Collect all candidate sequences from finished and current beams | |
finished_beams = beam_hyps.beams | |
current_beams = [] | |
cur_len = input_ids.shape[1] | |
for i in range(num_beams): | |
if not finished[i].item(): | |
raw_score = beam_scores[i].item() | |
adjusted_score = raw_score / ((cur_len + 1) ** length_penalty) | |
current_beams.append((adjusted_score, input_ids[i].tolist())) | |
all_candidates = finished_beams + current_beams | |
num_candidates = min(DEFAULT_SHOW_CANDIDATES, num_beams) | |
top_candidates = heapq.nlargest(num_candidates, all_candidates, key=lambda x: x[0]) | |
results = [] | |
for score, seq in top_candidates: | |
text = self.tokenizer.decode(seq, skip_special_tokens=True) | |
results.append(self.clean_response(text)) | |
# Fallback if no candidates found | |
if not results: | |
best_idx = mx.argmax(beam_scores).item() | |
best_sequence = input_ids[best_idx].tolist() | |
text = self.tokenizer.decode(best_sequence, skip_special_tokens=True) | |
results.append(self.clean_response(text)) | |
return results[:num_candidates] | |
def main(): | |
try: | |
print("Loading model from", MODEL_NAME) | |
model, tokenizer = load(MODEL_NAME) | |
print("Model loaded successfully!") | |
except Exception as e: | |
print(f"Failed to load model: {str(e)}") | |
return | |
generator = Generator(model, tokenizer) | |
while True: | |
prompt = input("\nEnter your prompt (or press Enter for default, 'q' to quit): ").strip() | |
if prompt.lower() == 'q': | |
break | |
if not prompt: | |
prompt = DEFAULT_PROMPT | |
print(f"Using default prompt: {DEFAULT_PROMPT}") | |
print("\nGenerating responses...") | |
# Greedy decoding | |
print("\n--- Greedy Decoding ---") | |
start_time = time.time() | |
try: | |
greedy_result = generator.generate_greedy(prompt) | |
greedy_time = time.time() - start_time | |
print(greedy_result) | |
print(f"Generation time: {greedy_time:.2f} seconds") | |
except Exception as e: | |
print(f"Greedy generation failed: {str(e)}") | |
# Beam search decoding | |
num_candidates = min(DEFAULT_SHOW_CANDIDATES, DEFAULT_NUM_BEAMS) | |
print(f"\n--- Beam Search Decoding (num_beams={DEFAULT_NUM_BEAMS}) ---") | |
start_time = time.time() | |
try: | |
beam_results = generator.generate_beam(prompt, num_beams=DEFAULT_NUM_BEAMS) | |
beam_time = time.time() - start_time | |
print(f"\nTop {num_candidates} candidates:") | |
for i, result in enumerate(beam_results, 1): | |
print(f"\n{i}. {result}") | |
print(f"\nGeneration time: {beam_time:.2f} seconds") | |
except Exception as e: | |
print(f"Beam search generation failed: {str(e)}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment