-
-
Save philippmuench/01444bae00d367ee62a7b60c923d839f to your computer and use it in GitHub Desktop.
LLaMA implemented in R Tensorflow and Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Setup | |
Sys.setenv(CUDA_VISIBLE_DEVICES='') | |
options(tensorflow.extract.warn_tensors_passed_asis = FALSE) | |
library(dplyr, warn.conflicts = FALSE) | |
library(purrr) | |
library(glue) | |
library(envir) | |
library(tensorflow) | |
library(tfautograph) | |
library(keras) | |
reticulate::use_virtualenv("./.venv", required = TRUE) | |
attach_eval({ | |
np <- reticulate::import("numpy", convert = FALSE) | |
import_from(withr, with_options, local_options) | |
import_from(keras$layers, Dense) | |
import_from(tf$compiler$tf2xla$python$xla, dynamic_update_slice) | |
nlist <- \(...) rlang::dots_list(..., .named = TRUE) | |
seq_len0 <- \(x) seq.int(from = 0L, length.out = x) | |
}) | |
precompute_rotarty_freqs <- function(seqlen, feature_dim, theta = 10000) { | |
repeat_each_twice <- function(x) | |
tf$`repeat`(x, 2L, axis = -1L) | |
t <- tf$range(seqlen, dtype = tf$float32) | |
freqs <- tf$range(start = 0, limit = 1, | |
delta = 1 / (feature_dim %/% 2), | |
dtype = tf$float32) | |
tf_assert(tf$size(freqs) == feature_dim %/% 2) | |
freqs <- 1 / (theta ^ freqs) | |
# outer product; (seqlen, head_size/2) | |
freqs <- tf$einsum('a,b->ab', t, freqs) | |
# prep to recycle across head_size axis and | |
# broadcast across batch_size and n_heads axes | |
list(cos = tf$cos(freqs), | |
sin = tf$sin(freqs)) |> | |
lapply(repeat_each_twice) |> | |
lapply(\(m) m[tf$newaxis, , tf$newaxis, ]) # (1, seqlen, 1, head_size) | |
} | |
apply_rotary_embedding <- function(x, freqs) { | |
rotate_every_two <- function(x) { | |
x1 <- x[all_dims(), `::2`] | |
x2 <- x[all_dims(), `2::2`] | |
x_ <- tf$stack(list(-x2, x1), axis = -1L) | |
tf$reshape(x_, tf$shape(x)) | |
} | |
(x * freqs$cos) + (rotate_every_two(x) * freqs$sin) | |
} | |
make_mask <- function(seqlen, position_index = 0L, dtype = k_floatx()) { | |
x <- tf$range(seqlen) | |
i <- x[, tf$newaxis] + position_index | |
j <- x[tf$newaxis, ] | |
mask <- tf$where(i < j, | |
tf$constant(-Inf, dtype = dtype), | |
tf$constant(0, dtype = dtype)) | |
mask[tf$newaxis, tf$newaxis, , ] # (1, 1, seqlen_q, seqlen_q) | |
} | |
RMSNorm(keras$layers$Layer) %py_class% { | |
initialize <- | |
function(eps = 1e-6, ..., block_id = NULL, feeds_into = NULL) { | |
super$initialize(...) | |
self$eps <- eps | |
self$block_id <- block_id | |
self$feeds_into <- feeds_into | |
} | |
build <- function(input_shape) { | |
# input_shape == (batch_size, seqlen, params$dim) | |
# self$w will broadcast over batch_size and seqlen dims. | |
# w_shape == (1, 1, params$dim) | |
w_shape <- rep(1L, length(input_shape)) | |
w_shape[length(input_shape)] <- as.integer(input_shape) |> tail(1L) | |
# helper that will load | |
# the pretrained-weights if we supplied `block_id` and `feeds_into` | |
import_from({self}, block_id, feeds_into) | |
initializer <- if (is.null(self$block_id)) | |
"ones" | |
else if (block_id >=0) { | |
\(...) weights_path("7B/layers.{block_id}.{feeds_into}_norm.weight.npy") |> | |
np$load() |> np$expand_dims(0:1) | |
} else if(block_id == -1) | |
# load weights for the final output norm, which is not part of a TransformerBlock | |
\(...) weights_path("7B/norm.weight.npy") |> | |
np$load() |> np$expand_dims(0:1) | |
self$w <- self$add_weight(shape = w_shape, | |
initializer = initializer, | |
trainable = TRUE) | |
} | |
rrms <- function(x) { | |
# reciprocal root mean square along the last axis | |
x %>% | |
tf$math$square() %>% | |
tf$reduce_mean(axis = -1L, keepdims = TRUE) %>% | |
tf$math$add(self$eps) %>% # for numerical stability | |
tf$math$rsqrt() | |
} | |
call <- function(x) { | |
x * self$rrms(x) * self$w | |
} | |
} | |
FeedForward(keras$layers$Layer) %py_class% { | |
initialize <- function(hidden_dim, multiple_of = 256L, ..., block_id = NULL) { | |
super$initialize() | |
if(!is.null(multiple_of)) { | |
hidden_dim <- hidden_dim %>% | |
{ as.integer( . * (2/3)) } %>% | |
{ (. + multiple_of - 1) %/% multiple_of } %>% | |
{ . * multiple_of } | |
} | |
self$hidden_dim <- hidden_dim | |
self$block_id <- block_id | |
} | |
build <- function(input_shape) { | |
output_dim <- input_shape |> as.integer() |> tail(1) | |
load_weight <- NULL | |
if(!is.null(self$block_id)) | |
load_weight <- \(name) \(...) np$load(weights_path( | |
"7B/layers.{self$block_id}.feed_forward.{name}.weight.npy"))$`T` | |
self$w1 <- Dense(self$hidden_dim, use_bias = FALSE, | |
kernel_initializer = load_weight("w1")) | |
self$w2 <- Dense(output_dim, use_bias = FALSE, | |
kernel_initializer = load_weight("w2")) | |
self$w3 <- Dense(self$hidden_dim, use_bias = FALSE, | |
kernel_initializer = load_weight("w3")) | |
super$build(input_shape) | |
} | |
call <- function(x) { | |
import_from({self}, w1, w2, w3) | |
import_from(tf$nn, silu) | |
x %>% | |
{ silu(w1(.)) * w3(.) } %>% # SwiGLU | |
w2() | |
} | |
} | |
Attention(keras$layers$Layer) %py_class% { | |
initialize <- function(head_size, n_heads, ..., block_id = NULL) { | |
super$initialize(...) | |
self$head_size <- head_size | |
self$n_heads <- n_heads | |
if (is.null(block_id)) | |
load_weight <- function(name) NULL | |
else | |
load_weight <- \(name) \(...) np$load(weights_path( | |
"7B/layers.{block_id}.attention.{name}.weight.npy"))$`T` | |
Dense <- function(name) keras$layers$Dense( | |
units = n_heads * head_size, | |
use_bias = FALSE, | |
kernel_initializer = load_weight(name) | |
) | |
self$wq <- Dense("wq") | |
self$wk <- Dense("wk") | |
self$wv <- Dense("wv") | |
self$wo <- Dense("wo") | |
} | |
call <- function(x, ..., | |
freqs = NULL, | |
cache = NULL, | |
cache_index = NULL, | |
mask = NULL) { | |
c(batch_size, seqlen_q, n_features) %<-% tf$unstack(tf$shape(x)) | |
seqlen_k <- seqlen_v <- cache_index + seqlen_q | |
split_heads_shape <- c(batch_size, seqlen_q, self$n_heads, self$head_size) | |
q <- x |> self$wq() |> tf$reshape(split_heads_shape) | |
k <- x |> self$wk() |> tf$reshape(split_heads_shape) | |
v <- x |> self$wv() |> tf$reshape(split_heads_shape) | |
q %<>% apply_rotary_embedding(freqs) # (bsz, seqlen_q, n_heads, head_size) | |
k %<>% apply_rotary_embedding(freqs) # (bsz, seqlen_q, n_heads, head_size) | |
if(!is.null(cache)) { | |
# append k,v to respective caches; fetch full k,v from cache | |
cache$k %<>% dynamic_update_slice(k, c(0L, cache_index, 0L, 0L)) | |
cache$v %<>% dynamic_update_slice(v, c(0L, cache_index, 0L, 0L)) | |
k <- cache$k[, NA:seqlen_k, , ] | |
v <- cache$v[, NA:seqlen_v, , ] | |
} | |
v <- tf$transpose(v, c(0L, 2L, 1L, 3L)) # (bsz, n_heads, seqlen_v, head_size) | |
q <- tf$transpose(q, c(0L, 2L, 1L, 3L)) # (bsz, n_heads, seqlen_q, head_size) | |
k <- tf$transpose(k, c(0L, 2L, 3L, 1L)) # (bsz, n_heads, head_size, seqlen_k) | |
scores <- (q %*% k) / sqrt(self$head_size) # (bsz, n_heads, seqlen_q, seqlen_k) | |
# apply causal mask, so the model can't "look ahead" during training | |
if (!is.null(mask)) | |
scores %<>% { . + mask } | |
scores <- tf$nn$softmax(scores, axis = -1L) | |
# adjust values tensor with attention scores | |
# scores (bsz, n_heads, seqlen_q, seqlen_k) | |
# v (bsz, n_heads, seqlen_v, head_size) | |
output <- scores %*% v # (bsz, n_heads, seqlen_q, head_size) | |
# combine heads back into a single features dim, | |
# so Attention output_shape==input_shape | |
# (needed so that you can add residuals in TransformerBlock) | |
output <- output |> | |
tf$transpose(c(0L, 2L, 1L, 3L)) |> # (bsz, seqlen_q, n_heads, head_size) | |
tf$reshape(c(batch_size, seqlen_q, # (bsz, seqlen_q, n_heads * head_size) | |
self$n_heads * self$head_size)) | |
# one more trainable linear projection for good luck | |
output <- self$wo(output) # (bsz, seqlen_q, n_heads * head_size) | |
if(is.null(cache)) | |
output | |
else | |
list(output, cache) | |
} | |
} | |
TransformerBlock(keras$layers$Layer) %py_class% { | |
initialize <- function(attn_head_size, attn_n_heads, | |
norm_eps = k_epsilon(), ..., | |
block_id = NULL) { | |
super$initialize(...) | |
self$attention <- Attention(attn_head_size, attn_n_heads, | |
block_id = block_id) | |
self$feed_forward <- FeedForward( | |
hidden_dim = 4 * attn_head_size * attn_n_heads, | |
block_id = block_id) | |
self$attention_norm <- RMSNorm(eps = norm_eps, block_id = block_id, | |
feeds_into = "attention") | |
self$feed_forward_norm <- RMSNorm(eps = norm_eps, block_id = block_id, | |
feeds_into = "ffn") | |
} | |
call <- function(x, ..., cache = NULL) { | |
# norm and attention | |
x2 <- x |> | |
self$attention_norm() |> | |
self$attention(..., cache = cache) | |
# maybe unpack cache returned by Attention | |
if(!is.null(cache)) | |
c(x2, cache) %<-% x2 | |
x <- x + x2 # add residual | |
# norm and swiglu projection | |
x2 <- x %>% | |
self$feed_forward_norm() %>% | |
self$feed_forward() | |
x <- x + x2 # residual again | |
if(is.null(cache)) x else list(x, cache) | |
} | |
} | |
TransformerDecoder(keras$Model) %py_class% { | |
initialize <- function(vocab_size, n_blocks, n_heads, head_size, norm_eps) { | |
super$initialize() | |
self$head_size <- head_size | |
self$n_heads <- n_heads | |
self$tok_embeddings <- keras$layers$Embedding( | |
input_dim = vocab_size, | |
output_dim = n_heads*head_size, | |
embeddings_initializer = | |
\(...) np$load(weights_path("7B/tok_embeddings.weight.npy"))) | |
self$blocks <- lapply(seq_len0(n_blocks), function(block_id) { | |
TransformerBlock(attn_head_size = head_size, | |
attn_n_heads = n_heads, | |
norm_eps = norm_eps, | |
block_id = block_id) | |
}) | |
self$norm <- RMSNorm(block_id = -1, eps = norm_eps) | |
self$output_proj <- Dense( | |
vocab_size, use_bias = FALSE, | |
kernel_initializer = \(...) | |
np$load(weights_path("7B/output.weight.npy"))$`T`) | |
self$freqs <- precompute_rotarty_freqs(feature_dim = head_size, | |
seqlen = 2048L) | |
} | |
call <- function(tokens) { | |
c(bsz, seqlen) %<-% tf$unstack(tf$shape(tokens)) | |
freqs <- self$freqs |> lapply(\(f) f[, NA:seqlen, , ]) | |
mask <- make_mask(seqlen) | |
x <- tokens |> | |
self$tok_embeddings() | |
for (block in self$blocks) | |
x <- block(x, freqs = freqs, mask = mask) | |
local_options(c(tensorflow.extract.warn_negatives_pythonic = FALSE)) | |
x |> | |
self$norm() |> | |
_[, -1, ] |> | |
self$output_proj() | |
} | |
call_with_cache <- function(tokens, cache, position) { | |
c(batch_size, seqlen) %<-% tf$unstack(tf$shape(tokens)) | |
# Sanity check: after the initial seeding of cache with the prompt, we | |
# should only be running inference on one token at a time. | |
tf_assert(position == 0 | seqlen == 1) | |
if(is.numeric(position) && position == 0L) { | |
# initial cache seeding | |
mask <- make_mask(seqlen) | |
freqs <- self$freqs |> lapply(\(f) f[, NA:seqlen, , ]) | |
} else { | |
# inference with one token | |
position %<>% as_tensor(dtype = "int32") | |
freqs <- self$freqs |> lapply(\(f) f[, position, , ]) | |
mask <- NULL | |
} | |
blocks <- self$blocks | |
stopifnot(is.list(cache), length(cache) == length(blocks)) | |
x <- tokens |> | |
self$tok_embeddings() | |
for (i in seq_along(blocks)) { | |
c(x, cache[[i]]) %<-% blocks[[i]](x, cache = cache[[i]], | |
cache_index = position, | |
freqs = freqs, | |
mask = mask) | |
} | |
local_options(c(tensorflow.extract.warn_negatives_pythonic = FALSE)) | |
output <- x |> | |
self$norm() |> | |
_[,-1,] |> | |
self$output_proj() | |
list(output, cache) | |
} | |
.make_cache <- function(prompt_tokens, max_seqlen = 2048L) { | |
c(batch_size, seqlen) %<-% tf$unstack(tf$shape(prompt_tokens)) | |
import_from({self}, head_size, n_heads) | |
max_seqlen <- min(max_seqlen + seqlen, 2048L) | |
cache_shape <- c(batch_size, max_seqlen, n_heads, head_size) | |
cache <- lapply(seq_along(self$blocks), \(.) { | |
list(k = tf$zeros(cache_shape), v = tf$zeros(cache_shape)) | |
}) | |
tokens_with_preallocated_space <- | |
tf$zeros(c(batch_size, max_seqlen), dtype = "int32") |> | |
dynamic_update_slice(update = prompt_tokens, indices = c(0L, 0L)) | |
# run first forward pass to seed cache with initial prompt | |
# return (propmt_tokens, next_token_probs, cache) | |
c(tokens_with_preallocated_space, | |
self$call_with_cache(prompt_tokens, cache = cache, position = 0L)) | |
} | |
private$sampler_fn <- \(logits) logits |> | |
tf$argmax(axis = -1L, output_type = "int32") |> | |
tf$expand_dims(-1L) | |
sampler %<-active% function(fn) { | |
if(missing(fn)) | |
private$sampler_fn | |
else | |
private$sampler_fn <- fn | |
} | |
generate <- function(prompt, max_len = 20L) { | |
max_len %<>% as_tensor("int32") | |
prompt %<>% as_tensor() | |
# accept either tokens or a string | |
if (prompt$dtype$name == "string") { | |
if(length(dim(prompt)) == 0) # ensure a batch dim | |
prompt %<>% .[tf$newaxis] | |
tokens <- tokenizer$tokenize(prompt)$to_tensor() | |
} else { | |
tokens <- prompt | |
if(length(dim(prompt)) == 1) # ensure a batch dim | |
tokens %<>% .[tf$newaxis, ] | |
} | |
c(batch_size, initial_prompt_len) %<-% tf$unstack(tf$shape(tokens)) | |
max_seqlen <- min(max_len + initial_prompt_len, 2048L) | |
c(tokens, next_token_probs, cache) %<-% self$.make_cache(tokens, max_len) | |
i <- initial_prompt_len | |
autograph({ | |
# enable `if` and `for` to accept tensors | |
for (i in tf$range(initial_prompt_len, max_seqlen, dtype = "int32")) { | |
next_token <- self$sampler(next_token_probs) | |
tokens %<>% dynamic_update_slice(next_token, c(0L, i)) | |
if (any(next_token == 2L)) | |
break # end-of-sequence token | |
c(next_token_probs, cache) %<-% | |
self$call_with_cache(next_token, cache, i) | |
} | |
}) | |
tokens %<>% .[, NA:(i+1)] # drop unused preallocated space | |
if(prompt$dtype$name == "string") | |
# return string if supplied a string | |
tokenizer$detokenize(tokens) | |
else | |
tokens | |
} | |
} | |
# ---- load | |
weights_path <- function(rel_path) { | |
normalizePath( | |
file.path( | |
"~/github/facebookresearch/llama/weights/LLaMA/", | |
glue::glue(rel_path, .envir = parent.frame()) | |
), | |
mustWork = TRUE | |
) | |
} | |
params <- jsonlite::read_json(weights_path("7B/params.json")) | |
tf_text <- reticulate::import("tensorflow_text") | |
tokenizer_path <- weights_path("tokenizer.model") | |
tokenizer <- tf_text$SentencepieceTokenizer( | |
tf$io$gfile$GFile(tokenizer_path, "rb")$read(), | |
add_bos = TRUE, add_eos = FALSE, | |
) | |
llama <- TransformerDecoder(vocab_size = tokenizer$vocab_size(), | |
n_blocks = params$n_layers, | |
n_heads = params$n_heads, | |
head_size = params$dim %/% params$n_heads, | |
norm_eps = params$norm_eps) | |
prompt <- "The best way to attract bees" | |
test_generate <- function() { | |
prompt |> | |
tokenizer$tokenize() |> | |
llama$generate(as_tensor(17L)) |> | |
tokenizer$detokenize() |> | |
as.character() |> | |
strwrap(60) |> writeLines() | |
} | |
test_generate() | |
## expected output with the argmax() sampler: | |
# The best way to attract bees to your garden is to plant a | |
# variety of flowers that bloom at different times. | |
# Timings on CPU: | |
print(system.time(test_generate())) | |
# user system elapsed | |
# 99.562 0.149 89.057 | |
# Compile to XLA | |
llama$generate %<>% tf_function(jit_compile = TRUE) | |
# First call includes tracing time | |
print(system.time(test_generate())) | |
# user system elapsed | |
# 64.944 0.809 55.314 | |
# Second call is pure graph mode | |
print(system.time(generate())) | |
# user system elapsed | |
# 28.754 0.120 18.453 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment