Mix.install(
[
{:bumblebee, github: "aymanosman/bumblebee", branch: "m2m100-and-nllb"},
{:exla, ">= 0.0.0"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
{:ok, model_info} =
Bumblebee.load_model({:hf, "facebook/nllb-200-distilled-600M"},
architecture: :for_conditional_generation
)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/nllb-200-distilled-600M"})
21:24:16.049 [debug] the following parameters were missing:
* language_modeling_head.logits_bias.bias
21:24:16.049 [debug] the following PyTorch parameters were unused:
* lm_head.weight
{:ok,
%Bumblebee.Text.PreTrainedTokenizer{
native_tokenizer: #Tokenizers.Tokenizer<[
vocab_size: 256204,
byte_fallback: false,
continuing_subword_prefix: nil,
dropout: nil,
end_of_word_suffix: nil,
fuse_unk: true,
model_type: "bpe",
unk_token: "<unk>"
]>,
type: :m2m_100,
special_tokens: %{
mask: "<mask>",
pad: "<pad>",
eos: "</s>",
unk: "<unk>",
sep: "</s>",
cls: "<s>"
},
additional_special_tokens: MapSet.new(["ces_Latn", "est_Latn", "mag_Deva", "yue_Hant",
"lij_Latn", "cym_Latn", "lao_Laoo", "mri_Latn", "bug_Latn", "acq_Arab", "zsm_Latn", "uzn_Latn",
"bem_Latn", "prs_Arab", "hye_Armn", "bul_Cyrl", "jpn_Jpan", "dik_Latn", "lim_Latn", "aka_Latn",
"ind_Latn", "eus_Latn", "kmr_Latn", "gaz_Latn", "ast_Latn", "amh_Ethi", "xho_Latn", "som_Latn",
"bam_Latn", "arz_Arab", "lvs_Latn", "azj_Latn", "ita_Latn", "heb_Hebr", "awa_Deva", "ukr_Cyrl",
"taq_Tfng", "hat_Latn", "nob_Latn", "dan_Latn", "mai_Deva", "pan_Guru", "ayr_Latn", "ban_Latn",
...]),
add_special_tokens: true,
length: nil,
pad_direction: :right,
truncate_direction: :right,
return_attention_mask: true,
return_token_type_ids: true,
return_special_tokens_mask: false,
return_offsets: false,
return_length: false
}}
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "facebook/nllb-200-distilled-600M"})
{:ok,
%Bumblebee.Text.GenerationConfig{
max_new_tokens: nil,
min_new_tokens: nil,
max_length: 200,
min_length: nil,
strategy: %{type: :greedy_search},
decoder_start_token_id: 2,
forced_bos_token_id: nil,
forced_eos_token_id: nil,
forced_token_ids: [],
suppressed_token_ids: [],
no_repeat_ngram_length: nil,
temperature: nil,
bos_token_id: 0,
eos_token_id: 2,
pad_token_id: 1,
extra_config: nil
}}
lang_code = fn string ->
Nx.to_number(Bumblebee.apply_tokenizer(tokenizer, string)["input_ids"][[0, 0]])
end
en = lang_code.("eng_Latn")
fr = lang_code.("fra_Latn")
ja = lang_code.("jpn_Jpan")
fr_serving = Bumblebee.Text.generation(model_info, tokenizer, %{generation_config | forced_bos_token_id: fr})
ja_serving = Bumblebee.Text.generation(model_info, tokenizer, %{generation_config | forced_bos_token_id: ja})
%Nx.Serving{
module: Nx.Serving.Default,
arg: #Function<0.132067394/2 in Bumblebee.Text.TextGeneration.generation/4>,
client_preprocessing: #Function<1.132067394/1 in Bumblebee.Text.TextGeneration.generation/4>,
client_postprocessing: #Function<3.132067394/2 in Bumblebee.Text.TextGeneration.maybe_stream/4>,
streaming: nil,
batch_size: nil,
distributed_postprocessing: &Function.identity/1,
process_options: [batch_keys: [:default]],
defn_options: []
}
Nx.Serving.run(fr_serving, "The bank of the river is beautiful in spring")
%{
results: [
%{
text: "La rive du fleuve est magnifique au printemps.",
token_summary: %{input: 11, output: 15, padding: 0}
}
]
}
Nx.Serving.run(ja_serving, "The bank of the river is beautiful in spring")
%{
results: [
%{text: "川の岸は春に美しい.", token_summary: %{input: 11, output: 11, padding: 0}}
]
}