Skip to content

Instantly share code, notes, and snippets.

@cavit99
Created February 22, 2025 03:34
Show Gist options
  • Save cavit99/ea7d0ea41a54a41f577a75420c8762d0 to your computer and use it in GitHub Desktop.
Save cavit99/ea7d0ea41a54a41f577a75420c8762d0 to your computer and use it in GitHub Desktop.
sample beam search decoding with MLX
# 2025 Cavit Erginsoy, MIT License
import time
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from transformers import AutoTokenizer, AutoConfig
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
import heapq
# Constants
DEFAULT_PROMPT = "Find a very original name for an imaginary city in a sci-fi novel featuring an omnipotentAI."
DEFAULT_MAX_TOKENS = 100
DEFAULT_NUM_BEAMS = 20
DEFAULT_SHOW_CANDIDATES = 3 # Must be <= DEFAULT_NUM_BEAMS
MODEL_NAME = "mlx-community/Qwen2.5-0.5B-Instruct-bf16"
@dataclass
class ModelArgs:
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: Optional[int] = None
rope_theta: float = 1000000.0
rope_traditional: bool = False
tie_word_embeddings: bool = True
def create_attention_mask(h: mx.array):
seq_length = h.shape[1]
mask = mx.arange(seq_length)[:, None] < mx.arange(seq_length)[None, :]
mask = mx.where(mask, -1e9, 0).astype(h.dtype)
return mask
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = args.num_attention_heads
self.n_kv_heads = args.num_key_value_heads or args.num_attention_heads
head_dim = dim // self.n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, self.n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, self.n_kv_heads * head_dim, bias=True)
self.v_proj = nn.Linear(dim, self.n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(self.n_heads * head_dim, dim, bias=False)
self.rope = nn.RoPE(head_dim, traditional=args.rope_traditional, base=args.rope_theta)
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
B, L, D = x.shape
queries = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if self.n_heads != self.n_kv_heads:
repeat_times = self.n_heads // self.n_kv_heads
keys = mx.repeat(keys, repeat_times, axis=1)
values = mx.repeat(values, repeat_times, axis=1)
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries @ keys.transpose(0, 1, 3, 2)) * self.scale
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
h = x + self.self_attn(self.input_layernorm(x), mask)
out = h + self.mlp(self.post_attention_layernorm(h))
return out
class Qwen2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h)
for layer in self.layers:
h = layer(h, mask)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model = Qwen2Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None):
out = self.model(inputs, mask)
if self.args.tie_word_embeddings:
return self.model.embed_tokens.as_linear(out)
return self.lm_head(out)
def load(model_path: str):
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model_file = hf_hub_download(repo_id=model_path, filename="model.safetensors")
args = ModelArgs(
model_type=config.model_type,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
intermediate_size=config.intermediate_size,
num_attention_heads=config.num_attention_heads,
rms_norm_eps=config.rms_norm_eps,
vocab_size=config.vocab_size,
num_key_value_heads=config.num_key_value_heads,
rope_theta=getattr(config, "rope_theta", 1000000),
rope_traditional=getattr(config, "rope_traditional", False),
tie_word_embeddings=getattr(config, "tie_word_embeddings", True)
)
model = Model(args)
weights = mx.load(model_file)
if args.tie_word_embeddings:
weights.pop("lm_head.weight", None)
model.load_weights(list(weights.items()))
# Ensure model's vocab size matches weights
embed_weight = weights["model.embed_tokens.weight"]
actual_vocab_size = embed_weight.shape[0]
if actual_vocab_size != args.vocab_size:
print(f"Warning: Config vocab_size ({args.vocab_size}) differs from model weights ({actual_vocab_size}). Using model's vocab size.")
args.vocab_size = actual_vocab_size
model.args.vocab_size = actual_vocab_size
model.eval()
return model, tokenizer
class BeamHypotheses:
"""Manages beam search hypotheses and their scores.
Args:
num_beams: Number of beams to maintain
length_penalty: Penalty factor for sequence length
early_stopping: Whether to stop early when good candidates are found
max_length: Maximum allowed sequence length
"""
def __init__(self, num_beams: int, length_penalty: float = 1.0,
early_stopping: bool = False, max_length: Optional[int] = None) -> None:
self.num_beams = num_beams
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.max_length = max_length
self.beams = []
self.worst_score = 1e9
def add(self, sequence: list, score: float) -> None:
"""Add a new hypothesis to the beam."""
heapq.heappush(self.beams, (score, sequence))
def is_done(self, best_score: float, cur_len: int) -> bool:
"""Check if beam search is complete."""
if self.early_stopping == "never":
return False
elif self.early_stopping:
return len(self.beams) >= self.num_beams
else:
if self.max_length and cur_len >= self.max_length:
return True
current_score = best_score / ((cur_len + 1) ** self.length_penalty)
return current_score <= self.worst_score and len(self.beams) >= self.num_beams
class Generator:
"""Text generation using various decoding strategies."""
def __init__(self, model: Model, tokenizer: AutoTokenizer,
max_new_tokens: int = DEFAULT_MAX_TOKENS) -> None:
self.model = model
self.tokenizer = tokenizer
self.max_new_tokens = max_new_tokens
self.eos_token_id = tokenizer.eos_token_id
def _prepare_prompt(self, prompt: str) -> mx.array:
"""Prepare prompt by applying chat template and tokenization."""
messages = [{"role": "system", "content": ""}, {"role": "user", "content": prompt}]
if hasattr(self.tokenizer, 'apply_chat_template'):
formatted_prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
formatted_prompt = f"User: {prompt}\nAssistant:"
tokens = self.tokenizer.encode(formatted_prompt)
return mx.array(tokens)
def clean_response(self, text: str) -> str:
"""
Post-process generated text to extract only the assistant's response.
Handles both special tokens and role labels that might appear in plain text.
"""
text = text.replace('<|im_start|>', '').replace('<|im_end|>', '').strip()
if 'assistant' in text:
_, _, assistant_part = text.rpartition('assistant')
assistant_part = assistant_part.strip()
return assistant_part.split('user')[0].split('system')[0].strip()
return text
def generate_greedy(self, prompt: str) -> str:
"""Generate text using greedy decoding strategy."""
input_ids = self._prepare_prompt(prompt)
generated_ids = input_ids
new_tokens = []
for _ in range(self.max_new_tokens):
outputs = self.model(generated_ids[None, :])[0, -1, :]
mx.eval(outputs)
next_token = mx.argmax(outputs).item()
new_tokens.append(next_token)
generated_ids = mx.concatenate([generated_ids, mx.array([next_token])])
if next_token == self.eos_token_id:
break
full_text = self.tokenizer.decode(generated_ids.tolist(), skip_special_tokens=True)
return self.clean_response(full_text)
def generate_beam(self, prompt: str, num_beams: int = DEFAULT_NUM_BEAMS,
early_stopping: bool = True, temperature: float = 1.0,
length_penalty: float = 0.7) -> list[str]:
"""Generate text using beam search decoding.
Args:
prompt: Input text prompt
num_beams: Number of beams for search
early_stopping: Whether to stop early when good candidates are found
temperature: Sampling temperature
length_penalty: Penalty factor for sequence length
Returns:
List of generated text candidates
"""
input_ids = self._prepare_prompt(prompt)
vocab_size = self.model.args.vocab_size
input_ids = mx.tile(input_ids[None, :], (num_beams, 1))
beam_scores = mx.zeros((num_beams,), dtype=mx.float32)
beam_scores[1:] = -1e9
beam_hyps = BeamHypotheses(num_beams, length_penalty=length_penalty, early_stopping=early_stopping)
finished = mx.zeros((num_beams,), dtype=mx.bool_)
for step in range(self.max_new_tokens):
if mx.all(finished):
break
logits = self.model(input_ids)[:, -1, :]
next_token_scores = nn.log_softmax(logits / temperature, axis=-1)
next_scores = next_token_scores + beam_scores[:, None]
next_scores = mx.where(finished[:, None], -1e9, next_scores)
top_k = min(num_beams * 2, vocab_size)
flat_next_scores = next_scores.reshape(-1)
top_indices = mx.argpartition(-flat_next_scores, top_k - 1)[:top_k]
top_scores = mx.take(flat_next_scores, top_indices)
sorted_indices = mx.argsort(-top_scores)[:num_beams]
beam_indices = top_indices[sorted_indices] // vocab_size
next_tokens = top_indices[sorted_indices] % vocab_size
next_scores = top_scores[sorted_indices]
new_input_ids = []
for i in range(num_beams):
beam_idx = beam_indices[i].item()
new_seq = mx.concatenate([input_ids[beam_idx], mx.array([next_tokens[i].item()])])
new_input_ids.append(new_seq)
input_ids = mx.stack(new_input_ids)
if mx.any(next_tokens == self.eos_token_id):
for i in range(num_beams):
if next_tokens[i] == self.eos_token_id:
adjusted_score = beam_scores[i].item() / ((input_ids.shape[1] + 1) ** length_penalty)
beam_hyps.add(input_ids[i].tolist(), adjusted_score)
finished = (next_tokens == self.eos_token_id) | mx.take(finished, beam_indices)
beam_scores = next_scores
if beam_hyps.is_done(beam_scores.max().item(), input_ids.shape[1]):
break
# Collect all candidate sequences from finished and current beams
finished_beams = beam_hyps.beams
current_beams = []
cur_len = input_ids.shape[1]
for i in range(num_beams):
if not finished[i].item():
raw_score = beam_scores[i].item()
adjusted_score = raw_score / ((cur_len + 1) ** length_penalty)
current_beams.append((adjusted_score, input_ids[i].tolist()))
all_candidates = finished_beams + current_beams
num_candidates = min(DEFAULT_SHOW_CANDIDATES, num_beams)
top_candidates = heapq.nlargest(num_candidates, all_candidates, key=lambda x: x[0])
results = []
for score, seq in top_candidates:
text = self.tokenizer.decode(seq, skip_special_tokens=True)
results.append(self.clean_response(text))
# Fallback if no candidates found
if not results:
best_idx = mx.argmax(beam_scores).item()
best_sequence = input_ids[best_idx].tolist()
text = self.tokenizer.decode(best_sequence, skip_special_tokens=True)
results.append(self.clean_response(text))
return results[:num_candidates]
def main():
try:
print("Loading model from", MODEL_NAME)
model, tokenizer = load(MODEL_NAME)
print("Model loaded successfully!")
except Exception as e:
print(f"Failed to load model: {str(e)}")
return
generator = Generator(model, tokenizer)
while True:
prompt = input("\nEnter your prompt (or press Enter for default, 'q' to quit): ").strip()
if prompt.lower() == 'q':
break
if not prompt:
prompt = DEFAULT_PROMPT
print(f"Using default prompt: {DEFAULT_PROMPT}")
print("\nGenerating responses...")
# Greedy decoding
print("\n--- Greedy Decoding ---")
start_time = time.time()
try:
greedy_result = generator.generate_greedy(prompt)
greedy_time = time.time() - start_time
print(greedy_result)
print(f"Generation time: {greedy_time:.2f} seconds")
except Exception as e:
print(f"Greedy generation failed: {str(e)}")
# Beam search decoding
num_candidates = min(DEFAULT_SHOW_CANDIDATES, DEFAULT_NUM_BEAMS)
print(f"\n--- Beam Search Decoding (num_beams={DEFAULT_NUM_BEAMS}) ---")
start_time = time.time()
try:
beam_results = generator.generate_beam(prompt, num_beams=DEFAULT_NUM_BEAMS)
beam_time = time.time() - start_time
print(f"\nTop {num_candidates} candidates:")
for i, result in enumerate(beam_results, 1):
print(f"\n{i}. {result}")
print(f"\nGeneration time: {beam_time:.2f} seconds")
except Exception as e:
print(f"Beam search generation failed: {str(e)}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment