Created
May 20, 2024 09:43
-
-
Save scturtle/c5727018836565183dd6bcc984d458d4 to your computer and use it in GitHub Desktop.
llama3 in numpy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class ModelArgs: | |
dim = 288 | |
n_layers = 6 | |
n_heads = 6 | |
norm_eps = 1e-6 | |
def build_cos_sin_cache(head_dim, seq_len, base=10000): | |
theta = 1. / (base ** (np.arange(0, head_dim, 2, dtype=np.float32) / head_dim)) | |
seq_idx = np.arange(seq_len, dtype=np.float32) | |
idx_theta = np.outer(seq_idx, theta) | |
return np.cos(idx_theta), np.sin(idx_theta) | |
cos_cached, sin_cached = build_cos_sin_cache(ModelArgs.dim // ModelArgs.n_heads, seq_len=256) | |
def rope(x, start_pos): | |
seq_len = x.shape[1] | |
r = np.zeros_like(x) | |
cos = cos_cached[start_pos: start_pos + seq_len][None, :, None, :] | |
sin = sin_cached[start_pos: start_pos + seq_len][None, :, None, :] | |
r[:, :, :, 0::2] = x[:, :, :, 0::2] * cos - x[:, :, :, 1::2] * sin | |
r[:, :, :, 1::2] = x[:, :, :, 1::2] * cos + x[:, :, :, 0::2] * sin | |
return r | |
def softmax(x): | |
e = np.exp(x - np.max(x, axis=-1, keepdims=True)) | |
return e / np.sum(e, axis=-1, keepdims=True) | |
def silu(x): | |
return x * (1 / (1 + np.exp(-x))) | |
def ffn(x, up_wgt, gate_wgt, down_wgt): | |
return (silu(x @ gate_wgt) * (x @ up_wgt)) @ down_wgt | |
def rmsnorm(x, eps=ModelArgs.norm_eps): | |
return x / np.sqrt(np.mean(x ** 2, axis=-1, keepdims=True) + eps) | |
def attn(x, start_pos, q_wgt, k_wgt, v_wgt, o_wgt, cache): | |
q = x @ q_wgt | |
k = x @ k_wgt | |
v = x @ v_wgt | |
B, L, d = x.shape | |
q = q.reshape((B, L, ModelArgs.n_heads, d // ModelArgs.n_heads)) | |
k = k.reshape((B, L, ModelArgs.n_heads, d // ModelArgs.n_heads)) | |
v = v.reshape((B, L, ModelArgs.n_heads, d // ModelArgs.n_heads)) | |
q = rope(q, start_pos) | |
k = rope(k, start_pos) | |
if cache: | |
k_cache, v_cache = cache | |
k = np.concatenate([k_cache, k], axis=1) | |
v = np.concatenate([v_cache, v], axis=1) | |
cache[:] = [k, v] | |
n_rep = q.shape[-2] // k.shape[-2] | |
k = np.repeat(k, n_rep, axis=-2) | |
v = np.repeat(v, n_rep, axis=-2) | |
x = np.einsum('...qhd,...khd->...hqk', q, k) | |
if L > 1: | |
mask = (1 - np.tri(x.shape[-1], dtype=x.dtype)) * -1e10 | |
else: | |
mask = 0 | |
x = softmax(x * q.shape[-1] ** -0.5 + mask) | |
x = np.einsum('...hqk,...khd->...qhd', x, v) | |
x = x.reshape(x.shape[:-2] + (-1,)) | |
return x @ o_wgt | |
def block(x, start_pos, layer_id, weights, cache): | |
rms_wgt_in = weights[f"model.layers.{layer_id}.input_layernorm.weight"] | |
q_wgt = weights[f"model.layers.{layer_id}.self_attn.q_proj.weight"] | |
k_wgt = weights[f"model.layers.{layer_id}.self_attn.k_proj.weight"] | |
v_wgt = weights[f"model.layers.{layer_id}.self_attn.v_proj.weight"] | |
o_wgt = weights[f"model.layers.{layer_id}.self_attn.o_proj.weight"] | |
rms_wgt_out = weights[f"model.layers.{layer_id}.post_attention_layernorm.weight"] | |
up_wgt = weights[f"model.layers.{layer_id}.mlp.up_proj.weight"] | |
gate_wgt = weights[f"model.layers.{layer_id}.mlp.gate_proj.weight"] | |
down_wgt = weights[f"model.layers.{layer_id}.mlp.down_proj.weight"] | |
norm_x = rmsnorm(x) * rms_wgt_in | |
x += attn(norm_x, start_pos, q_wgt, k_wgt, v_wgt, o_wgt, cache) | |
norm_x = rmsnorm(x) * rms_wgt_out | |
x += ffn(norm_x, up_wgt, gate_wgt, down_wgt) | |
return x | |
def llama3(x, start_pos, weights, caches): | |
x = weights["model.embed_tokens.weight"][x] | |
for i in range(ModelArgs.n_layers): | |
x = block(x, start_pos, layer_id=i, weights=weights, cache=caches[i]) | |
x = rmsnorm(x) * weights["model.norm.weight"] | |
return x @ weights["lm_head.weight"] | |
def main(): | |
from tokenizer import Tokenizer | |
tokenizer = Tokenizer("./tokenizer.model.np") | |
weights = dict(np.load("./stories15M.model.npz")) | |
for k in weights: | |
if k.endswith('proj.weight') or k == "lm_head.weight": | |
weights[k] = weights[k].T | |
prompt = "I have a dream" | |
print(f"{prompt}", end="", flush=True) | |
x = np.array([tokenizer.encode(prompt)]) | |
caches = [[] for _ in range(ModelArgs.n_layers)] | |
for start_pos in range(x.shape[1], 56): | |
start_pos = 0 if not caches[0] else start_pos | |
logits = llama3(x, start_pos, weights, caches) | |
x = np.argmax(logits[:, -1, :], axis=-1, keepdims=True) | |
print(tokenizer.decode(x[0]), end="", flush=True) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment