Skip to content

Instantly share code, notes, and snippets.

@toranb
Last active November 14, 2023 15:56
Show Gist options
  • Save toranb/b37096fee0c9af93d16b0aaa1a9bcdf4 to your computer and use it in GitHub Desktop.
Save toranb/b37096fee0c9af93d16b0aaa1a9bcdf4 to your computer and use it in GitHub Desktop.
The bumblebee fine tuning example with one of the smaller Pytorch pre-trained BERT variants
defmodule Training.Example do
def train() do
Nx.default_backend(EXLA.Backend)
{:ok, spec} =
Bumblebee.load_spec({:hf, "prajjwal1/bert-medium"},
module: Bumblebee.Text.Bert,
architecture: :for_sequence_classification
)
spec = Bumblebee.configure(spec, num_labels: 5)
{:ok, model} = Bumblebee.load_model({:hf, "prajjwal1/bert-medium"}, spec: spec)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-cased"})
# training data
batch_size = 32
sequence_length = 64
train_data =
Whisper.Yelp.load("priv/yelp/train.csv", tokenizer,
batch_size: batch_size,
sequence_length: sequence_length
)
test_data =
Whisper.Yelp.load("priv/yelp/test.csv", tokenizer,
batch_size: batch_size,
sequence_length: sequence_length
)
train_data = Enum.take(train_data, 250)
test_data = Enum.take(test_data, 50)
## fine tune bert
%{model: model, params: params} = model
[{input, _}] = Enum.take(train_data, 1)
Axon.get_output_shape(model, input)
logits_model = Axon.nx(model, & &1.logits)
loss =
&Axon.Losses.categorical_cross_entropy(&1, &2,
reduction: :mean,
from_logits: true,
sparse: true
)
optimizer = Axon.Optimizers.adam(5.0e-5)
accuracy = &Axon.Metrics.accuracy(&1, &2, from_logits: true, sparse: true)
trained_model_state =
logits_model
|> Axon.Loop.trainer(loss, optimizer, log: 1)
|> Axon.Loop.metric(accuracy, "accuracy")
|> Axon.Loop.checkpoint(event: :epoch_completed)
|> Axon.Loop.run(train_data, params, epochs: 3, compiler: EXLA, strict?: false)
:ok
end
end
@toranb
Copy link
Author

toranb commented May 17, 2023

@lorenzosinisi that's more of a bumblebee question as I'm truly not sure what architectures are supported for generation. In my example above you can see I'm using architecture: :for_sequence_classification explicitly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment