-
-
Save glowinthedark/887e715b23473d0054cf5299a8328f8d to your computer and use it in GitHub Desktop.
| #!/usr/bin/env python3 | |
| # Dependencies | |
| # ============================= | |
| # pip install nltk transformers | |
| import argparse | |
| import sys | |
| from pathlib import Path | |
| # SEE ALSO: https://huggingface.co/docs/transformers/model_doc/nllb | |
| # facebook/nllb-200-3.3B # This is a larger model designed for translation tasks. It has a larger size than the distilled version, which might result in better quality translations but at the cost of slower inference | |
| # facebook/nllb-200-distilled-600M # This model is designed for translation tasks and has been updated recently. It's a distilled version of the larger NLLB-200-3.3B model, which means it retains much of the performance of the larger model but with a smaller size | |
| # facebook/nllb-200-1.3B # This model is another variant of the NLLB-200 series. It's designed for translation tasks and has been updated recently huggingface.co. | |
| # michaelfeil/ct2fast-nllb-200-3.3B # This model is also a variant of the NLLB-200 series. It's designed for translation tasks and has been updated recentl | |
| # facebook/mbart-large-50-many-to-many-mmt # This model is designed for many-to-many multilingual machine translation tasks. It has been updated recently | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| prog='nllb200 translator', | |
| description='''DESCRIPTION: | |
| Translate using https://huggingface.co/docs/transformers/model_doc/nllb.''', | |
| usage=""" | |
| SEE FULL language list: https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200 | |
| | Language | FLORES-200 code | | |
| |------------------------------------|-----------------| | |
| | Catalan | cat_Latn | | |
| | Chinese (Simplified) | zho_Hans | | |
| | Chinese (Traditional) | zho_Hant | | |
| | English | eng_Latn | | |
| | French | fra_Latn | | |
| | German | deu_Latn | | |
| | Portuguese | por_Latn | | |
| | Spanish | spa_Latn | | |
| """, | |
| epilog='Model card: https://huggingface.co/docs/transformers/model_doc/nllb') | |
| parser.add_argument('source', | |
| nargs='?', | |
| action='store', | |
| help='REQUIRED: input file', | |
| default='text.txt') | |
| parser.add_argument('--source-language', '-s', | |
| help='source language', | |
| action='store', | |
| required=True) | |
| parser.add_argument('--target-language', '-t', | |
| help='REQUIRED: target language', | |
| action='store', | |
| required=True) | |
| parser.add_argument('--verbose', '-v', | |
| action='store_true', | |
| help='verbosely list details') | |
| parser.add_argument('--force-overwrite', '-f', | |
| action='store_true', | |
| help='force overwrite') | |
| conf = parser.parse_args(sys.argv[1:]) | |
| print(conf, file=sys.stderr) | |
| path_in = Path(conf.source) | |
| text_source = path_in.read_text(encoding='utf-8') | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from nltk import sent_tokenize | |
| segments = sent_tokenize(text_source, 'english') | |
| print(f'Generated {len(segments)} segments', file=sys.stderr, flush=True) | |
| print('loading model...', end='') | |
| # https://huggingface.co/facebook/nllb-200-3.3B | |
| MODEL = "facebook/nllb-200-3.3B" | |
| print(f'Using model: {MODEL}') | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL, | |
| src_lang=conf.source_language) # BCP-47 code | |
| # MORE LANGUAGE codes: | |
| # https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200 | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL) | |
| print('Ok') | |
| translated = [] | |
| for idx, segment in enumerate(segments): | |
| if not len(segment) or segment[0].isdigit(): | |
| decoded = segment | |
| else: | |
| inputs = tokenizer(segment, return_tensors="pt") | |
| # SEE LANGUAGE CODES in transformers_fb_nllb200distilled600M_text.md | |
| translated_tokens = model.generate( | |
| **inputs, | |
| forced_bos_token_id=tokenizer.lang_code_to_id[conf.target_language], | |
| max_length=1024 | |
| ) | |
| decoded = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
| translated.append(decoded) | |
| print(decoded) | |
| path_out = path_in.with_suffix(f'.{MODEL.split("/")[1]}.{conf.source_language}-{conf.target_language}.txt') | |
| path_out.write_text('\n'.join(translated)) | |
| print(f'Wrote {path_out.absolute()}') |
@meiyou313, there is a better model, check out madlad400 (400+ languages!) — there is also a rust-based CPU-optimized version that can be run via this fork (the official original repo is https://github.com/huggingface/candle, examples/quantized-t5); the fork allows taking an input file as a parameter (uses stdout for output, so the result needs to be redirected via the shell, the full command line would be something like
cargo run --example quantized-t5v \
--release \
--features accelerate -- \
--model-id "jbochi/madlad400-3b-mt" \
--weight-file "model-q4k.gguf"
--target-language es
--file-path input.txt \
--temperature 0 > result.txtquantized-t5v is the modified version of the quantized-t5 in the original repo
all languages supported by madlad-400:
cuz huggingface/transformers#31348
you will get AttributeError: 'NllbTokenizerFast' object has no attribute 'lang_code_to_id' error
you can downgrade transformers to 4.37.0 or replace
forced_bos_token_id=tokenizer.lang_code_to_id[conf.target_language],
to
forced_bos_token_id=tokenizer.convert_tokens_to_ids(conf.target_language),
hi