Created
September 27, 2023 19:59
-
-
Save tos-kamiya/033d016094d0570e514a867fd37c21a9 to your computer and use it in GitHub Desktop.
A command-line translator using Facebook's NLLB LLM (proof of concept)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ref https://zenn.dev/syoyo/articles/9a159ee747835a | |
import sys | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
max_length = 512 | |
# ref https://huggingface.co/facebook/nllb-200-distilled-1.3B | |
# The model was trained with input lengths not exceeding 512 tokens, therefore translating longer sequences might result in quality degradation. | |
# ref https://huggingface.co/facebook/nllb-200-distilled-600M | |
def main(): | |
# model = "facebook/nllb-200-3.3B" | |
model = "facebook/nllb-200-distilled-1.3B" | |
# model = "facebook/nllb-200-distilled-600M" | |
print("** Note: The license of NLLB LLM (https://huggingface.co/facebook/nllb-200-distilled-1.3B) is CC-BY-NC, that is, **NON-COMMERCIAL** ", file=sys.stderr, flush=True) | |
tokenizer = AutoTokenizer.from_pretrained(model) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model) | |
if len(sys.argv) < 2: | |
print("Error: Provide a language description as the first argument.", file=sys.stderr) | |
exit(1) | |
lang_desc = sys.argv[1] | |
lang_candidates = det_lang(lang_desc) | |
if len(lang_candidates) == 0: | |
exit("Error: language not found") | |
if len(lang_candidates) >= 2: | |
print("Error: ambiguous language specification", file=sys.stderr) | |
for item in lang_candidates: | |
print("%s | %s" % item, file=sys.stderr) | |
exit() | |
lang_code = lang_candidates[0][1] | |
text = sys.stdin.read() | |
for line in text.split('\n'): | |
line = line.strip() | |
if not line: | |
print("", flush=True) | |
continue | |
inputs = tokenizer(line, return_tensors="pt") | |
translated_tokens = model.generate( | |
**inputs, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=max_length | |
) | |
len_translated = len(translated_tokens[0]) | |
if len_translated > max_length // 2: | |
print("** Warning: long text, which may result in wrong translation.", file=sys.stderr, flush=True) | |
ret = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
print(ret, flush=True) | |
# the following list was taken from https://github.com/facebookresearch/flores/blob/main/flores200/README.md | |
flores200_lang_list = [tuple(item.split(' | ')) for item in """ | |
Language | FLORES-200 code | |
---|--- | |
Acehnese (Arabic script) | ace_Arab | |
Acehnese (Latin script) | ace_Latn | |
Mesopotamian Arabic | acm_Arab | |
Ta’izzi-Adeni Arabic | acq_Arab | |
Tunisian Arabic | aeb_Arab | |
Afrikaans | afr_Latn | |
South Levantine Arabic | ajp_Arab | |
Akan | aka_Latn | |
Amharic | amh_Ethi | |
North Levantine Arabic | apc_Arab | |
Modern Standard Arabic | arb_Arab | |
Modern Standard Arabic (Romanized) | arb_Latn | |
Najdi Arabic | ars_Arab | |
Moroccan Arabic | ary_Arab | |
Egyptian Arabic | arz_Arab | |
Assamese | asm_Beng | |
Asturian | ast_Latn | |
Awadhi | awa_Deva | |
Central Aymara | ayr_Latn | |
South Azerbaijani | azb_Arab | |
North Azerbaijani | azj_Latn | |
Bashkir | bak_Cyrl | |
Bambara | bam_Latn | |
Balinese | ban_Latn | |
Belarusian | bel_Cyrl | |
Bemba | bem_Latn | |
Bengali | ben_Beng | |
Bhojpuri | bho_Deva | |
Banjar (Arabic script) | bjn_Arab | |
Banjar (Latin script) | bjn_Latn | |
Standard Tibetan | bod_Tibt | |
Bosnian | bos_Latn | |
Buginese | bug_Latn | |
Bulgarian | bul_Cyrl | |
Catalan | cat_Latn | |
Cebuano | ceb_Latn | |
Czech | ces_Latn | |
Chokwe | cjk_Latn | |
Central Kurdish | ckb_Arab | |
Crimean Tatar | crh_Latn | |
Welsh | cym_Latn | |
Danish | dan_Latn | |
German | deu_Latn | |
Southwestern Dinka | dik_Latn | |
Dyula | dyu_Latn | |
Dzongkha | dzo_Tibt | |
Greek | ell_Grek | |
English | eng_Latn | |
Esperanto | epo_Latn | |
Estonian | est_Latn | |
Basque | eus_Latn | |
Ewe | ewe_Latn | |
Faroese | fao_Latn | |
Fijian | fij_Latn | |
Finnish | fin_Latn | |
Fon | fon_Latn | |
French | fra_Latn | |
Friulian | fur_Latn | |
Nigerian Fulfulde | fuv_Latn | |
Scottish Gaelic | gla_Latn | |
Irish | gle_Latn | |
Galician | glg_Latn | |
Guarani | grn_Latn | |
Gujarati | guj_Gujr | |
Haitian Creole | hat_Latn | |
Hausa | hau_Latn | |
Hebrew | heb_Hebr | |
Hindi | hin_Deva | |
Chhattisgarhi | hne_Deva | |
Croatian | hrv_Latn | |
Hungarian | hun_Latn | |
Armenian | hye_Armn | |
Igbo | ibo_Latn | |
Ilocano | ilo_Latn | |
Indonesian | ind_Latn | |
Icelandic | isl_Latn | |
Italian | ita_Latn | |
Javanese | jav_Latn | |
Japanese | jpn_Jpan | |
Kabyle | kab_Latn | |
Jingpho | kac_Latn | |
Kamba | kam_Latn | |
Kannada | kan_Knda | |
Kashmiri (Arabic script) | kas_Arab | |
Kashmiri (Devanagari script) | kas_Deva | |
Georgian | kat_Geor | |
Central Kanuri (Arabic script) | knc_Arab | |
Central Kanuri (Latin script) | knc_Latn | |
Kazakh | kaz_Cyrl | |
Kabiyè | kbp_Latn | |
Kabuverdianu | kea_Latn | |
Khmer | khm_Khmr | |
Kikuyu | kik_Latn | |
Kinyarwanda | kin_Latn | |
Kyrgyz | kir_Cyrl | |
Kimbundu | kmb_Latn | |
Northern Kurdish | kmr_Latn | |
Kikongo | kon_Latn | |
Korean | kor_Hang | |
Lao | lao_Laoo | |
Ligurian | lij_Latn | |
Limburgish | lim_Latn | |
Lingala | lin_Latn | |
Lithuanian | lit_Latn | |
Lombard | lmo_Latn | |
Latgalian | ltg_Latn | |
Luxembourgish | ltz_Latn | |
Luba-Kasai | lua_Latn | |
Ganda | lug_Latn | |
Luo | luo_Latn | |
Mizo | lus_Latn | |
Standard Latvian | lvs_Latn | |
Magahi | mag_Deva | |
Maithili | mai_Deva | |
Malayalam | mal_Mlym | |
Marathi | mar_Deva | |
Minangkabau (Arabic script) | min_Arab | |
Minangkabau (Latin script) | min_Latn | |
Macedonian | mkd_Cyrl | |
Plateau Malagasy | plt_Latn | |
Maltese | mlt_Latn | |
Meitei (Bengali script) | mni_Beng | |
Halh Mongolian | khk_Cyrl | |
Mossi | mos_Latn | |
Maori | mri_Latn | |
Burmese | mya_Mymr | |
Dutch | nld_Latn | |
Norwegian Nynorsk | nno_Latn | |
Norwegian Bokmål | nob_Latn | |
Nepali | npi_Deva | |
Northern Sotho | nso_Latn | |
Nuer | nus_Latn | |
Nyanja | nya_Latn | |
Occitan | oci_Latn | |
West Central Oromo | gaz_Latn | |
Odia | ory_Orya | |
Pangasinan | pag_Latn | |
Eastern Panjabi | pan_Guru | |
Papiamento | pap_Latn | |
Western Persian | pes_Arab | |
Polish | pol_Latn | |
Portuguese | por_Latn | |
Dari | prs_Arab | |
Southern Pashto | pbt_Arab | |
Ayacucho Quechua | quy_Latn | |
Romanian | ron_Latn | |
Rundi | run_Latn | |
Russian | rus_Cyrl | |
Sango | sag_Latn | |
Sanskrit | san_Deva | |
Santali | sat_Olck | |
Sicilian | scn_Latn | |
Shan | shn_Mymr | |
Sinhala | sin_Sinh | |
Slovak | slk_Latn | |
Slovenian | slv_Latn | |
Samoan | smo_Latn | |
Shona | sna_Latn | |
Sindhi | snd_Arab | |
Somali | som_Latn | |
Southern Sotho | sot_Latn | |
Spanish | spa_Latn | |
Tosk Albanian | als_Latn | |
Sardinian | srd_Latn | |
Serbian | srp_Cyrl | |
Swati | ssw_Latn | |
Sundanese | sun_Latn | |
Swedish | swe_Latn | |
Swahili | swh_Latn | |
Silesian | szl_Latn | |
Tamil | tam_Taml | |
Tatar | tat_Cyrl | |
Telugu | tel_Telu | |
Tajik | tgk_Cyrl | |
Tagalog | tgl_Latn | |
Thai | tha_Thai | |
Tigrinya | tir_Ethi | |
Tamasheq (Latin script) | taq_Latn | |
Tamasheq (Tifinagh script) | taq_Tfng | |
Tok Pisin | tpi_Latn | |
Tswana | tsn_Latn | |
Tsonga | tso_Latn | |
Turkmen | tuk_Latn | |
Tumbuka | tum_Latn | |
Turkish | tur_Latn | |
Twi | twi_Latn | |
Central Atlas Tamazight | tzm_Tfng | |
Uyghur | uig_Arab | |
Ukrainian | ukr_Cyrl | |
Umbundu | umb_Latn | |
Urdu | urd_Arab | |
Northern Uzbek | uzn_Latn | |
Venetian | vec_Latn | |
Vietnamese | vie_Latn | |
Waray | war_Latn | |
Wolof | wol_Latn | |
Xhosa | xho_Latn | |
Eastern Yiddish | ydd_Hebr | |
Yoruba | yor_Latn | |
Yue Chinese | yue_Hant | |
Chinese (Simplified) | zho_Hans | |
Chinese (Traditional) | zho_Hant | |
Standard Malay | zsm_Latn | |
Zulu | zul_Latn | |
""".strip().split('\n')[2:]] | |
def det_lang(desc: str) -> str: | |
desc = desc.lower() | |
found = [] | |
for item in flores200_lang_list: | |
name, code = item | |
if name.lower().startswith(desc) or code.lower().startswith(desc): | |
found.append(item) | |
return found | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment