-
-
Save manniru/2e0cb40b34cf5413eb5d39de47bb4e22 to your computer and use it in GitHub Desktop.
Text translation with facebook/nllb-200-3.3B model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Dependencies | |
# ============================= | |
# pip install nltk transformers | |
import argparse | |
import sys | |
from pathlib import Path | |
# SEE ALSO: https://huggingface.co/docs/transformers/model_doc/nllb | |
# facebook/nllb-200-3.3B # This is a larger model designed for translation tasks. It has a larger size than the distilled version, which might result in better quality translations but at the cost of slower inference | |
# facebook/nllb-200-distilled-600M # This model is designed for translation tasks and has been updated recently. It's a distilled version of the larger NLLB-200-3.3B model, which means it retains much of the performance of the larger model but with a smaller size | |
# facebook/nllb-200-1.3B # This model is another variant of the NLLB-200 series. It's designed for translation tasks and has been updated recently huggingface.co. | |
# michaelfeil/ct2fast-nllb-200-3.3B # This model is also a variant of the NLLB-200 series. It's designed for translation tasks and has been updated recentl | |
# facebook/mbart-large-50-many-to-many-mmt # This model is designed for many-to-many multilingual machine translation tasks. It has been updated recently | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
prog='nllb200 translator', | |
description='''DESCRIPTION: | |
Translate using https://huggingface.co/docs/transformers/model_doc/nllb.''', | |
usage=""" | |
SEE FULL language list: https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200 | |
| Language | FLORES-200 code | | |
|------------------------------------|-----------------| | |
| Catalan | cat_Latn | | |
| Chinese (Simplified) | zho_Hans | | |
| Chinese (Traditional) | zho_Hant | | |
| English | eng_Latn | | |
| French | fra_Latn | | |
| German | deu_Latn | | |
| Portuguese | por_Latn | | |
| Spanish | spa_Latn | | |
""", | |
epilog='Model card: https://huggingface.co/docs/transformers/model_doc/nllb') | |
parser.add_argument('source', | |
nargs='?', | |
action='store', | |
help='REQUIRED: input file', | |
default='text.txt') | |
parser.add_argument('--source-language', '-s', | |
help='source language', | |
action='store', | |
required=True) | |
parser.add_argument('--target-language', '-t', | |
help='REQUIRED: target language', | |
action='store', | |
required=True) | |
parser.add_argument('--verbose', '-v', | |
action='store_true', | |
help='verbosely list details') | |
parser.add_argument('--force-overwrite', '-f', | |
action='store_true', | |
help='force overwrite') | |
conf = parser.parse_args(sys.argv[1:]) | |
print(conf, file=sys.stderr) | |
path_in = Path(conf.source) | |
text_source = path_in.read_text(encoding='utf-8') | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from nltk import sent_tokenize | |
segments = sent_tokenize(text_source, 'english') | |
print(f'Generated {len(segments)} segments', file=sys.stderr, flush=True) | |
print('loading model...', end='') | |
# https://huggingface.co/facebook/nllb-200-3.3B | |
MODEL = "facebook/nllb-200-3.3B" | |
print(f'Using model: {MODEL}') | |
tokenizer = AutoTokenizer.from_pretrained(MODEL, | |
src_lang=conf.source_language) # BCP-47 code | |
# MORE LANGUAGE codes: | |
# https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200 | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL) | |
print('Ok') | |
translated = [] | |
for idx, segment in enumerate(segments): | |
if not len(segment) or segment[0].isdigit(): | |
decoded = segment | |
else: | |
inputs = tokenizer(segment, return_tensors="pt") | |
# SEE LANGUAGE CODES in transformers_fb_nllb200distilled600M_text.md | |
translated_tokens = model.generate( | |
**inputs, | |
forced_bos_token_id=tokenizer.lang_code_to_id[conf.target_language], | |
max_length=1024 | |
) | |
decoded = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
translated.append(decoded) | |
print(decoded) | |
path_out = path_in.with_suffix(f'.{MODEL.split("/")[1]}.{conf.source_language}-{conf.target_language}.txt') | |
path_out.write_text('\n'.join(translated)) | |
print(f'Wrote {path_out.absolute()}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment