Skip to content

Instantly share code, notes, and snippets.

@aymanosman
Created August 14, 2024 20:25
Show Gist options
  • Save aymanosman/00b525b9a2a430870c247321ddbafb7b to your computer and use it in GitHub Desktop.
Save aymanosman/00b525b9a2a430870c247321ddbafb7b to your computer and use it in GitHub Desktop.
NLLB in Bumblebee demo

Try NLLB

Mix.install(
  [
    {:bumblebee, github: "aymanosman/bumblebee", branch: "m2m100-and-nllb"},
    {:exla, ">= 0.0.0"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Section

{:ok, model_info} =
  Bumblebee.load_model({:hf, "facebook/nllb-200-distilled-600M"},
    architecture: :for_conditional_generation
  )
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/nllb-200-distilled-600M"})

21:24:16.049 [debug] the following parameters were missing:

  * language_modeling_head.logits_bias.bias


21:24:16.049 [debug] the following PyTorch parameters were unused:

  * lm_head.weight


{:ok,
 %Bumblebee.Text.PreTrainedTokenizer{
   native_tokenizer: #Tokenizers.Tokenizer<[
     vocab_size: 256204,
     byte_fallback: false,
     continuing_subword_prefix: nil,
     dropout: nil,
     end_of_word_suffix: nil,
     fuse_unk: true,
     model_type: "bpe",
     unk_token: "<unk>"
   ]>,
   type: :m2m_100,
   special_tokens: %{
     mask: "<mask>",
     pad: "<pad>",
     eos: "</s>",
     unk: "<unk>",
     sep: "</s>",
     cls: "<s>"
   },
   additional_special_tokens: MapSet.new(["ces_Latn", "est_Latn", "mag_Deva", "yue_Hant",
    "lij_Latn", "cym_Latn", "lao_Laoo", "mri_Latn", "bug_Latn", "acq_Arab", "zsm_Latn", "uzn_Latn",
    "bem_Latn", "prs_Arab", "hye_Armn", "bul_Cyrl", "jpn_Jpan", "dik_Latn", "lim_Latn", "aka_Latn",
    "ind_Latn", "eus_Latn", "kmr_Latn", "gaz_Latn", "ast_Latn", "amh_Ethi", "xho_Latn", "som_Latn",
    "bam_Latn", "arz_Arab", "lvs_Latn", "azj_Latn", "ita_Latn", "heb_Hebr", "awa_Deva", "ukr_Cyrl",
    "taq_Tfng", "hat_Latn", "nob_Latn", "dan_Latn", "mai_Deva", "pan_Guru", "ayr_Latn", "ban_Latn",
    ...]),
   add_special_tokens: true,
   length: nil,
   pad_direction: :right,
   truncate_direction: :right,
   return_attention_mask: true,
   return_token_type_ids: true,
   return_special_tokens_mask: false,
   return_offsets: false,
   return_length: false
 }}
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "facebook/nllb-200-distilled-600M"})
{:ok,
 %Bumblebee.Text.GenerationConfig{
   max_new_tokens: nil,
   min_new_tokens: nil,
   max_length: 200,
   min_length: nil,
   strategy: %{type: :greedy_search},
   decoder_start_token_id: 2,
   forced_bos_token_id: nil,
   forced_eos_token_id: nil,
   forced_token_ids: [],
   suppressed_token_ids: [],
   no_repeat_ngram_length: nil,
   temperature: nil,
   bos_token_id: 0,
   eos_token_id: 2,
   pad_token_id: 1,
   extra_config: nil
 }}
lang_code = fn string ->
  Nx.to_number(Bumblebee.apply_tokenizer(tokenizer, string)["input_ids"][[0, 0]])
end
en = lang_code.("eng_Latn")
fr = lang_code.("fra_Latn")
ja = lang_code.("jpn_Jpan")

fr_serving = Bumblebee.Text.generation(model_info, tokenizer, %{generation_config | forced_bos_token_id: fr})
ja_serving = Bumblebee.Text.generation(model_info, tokenizer, %{generation_config | forced_bos_token_id: ja})
%Nx.Serving{
  module: Nx.Serving.Default,
  arg: #Function<0.132067394/2 in Bumblebee.Text.TextGeneration.generation/4>,
  client_preprocessing: #Function<1.132067394/1 in Bumblebee.Text.TextGeneration.generation/4>,
  client_postprocessing: #Function<3.132067394/2 in Bumblebee.Text.TextGeneration.maybe_stream/4>,
  streaming: nil,
  batch_size: nil,
  distributed_postprocessing: &Function.identity/1,
  process_options: [batch_keys: [:default]],
  defn_options: []
}
Nx.Serving.run(fr_serving, "The bank of the river is beautiful in spring")
%{
  results: [
    %{
      text: "La rive du fleuve est magnifique au printemps.",
      token_summary: %{input: 11, output: 15, padding: 0}
    }
  ]
}
Nx.Serving.run(ja_serving, "The bank of the river is beautiful in spring")
%{
  results: [
    %{text: "川の岸は春に美しい.", token_summary: %{input: 11, output: 11, padding: 0}}
  ]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment