-
-
Save glowinthedark/887e715b23473d0054cf5299a8328f8d to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3 | |
# Dependencies | |
# ============================= | |
# pip install nltk transformers | |
import argparse | |
import sys | |
from pathlib import Path | |
# SEE ALSO: https://huggingface.co/docs/transformers/model_doc/nllb | |
# facebook/nllb-200-3.3B # This is a larger model designed for translation tasks. It has a larger size than the distilled version, which might result in better quality translations but at the cost of slower inference | |
# facebook/nllb-200-distilled-600M # This model is designed for translation tasks and has been updated recently. It's a distilled version of the larger NLLB-200-3.3B model, which means it retains much of the performance of the larger model but with a smaller size | |
# facebook/nllb-200-1.3B # This model is another variant of the NLLB-200 series. It's designed for translation tasks and has been updated recently huggingface.co. | |
# michaelfeil/ct2fast-nllb-200-3.3B # This model is also a variant of the NLLB-200 series. It's designed for translation tasks and has been updated recentl | |
# facebook/mbart-large-50-many-to-many-mmt # This model is designed for many-to-many multilingual machine translation tasks. It has been updated recently | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
prog='nllb200 translator', | |
description='''DESCRIPTION: | |
Translate using https://huggingface.co/docs/transformers/model_doc/nllb.''', | |
usage=""" | |
SEE FULL language list: https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200 | |
| Language | FLORES-200 code | | |
|------------------------------------|-----------------| | |
| Catalan | cat_Latn | | |
| Chinese (Simplified) | zho_Hans | | |
| Chinese (Traditional) | zho_Hant | | |
| English | eng_Latn | | |
| French | fra_Latn | | |
| German | deu_Latn | | |
| Portuguese | por_Latn | | |
| Spanish | spa_Latn | | |
""", | |
epilog='Model card: https://huggingface.co/docs/transformers/model_doc/nllb') | |
parser.add_argument('source', | |
nargs='?', | |
action='store', | |
help='REQUIRED: input file', | |
default='text.txt') | |
parser.add_argument('--source-language', '-s', | |
help='source language', | |
action='store', | |
required=True) | |
parser.add_argument('--target-language', '-t', | |
help='REQUIRED: target language', | |
action='store', | |
required=True) | |
parser.add_argument('--verbose', '-v', | |
action='store_true', | |
help='verbosely list details') | |
parser.add_argument('--force-overwrite', '-f', | |
action='store_true', | |
help='force overwrite') | |
conf = parser.parse_args(sys.argv[1:]) | |
print(conf, file=sys.stderr) | |
path_in = Path(conf.source) | |
text_source = path_in.read_text(encoding='utf-8') | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from nltk import sent_tokenize | |
segments = sent_tokenize(text_source, 'english') | |
print(f'Generated {len(segments)} segments', file=sys.stderr, flush=True) | |
print('loading model...', end='') | |
# https://huggingface.co/facebook/nllb-200-3.3B | |
MODEL = "facebook/nllb-200-3.3B" | |
print(f'Using model: {MODEL}') | |
tokenizer = AutoTokenizer.from_pretrained(MODEL, | |
src_lang=conf.source_language) # BCP-47 code | |
# MORE LANGUAGE codes: | |
# https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200 | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL) | |
print('Ok') | |
translated = [] | |
for idx, segment in enumerate(segments): | |
if not len(segment) or segment[0].isdigit(): | |
decoded = segment | |
else: | |
inputs = tokenizer(segment, return_tensors="pt") | |
# SEE LANGUAGE CODES in transformers_fb_nllb200distilled600M_text.md | |
translated_tokens = model.generate( | |
**inputs, | |
forced_bos_token_id=tokenizer.lang_code_to_id[conf.target_language], | |
max_length=1024 | |
) | |
decoded = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
translated.append(decoded) | |
print(decoded) | |
path_out = path_in.with_suffix(f'.{MODEL.split("/")[1]}.{conf.source_language}-{conf.target_language}.txt') | |
path_out.write_text('\n'.join(translated)) | |
print(f'Wrote {path_out.absolute()}') |
@meiyou313, there is a better model, check out madlad400 (400+ languages!) — there is also a rust-based CPU-optimized version that can be run via this fork (the official original repo is https://github.com/huggingface/candle, examples/quantized-t5
); the fork allows taking an input file as a parameter (uses stdout for output, so the result needs to be redirected via the shell, the full command line would be something like
cargo run --example quantized-t5v \
--release \
--features accelerate -- \
--model-id "jbochi/madlad400-3b-mt" \
--weight-file "model-q4k.gguf"
--target-language es
--file-path input.txt \
--temperature 0 > result.txt
quantized-t5v
is the modified version of the quantized-t5
in the original repo
all languages supported by madlad-400:
cuz huggingface/transformers#31348
you will get AttributeError: 'NllbTokenizerFast' object has no attribute 'lang_code_to_id'
error
you can downgrade transformers
to 4.37.0 or replace
forced_bos_token_id=tokenizer.lang_code_to_id[conf.target_language],
to
forced_bos_token_id=tokenizer.convert_tokens_to_ids(conf.target_language),
hi