-
-
Save cafaxo/83b795f11c645e217688449dd91147a3 to your computer and use it in GitHub Desktop.
# this file contains a Julia port of https://github.com/karpathy/llama2.c | |
# all credit goes to Andrej Karpathy (author of llama2.c) | |
using LinearAlgebra | |
using StatsBase | |
using Printf | |
# Transformer and RunState structs, and related memory management | |
struct Config | |
dim::Int # transformer dimension | |
hidden_dim::Int # for ffn layers | |
n_layers::Int # number of layers | |
n_heads::Int # number of query heads | |
n_kv_heads::Int # number of key/value heads (can be < query heads because of multiquery) | |
vocab_size::Int # vocabulary size, usually 256 (byte-level) | |
seq_len::Int # max sequence length | |
end | |
read_config(f::IOStream) = Config( | |
read(f, Int32), | |
read(f, Int32), | |
read(f, Int32), | |
read(f, Int32), | |
read(f, Int32), | |
read(f, Int32), | |
read(f, Int32), | |
) | |
@kwdef struct TransformerWeights | |
# token embedding table | |
token_embedding_table::Matrix{Float32} # (vocab_size, dim) | |
# weights for rmsnorms | |
rms_att_weight::Matrix{Float32} # (layer, dim) rmsnorm weights | |
rms_ffn_weight::Matrix{Float32} # (layer, dim) | |
# weights for matmuls | |
wq::Array{Float32,3} # (layer, dim, dim) | |
wk::Array{Float32,3} # (layer, dim, dim) | |
wv::Array{Float32,3} # (layer, dim, dim) | |
wo::Array{Float32,3} # (layer, dim, dim) | |
# weights for ffn | |
w1::Array{Float32,3} # (layer, hidden_dim, dim) | |
w2::Array{Float32,3} # (layer, dim, hidden_dim) | |
w3::Array{Float32,3} # (layer, hidden_dim, dim) | |
# final rmsnorm | |
rms_final_weight::Vector{Float32} # (dim,) | |
# freq_cis for RoPE relatively positional embeddings | |
freq_cis_real::Matrix{Float32} # (seq_len, dim/2) | |
freq_cis_imag::Matrix{Float32} # (seq_len, dim/2) | |
end | |
TransformerWeights(p::Config) = TransformerWeights(; | |
token_embedding_table = zeros(Float32, p.dim, p.vocab_size), | |
rms_att_weight = zeros(Float32, p.dim, p.n_layers), | |
rms_ffn_weight = zeros(Float32, p.dim, p.n_layers), | |
wq = zeros(Float32, p.dim, p.dim, p.n_layers), | |
wk = zeros(Float32, p.dim, p.dim, p.n_layers), | |
wv = zeros(Float32, p.dim, p.dim, p.n_layers), | |
wo = zeros(Float32, p.dim, p.dim, p.n_layers), | |
w1 = zeros(Float32, p.dim, p.hidden_dim, p.n_layers), | |
w2 = zeros(Float32, p.hidden_dim, p.dim, p.n_layers), | |
w3 = zeros(Float32, p.dim, p.hidden_dim, p.n_layers), | |
rms_final_weight = zeros(Float32, p.dim), | |
freq_cis_real = zeros(Float32, (p.dim ÷ p.n_heads) ÷ 2, p.seq_len), | |
freq_cis_imag = zeros(Float32, (p.dim ÷ p.n_heads) ÷ 2, p.seq_len), | |
) | |
function checkpoint_init_weights!(w::TransformerWeights, f::IOStream) | |
read!(f, w.token_embedding_table) | |
read!(f, w.rms_att_weight) | |
read!(f, w.wq) | |
read!(f, w.wk) | |
read!(f, w.wv) | |
read!(f, w.wo) | |
read!(f, w.rms_ffn_weight) | |
read!(f, w.w1) | |
read!(f, w.w2) | |
read!(f, w.w3) | |
read!(f, w.rms_final_weight) | |
read!(f, w.freq_cis_real) | |
read!(f, w.freq_cis_imag) | |
return nothing | |
end | |
@kwdef struct RunState | |
# current wave of activations | |
x::Vector{Float32} # activation at current time stamp (dim,) | |
xb::Vector{Float32} # same, but inside a residual branch (dim,) | |
xb2::Vector{Float32} # an additional buffer just for convenience (dim,) | |
hb::Vector{Float32} # buffer for hidden dimension in the ffn (hidden_dim,) | |
hb2::Vector{Float32} # buffer for hidden dimension in the ffn (hidden_dim,) | |
q::Vector{Float32} # query (dim,) | |
k::Vector{Float32} # key (dim,) | |
v::Vector{Float32} # value (dim,) | |
att::Vector{Float32} # buffer for scores/attention values (seq_len,) | |
logits::Vector{Float32} # output logits | |
# kv cache | |
key_cache::Array{Float32,3} # (layer, seq_len, dim) | |
value_cache::Array{Float32,3} # (layer, seq_len, dim) | |
end | |
RunState(p::Config) = RunState(; | |
x = zeros(Float32, p.dim), | |
xb = zeros(Float32, p.dim), | |
xb2 = zeros(Float32, p.dim), | |
hb = zeros(Float32, p.hidden_dim), | |
hb2 = zeros(Float32, p.hidden_dim), | |
q = zeros(Float32, p.dim), | |
k = zeros(Float32, p.dim), | |
v = zeros(Float32, p.dim), | |
att = zeros(Float32, p.seq_len), | |
logits = zeros(Float32, p.vocab_size), | |
key_cache = zeros(Float32, p.dim, p.seq_len, p.n_layers), | |
value_cache = zeros(Float32, p.dim, p.seq_len, p.n_layers), | |
) | |
function rmsnorm!(o, x, weight) | |
# calculate sum of squares | |
ss = dot(x, x) | |
ss /= length(x) | |
ss += 1f-5 | |
ss = 1f0 / sqrt(ss) | |
# normalize and scale | |
o .= weight .* (ss .* x) | |
return nothing | |
end | |
function softmax!(x) | |
# find max value (for numerical stability) | |
max_val = maximum(x) | |
# exp | |
x .= exp.(x .- max_val) | |
# normalize | |
x ./= sum(x) | |
return nothing | |
end | |
@views function transformer!(token::Int, pos::Int, p::Config, s::RunState, w::TransformerWeights) | |
# a few convenience variables | |
x = s.x | |
dim = p.dim | |
hidden_dim = p.hidden_dim | |
head_size = dim ÷ p.n_heads | |
# copy the token embedding into x | |
copyto!(x, w.token_embedding_table[:, token]) | |
# pluck out the "pos" row of freq_cis_real and freq_cis_imag | |
freq_cis_real_row = w.freq_cis_real[:, pos] | |
freq_cis_imag_row = w.freq_cis_imag[:, pos] | |
# forward all the layers | |
for l in 1:p.n_layers | |
# attention rmsnorm | |
rmsnorm!(s.xb, x, w.rms_att_weight[:, l]) | |
# qkv matmuls for this position | |
mul!(s.q, w.wq[:, :, l]', s.xb) | |
mul!(s.k, w.wk[:, :, l]', s.xb) | |
mul!(s.v, w.wv[:, :, l]', s.xb) | |
# apply RoPE rotation to the q and k vectors for each head | |
for h in 1:p.n_heads | |
# get the q and k vectors for this head | |
q = s.q[((h-1) * head_size + 1):(h * head_size)] | |
k = s.k[((h-1) * head_size + 1):(h * head_size)] | |
# rotate q and k by the freq_cis_real and freq_cis_imag | |
for i in 1:(head_size ÷ 2) | |
q0 = q[2*i-1] | |
q1 = q[2*i] | |
k0 = k[2*i-1] | |
k1 = k[2*i] | |
fcr = freq_cis_real_row[i] | |
fci = freq_cis_imag_row[i] | |
q[2*i-1] = q0 * fcr - q1 * fci | |
q[2*i] = q0 * fci + q1 * fcr | |
k[2*i-1] = k0 * fcr - k1 * fci | |
k[2*i] = k0 * fci + k1 * fcr | |
end | |
end | |
# save key,value at this time step (pos) to our kv cache | |
copyto!(s.key_cache[:, pos, l], s.k) | |
copyto!(s.value_cache[:, pos, l], s.v) | |
# multihead attention. iterate over all heads | |
for h in 1:p.n_heads | |
# get the query vector for this head | |
q = s.q[((h-1) * head_size + 1):(h * head_size)] | |
# iterate over all timesteps, including the current one | |
for t in 1:pos | |
# get the key vector for this head and at this timestep | |
k = s.key_cache[((h-1) * head_size + 1):(h * head_size), t, l] | |
# calculate the attention score as the dot product of q and k | |
score = dot(q, k) / sqrt(Float32(head_size)) | |
# save the score to the attention buffer | |
s.att[t] = score | |
end | |
# softmax the scores to get attention weights, from 0..pos inclusively | |
softmax!(s.att[1:pos]) | |
# weighted sum of the values, store back into xb | |
mul!( | |
s.xb[((h-1) * head_size + 1):(h * head_size)], | |
s.value_cache[((h-1) * head_size + 1):(h * head_size), 1:pos, l], | |
s.att[1:pos], | |
) | |
end | |
# final matmul to get the output of the attention | |
mul!(s.xb2, w.wo[:, :, l]', s.xb) | |
# residual connection back into x | |
x .+= s.xb2 | |
# ffn rmsnorm | |
rmsnorm!(s.xb, x, w.rms_ffn_weight[:, l]) | |
# Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
# first calculate self.w1(x) and self.w3(x) | |
mul!(s.hb, w.w1[:, :, l]', s.xb) | |
mul!(s.hb2, w.w3[:, :, l]', s.xb) | |
# F.silu silu(x)=x*σ(x),where σ(x) is the logistic sigmoid | |
for i in 1:hidden_dim | |
s.hb[i] = s.hb[i] * (1f0 / (1f0 + exp(-s.hb[i]))) | |
end | |
s.hb .*= s.hb2 | |
# final matmul to get the output of the ffn | |
mul!(s.xb, w.w2[:, :, l]', s.hb) | |
# residual connection | |
x .+= s.xb | |
end | |
# final rmsnorm | |
rmsnorm!(x, x, w.rms_final_weight) | |
# classifier into logits | |
mul!(s.logits, w.token_embedding_table', x) | |
return nothing | |
end | |
function main( | |
checkpoint_filename::AbstractString, | |
tokenizer_filename::AbstractString; | |
temperature::Float32 = 0.9f0, | |
) | |
config = nothing | |
weights = nothing | |
# read in the model.bin file | |
open(checkpoint_filename) do file | |
config = read_config(file) | |
weights = TransformerWeights(config) | |
checkpoint_init_weights!(weights, file) | |
end | |
# read in the tokenizer.bin file | |
vocab = Vector{Vector{UInt8}}(undef, config.vocab_size) | |
open(tokenizer_filename) do file | |
for i in 1:config.vocab_size | |
len = read(file, Int32) | |
vocab[i] = read(file, len) | |
end | |
end | |
# create and init the application RunState | |
state = RunState(config) | |
# the current position we are in | |
time_start = time_ns() | |
token = 1 # 1 = BOS token in Llama-2 sentencepiece | |
for pos in 1:config.seq_len | |
# forward the transformer to get logits for the next token | |
transformer!(token, pos, config, state, weights) | |
# sample the next token | |
if temperature == 0f0 | |
# greedy argmax sampling | |
next = argmax(state.logits) | |
else | |
# apply the temperature to the logits | |
state.logits ./= temperature | |
# apply softmax to the logits to get the probabilities for next token | |
softmax!(state.logits) | |
# we now want to sample from this distribution to get the next token | |
next = wsample(1:config.vocab_size, state.logits) | |
end | |
print(String(copy(vocab[next]))) | |
# advance forward | |
token = next | |
end | |
print('\n') | |
# report our achieved tok/s | |
time_end = time_ns() | |
@printf "achieved tok/s: %f\n" config.seq_len / (time_end - time_start)*1e9 | |
return nothing | |
end |
Pretty cool! Thanks for sharing.
Great Work!
Currently, the code cannot run properly because the tokenizer.bin file has been modified. I have modified the code to correspond to the new features (prompt support) in the llama.c file.
Here is the gist link https://gist.github.com/BangBOOM/7f04e081a4d4c151794e0bf430cae27d
Maybe it makes sense to move this to a proper repo? That could make it easier to collaborate and ensure that it keeps working.
I think it would be good to create a repo. The Julia version of llama is much more intelligible to someone who wants to know how llama works. Also, since the original llama2.c is actively updated, we need a repo to track the updates.
I agree. Development will continue at https://github.com/cafaxo/Llama2.jl
I have also opened an issue concerning the tokenizer format change: cafaxo/Llama2.jl#1
awesome!