Skip to content

Instantly share code, notes, and snippets.

@kevmo314
Last active February 24, 2025 21:11
Show Gist options
  • Save kevmo314/294001659324429bae6749062a9003db to your computer and use it in GitHub Desktop.
Save kevmo314/294001659324429bae6749062a9003db to your computer and use it in GitHub Desktop.
Llama inference in 150 lines.

It turns out if you're just doing inference, Llama can be written very concisely. This implementation includes paged attention. Speculative decoding can also be added for another speed boost however it's quite verbose and was left out to keep the implementation cleaner.

Download the Llama files and place them in a directory ./Llama3.2-3B (or whatever flavor of Llama you want).

Your directory structure should look like:

./Llama3.2-3B/consolidated.00.pth
./Llama3.2-3B/tokenizer.model
./Llama3.2-3B/params.json
./llama.py
./tokenizer.py

Then, install some dependencies:

pip install flash-attn tiktoken

Then run the inference script (results from an RTX 4090):

$ python3 llama.py

...

Time taken: 2.74s (295.61 tokens/s)

Some inspiration from llama3-naive.

If you are totally confused about the implementation of forward() like I was, referencing the architecture may be helpful:

architecture

import json
import torch
import torch.nn.functional as F
from flash_attn import flash_attn_with_kvcache
from tokenizer import Tokenizer
import collections
import time
device = 'cuda'
model_name = './Llama3.2-3B'
tokenizer_path = f'{model_name}/tokenizer.model'
tokenizer = Tokenizer(model_path=tokenizer_path)
model = torch.load(f'{model_name}/consolidated.00.pth', map_location=device, mmap=False, weights_only=True)
with open(f'{model_name}/params.json', 'r') as f:
config = json.load(f)
head_dim = config['dim'] // config['n_heads'] # 4096 // 32 = 128
max_seq_len = 256
block_size = 256
stop_tokens = torch.tensor(list(tokenizer.stop_tokens), device=device)
# Precompute freqs cis for rope
zero_to_one_split_into_64_parts = torch.tensor(range(head_dim//2), device=device)/(head_dim//2)
freqs = 1.0 / (torch.tensor(config['rope_theta'], device=device) ** zero_to_one_split_into_64_parts)
freqs_for_each_token = torch.outer(torch.arange(max_seq_len, device=device), freqs)
freqs_cis_max = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
# pair the last dimension into complex numbers
complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
# apply the rotary embedding, flatten back into reals.
return torch.view_as_real(complex * freqs_cis.unsqueeze(2)).flatten(3).type_as(x)
num_blocks = (max_seq_len // block_size) * 16 * config['n_layers']
k_cache = torch.randn(num_blocks, block_size, config['n_kv_heads'], head_dim, device=device, dtype=torch.bfloat16)
v_cache = torch.randn(num_blocks, block_size, config['n_kv_heads'], head_dim, device=device, dtype=torch.bfloat16)
kv_block_free = collections.deque(reversed(range(0, num_blocks, config['n_layers'])))
# Generate next token i.e. do one forward pass of llama
def forward(tokens, pos, block_table, n=1):
bsz, T = tokens.shape
final_embedding = F.embedding(tokens, weight=model['tok_embeddings.weight'])
for layer in range(config['n_layers']):
layer_embedding_norm = F.rms_norm(
final_embedding,
normalized_shape=final_embedding.shape[-1:],
weight=model[f'layers.{layer}.attention_norm.weight'],
eps=config['norm_eps'],
)
q = (layer_embedding_norm @ model[f'layers.{layer}.attention.wq.weight'].T).view(bsz, T, config['n_heads'], head_dim)
k = (layer_embedding_norm @ model[f'layers.{layer}.attention.wk.weight'].T).view(bsz, T, config['n_kv_heads'], head_dim)
v = (layer_embedding_norm @ model[f'layers.{layer}.attention.wv.weight'].T).view(bsz, T, config['n_kv_heads'], head_dim)
freqs = freqs_cis_max[pos + torch.arange(n, device=device).repeat(bsz, 1)]
q, k = apply_rotary_emb(q, freqs), apply_rotary_emb(k, freqs)
stacked_qkv_attention = flash_attn_with_kvcache(
q=q,
k=k,
v=v,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=pos[:, 0],
block_table=block_table + layer,
causal=True,
).view(bsz, T, config['dim'])
embedding_after_edit = final_embedding + torch.matmul(stacked_qkv_attention, model[f'layers.{layer}.attention.wo.weight'].T)
embedding_after_edit_normalized = F.rms_norm(
embedding_after_edit,
normalized_shape=embedding_after_edit.shape[-1:],
weight=model[f'layers.{layer}.ffn_norm.weight'],
eps=config['norm_eps'],
)
w1, w2, w3 = model[f'layers.{layer}.feed_forward.w1.weight'], model[f'layers.{layer}.feed_forward.w2.weight'], model[f'layers.{layer}.feed_forward.w3.weight']
feed_forward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
final_embedding = embedding_after_edit + feed_forward
return torch.argmax(torch.matmul(F.rms_norm(
final_embedding,
normalized_shape=final_embedding.shape[-1:],
weight=model['norm.weight'],
eps=config['norm_eps'],
), model['output.weight'].T), dim=-1)[:, -n:]
completion_requests = [
tokenizer.encode("Do you know the muffin", bos=True, eos=False),
tokenizer.encode("def levenshtein(a, b):\n", bos=True, eos=False),
tokenizer.encode("The definition of surrepetitious is", bos=True, eos=False),
tokenizer.encode("function crossEntropy(x, y) {\n", bos=True, eos=False),
]
t0 = time.time()
token_count = 0
tokens = torch.zeros(len(completion_requests), max_seq_len, dtype=torch.long, device=device)
kv_block_table = -torch.ones(len(completion_requests), (max_seq_len + block_size - 1) // block_size, dtype=torch.int32, device=device)
prediction_pos = torch.zeros(len(completion_requests), 1, dtype=torch.int64, device=device)
for i, t in enumerate(completion_requests):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=device)
for j in range(0, len(t), block_size):
kv_block_table[i, j // block_size] = kv_block_free.pop()
prediction_pos[i, 0] = len(t)
prefill = max(len(t) for t in completion_requests)
while tokens.shape[0] > 0:
if prefill:
output = forward(tokens[:, :prefill], torch.zeros(len(completion_requests), 1, device=device, dtype=torch.int32), kv_block_table, n=prefill)
tokens.scatter_(1, prediction_pos, output)
else:
output = forward(tokens.gather(1, prediction_pos - 1), prediction_pos.type(dtype=torch.int32) - 1, kv_block_table)
tokens.scatter_(1, prediction_pos, output)
prediction_pos += 1
eos_reached = torch.any(torch.isin(output, stop_tokens), dim=-1) | (prediction_pos.squeeze(-1) >= max_seq_len)
token_count += tokens.shape[0]
for i in range(tokens.shape[0]):
if eos_reached[i]:
print(tokenizer.decode(tokens[i, :prediction_pos[i, 0]].tolist()))
kv_block_free.extend(block for block in kv_block_table.unique() if block != -1)
elif kv_block_table[i, prediction_pos[i, 0] // block_size] == -1:
# we need to allocate a new block for the next prediction
kv_block_table[i, prediction_pos[i, 0] // block_size] = kv_block_free.pop()
tokens = tokens[~eos_reached, :]
kv_block_table = kv_block_table[~eos_reached, :]
prediction_pos = prediction_pos[~eos_reached, :]
prefill = None
dt = time.time() - t0
print(f'Time taken: {dt:.2f}s ({(token_count / dt):.2f} tokens/s)')
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from logging import getLogger
from pathlib import Path
from typing import (
AbstractSet,
cast,
Collection,
Dict,
Iterator,
List,
Literal,
Sequence,
TypedDict,
Union,
)
import tiktoken
from tiktoken.load import load_tiktoken_bpe
logger = getLogger(__name__)
Role = Literal["system", "user", "assistant"]
class Message(TypedDict):
role: Role
content: str
Dialog = Sequence[Message]
class Tokenizer:
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
special_tokens: Dict[str, int]
num_reserved_special_tokens = 256
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
def __init__(self, model_path: str):
"""
Initializes the Tokenizer with a Tiktoken model.
Args:
model_path (str): The path to the Tiktoken model file.
"""
assert os.path.isfile(model_path), model_path
mergeable_ranks = load_tiktoken_bpe(model_path)
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [
f"<|reserved_special_token_{i}|>"
for i in range(5, self.num_reserved_special_tokens - 5)
]
self.special_tokens = {
token: num_base_tokens + i for i, token in enumerate(special_tokens)
}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
logger.info(f"Reloaded tiktoken model from {model_path}")
self.n_words: int = self.model.n_vocab
# BOS / EOS token IDs
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
self.pad_id: int = -1
self.stop_tokens = {
self.special_tokens["<|end_of_text|>"],
self.special_tokens["<|eot_id|>"],
}
logger.info(
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
)
def encode(
self,
s: str,
*,
bos: bool,
eos: bool,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = (),
) -> List[int]:
"""
Encodes a string into a list of token IDs.
Args:
s (str): The input string to be encoded.
bos (bool): Whether to prepend the beginning-of-sequence token.
eos (bool): Whether to append the end-of-sequence token.
allowed_tokens ("all"|set[str]): allowed special tokens in string
disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
Returns:
list[int]: A list of token IDs.
By default, setting disallowed_special=() encodes a string by ignoring
special tokens. Specifically:
- Setting `disallowed_special` to () will cause all text corresponding
to special tokens to be encoded as natural text (insteading of raising
an error).
- Setting `allowed_special` to "all" will treat all text corresponding
to special tokens to be encoded as special tokens.
"""
assert type(s) is str
# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
# https://github.com/openai/tiktoken/issues/195
# Here we iterate over subsequences and split if we exceed the limit
# of max consecutive non-whitespace or whitespace characters.
MAX_NO_WHITESPACES_CHARS = 25_000
substrs = (
substr
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
for substr in self._split_whitespaces_or_nonwhitespaces(
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
)
)
t: List[int] = []
for substr in substrs:
t.extend(
self.model.encode(
substr,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
)
if bos:
t.insert(0, self.bos_id)
if eos:
t.append(self.eos_id)
return t
def decode(self, t: Sequence[int]) -> str:
"""
Decodes a list of token IDs into a string.
Args:
t (List[int]): The list of token IDs to be decoded.
Returns:
str: The decoded string.
"""
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t))
@staticmethod
def _split_whitespaces_or_nonwhitespaces(
s: str, max_consecutive_slice_len: int
) -> Iterator[str]:
"""
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
consecutive whitespaces or consecutive non-whitespaces.
"""
current_slice_len = 0
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
slice_start = 0
for i in range(len(s)):
is_now_space = s[i].isspace()
if current_slice_is_space ^ is_now_space:
current_slice_len = 1
current_slice_is_space = is_now_space
else:
current_slice_len += 1
if current_slice_len > max_consecutive_slice_len:
yield s[slice_start:i]
slice_start = i
current_slice_len = 1
yield s[slice_start:]
class ChatFormat:
def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer
def encode_header(self, message: Message) -> List[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens
def encode_message(self, message: Message) -> List[int]:
tokens = self.encode_header(message)
tokens.extend(
self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
)
tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
return tokens
def encode_dialog_prompt(self, dialog: Dialog) -> List[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
for message in dialog:
tokens.extend(self.encode_message(message))
# Add the start of an assistant message for the model to complete.
tokens.extend(self.encode_header({"role": "assistant", "content": ""}))
return tokens
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment