This code is in dev mode and not yet finished, it probably won't work but I am using it to learn how to create a transformer from scratch
This gist is shared to help with this tweet https://twitter.com/LorenzoSinisi/status/1652756858459881473
Mix.install(
[
{:nx, "~> 0.5.3"},
{:req, "~> 0.3.6"},
{:kino_bumblebee, "~> 0.3.0"},
{:exla, "~> 0.5.1"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
file_path = Path.absname("./input.txt")
text =
if File.exists?(file_path) do
IO.puts("File loaded from memory: #{file_path}")
File.read!(file_path)
else
IO.puts(
"File loaded from git: https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
)
Req.get!(
"https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
)
end
defmodule Minidecoder do
@chars text |> String.codepoints() |> Enum.uniq() |> Enum.sort()
@vocab_size Enum.count(@chars)
@stoi Enum.reduce(@chars, %{}, fn ch, acc -> Map.put(acc, ch, Enum.count(acc)) end)
@itos Enum.reduce(@stoi, %{}, fn {ch, i}, acc -> Map.put(acc, i, ch) end)
def encode(text) do
text |> String.codepoints() |> Enum.map(&@stoi[&1])
end
def decode(text) do
text |> Enum.map(&@itos[&1]) |> Enum.join()
end
def tensor(text) do
Nx.tensor(encode(text))
end
end
data = Minidecoder.tensor(text)
n = Kernel.round(Nx.size(data) * 0.9)
# take from index 0 till the end
train_data = Nx.slice(data, [0], [n])
# take from index 0 for size - n (to get all until end)
val_data = Nx.slice(data, [n], [Nx.size(data) - n])
{train_data, val_data}
block_size = 8
x = Nx.slice(train_data, [0], [block_size])
y = Nx.slice(train_data, [0], [block_size + 1])
Enum.map(0..(block_size - 1), fn t ->
context = Nx.slice(x, [0], [t + 1])
target = Nx.slice(y, [t + 1], [1])
{Nx.to_list(context) |> inspect(charlists: :as_lists),
Nx.to_list(target) |> inspect(charlists: :as_lists)}
end)
|> Enum.into(Map.new())
|> Enum.reverse()
batch_size = 4
block_size = 8
get_batch = fn split ->
data = if(split == :train, do: train_data, else: val_data)
ix = Nx.random_uniform({batch_size}, 0, Nx.size(data) - block_size)
ix = Nx.to_list(ix)
x = Enum.map(ix, fn i -> Nx.slice(data, [i], [block_size]) end) |> Nx.stack()
y = Enum.map(ix, fn i -> Nx.slice(data, [i + 1], [block_size]) end) |> Nx.stack()
{x, y}
end
get_batch.(:train)
defmodule BigramLanguageModel do
alias Nx, as: N
defstruct [:params, :state, :opts]
def zeroes(x, y) do
Nx.random_uniform({y, y}, 0, 0)
end
def create(vocab_size) do
Nx.random_uniform({vocab_size, vocab_size})
end
def slice_last_token_logits(logits) do
{batch_size, seq_length, vocab_size} = logits.shape
start_indices = [0, seq_length - 1, 0]
lengths = [batch_size, 1, vocab_size]
Nx.slice(logits, start_indices, lengths)
end
def softmax(t, axis) do
exp_t = Nx.exp(t)
sum_exp_t = Nx.sum(exp_t, axes: [axis], keep_axes: true)
Nx.divide(exp_t, sum_exp_t)
end
def generate(model, idx, max_new_tokens) do
Enum.reduce(1..max_new_tokens, idx, fn _, idx ->
{logits, _loss} = forward(model, idx)
logits = slice_last_token_logits(logits)
probs = softmax(logits, -1)
Nx.concatenate([idx, Nx.argmax(probs, axis: -1)], axis: -1)
end)
end
def forward(model, idx, targets \\ nil) do
logits = Nx.take(model, idx)
if is_nil(targets) do
{logits, nil}
else
{b, t, c} = N.shape(logits)
reshaped_logits = N.reshape(logits, {b * t, c})
reshaped_targets = N.reshape(targets, {b * t})
loss =
Axon.Losses.categorical_cross_entropy(reshaped_targets, reshaped_logits,
sparse: true,
from_logits: true,
reduction: :mean
)
{logits, loss}
end
end
end
model = BigramLanguageModel.create(65)
{xb, yb} = get_batch.(:train)
BigramLanguageModel.forward(model, xb, yb)
max_new_tokens = 1
result = BigramLanguageModel.generate(model, xb, max_new_tokens) |> Nx.flatten() |> Nx.to_list()
Minidecoder.decode(result)