Created
November 10, 2023 15:05
-
-
Save ArthurZucker/159dedfcb908467e5f484cf1c143155e to your computer and use it in GitHub Desktop.
Script to automatically convert and upload marian models, checking new results vs previous
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
# conda create -n 4.29 python==3.9 | |
# source activate 4.29 | |
# pip install transformers==4.29.2 | |
# pip install torch accelerate sentencepiece tokenizers colorama sacremoses googletrans | |
# conda create -n 4.34 python==3.9 | |
# source activate 4.34 | |
# pip install transformers==4.34 | |
# pip install torch accelerate sentencepiece tokenizers colorama sacremoses googletrans | |
# Define color codes | |
RED='\033[0;31m' | |
GREEN='\033[0;32m' | |
YELLOW='\033[1;33m' | |
NC='\033[0m' # No Color | |
# Define emojis | |
SMILEY='😊' | |
THUMBS_UP='👍' | |
WARNING='⚠️' | |
# Valid languages | |
LANGUAGES=( | |
'af' 'sq' 'am' 'ar' 'hy' 'az' 'eu' 'be' 'bn' 'bs' 'bg' 'ca' 'ceb' 'ny' 'zh-cn' 'zh-tw' 'co' 'hr' 'cs' 'da' 'nl' 'en' 'eo' 'et' 'tl' 'fi' 'fr' 'fy' 'gl' 'ka' 'de' 'el' 'gu' 'ht' 'ha' 'haw' 'iw' 'he' 'hi' 'hmn' 'hu' 'is' 'ig' 'id' 'ga' 'it' 'ja' 'jw' 'kn' 'kk' 'km' 'ko' 'ku' 'ky' 'lo' 'la' 'lv' 'lt' 'lb' 'mk' 'mg' 'ms' 'ml' 'mt' 'mi' 'mr' 'mn' 'my' 'ne' 'no' 'or' 'ps' 'fa' 'pl' 'pt' 'pa' 'ro' 'ru' 'sm' 'gd' 'sr' 'st' 'sn' 'sd' 'si' 'sk' 'sl' 'so' 'es' 'su' 'sw' 'sv' 'tg' 'ta' 'te' 'th' 'tr' 'uk' 'ur' 'ug' 'uz' 'vi' 'cy' 'xh' 'yi' 'yo' 'zu' | |
) | |
# Task 1: Get list of model_ids | |
model_ids=$(python -c " | |
from huggingface_hub import HfApi | |
hf_api = HfApi() | |
models = hf_api.list_models(author='Helsinki-NLP', filter='marian') | |
model_ids = [model.id for model in models] | |
print('\n'.join(model_ids)) | |
") | |
model_ids=$(python -c " | |
import requests | |
params = { | |
'author': 'Helsinki-NLP', | |
'other': 'marian', | |
'expand[]': 'downloadsAllTime', | |
} | |
response = requests.get('https://huggingface.co/api/models', params=params) | |
models = response.json() | |
model_list = sorted(models, key=lambda e: e['downloadsAllTime'], reverse=True) | |
model_list = [model['id'] for model in model_list if 'big' not in model['id']] | |
print('\n'.join(model_list)) | |
") | |
CONDA_PATH=$(conda info --base) | |
DEVICE="'cuda'" | |
# Loop through model_ids | |
for model_id in $model_ids; do | |
echo -e " - $model_id" | |
source_lang=$(echo "$model_id" | awk -F'-' '{print $(NF-1)}') | |
# Check if source_lang length is greater than 3 | |
if [ ${#source_lang} -gt 3 ]; then | |
echo -e "${YELLOW}${WARNING} Skipping: $model_id with $source_lang${NC}" | |
continue | |
fi | |
# Check if source_lang is in the list of valid languages | |
if [[ " ${LANGUAGES[@]} " =~ " $source_lang " ]]; then | |
# Attempt translation with try-except block | |
translation=$($CONDA_PATH/envs/py39/bin/python - <<EOF | |
from googletrans import Translator | |
from logging import captureWarnings, getLogger, ERROR | |
logger = getLogger('py.warnings') | |
logger.setLevel(ERROR) | |
captureWarnings(True) | |
translator = Translator() | |
source_lang = "$source_lang" | |
try: | |
translation = translator.translate("Hey! Let\'s learn together", dest=source_lang) | |
from colorama import Fore, Back, Style | |
print(translation.text) | |
except Exception as e: | |
print(f"Error translating {source_lang}: {str(e)}") | |
translation = None | |
EOF | |
) | |
echo -e "${GREEN}${SMILEY}Testing input $model_id with input sentence from $source_lang: $translation ${NC}" | |
else | |
translation="' >>en<< Hey how are you?'" | |
echo -e "${YELLOW} \tFailed to translate to $source_lang, using the default english prompt $translation ${NC}" | |
fi | |
captured_output="''" | |
# Task 3: Run scripts with MarianMTModel | |
formatted_model_name=$(echo $model_id | tr '/' '_') | |
output_file="${formatted_model_name}_${conda_env}.pt" | |
python_path="$CONDA_PATH/envs/4.29/bin/python" | |
captured_output=$(TRANSFORMERS_VERBOSITY=error $python_path - <<EOF | |
from transformers import AutoTokenizer, MarianMTModel | |
from logging import captureWarnings, getLogger, ERROR | |
import torch, os, transformers | |
logger = getLogger('py.warnings') | |
logger.setLevel(ERROR) | |
captureWarnings(True) | |
transformers.utils.logging.set_verbosity_error() | |
tokenizer = AutoTokenizer.from_pretrained("$model_id") | |
inputs = tokenizer("$translation", return_tensors="pt", padding=True).to($DEVICE) | |
model = MarianMTModel.from_pretrained("$model_id").to($DEVICE) | |
torch.save(model(**inputs, decoder_input_ids = inputs["input_ids"]).logits.detach(),'Arthur/$output_file') | |
print(tokenizer.batch_decode(model.generate(**inputs))) | |
EOF | |
) | |
commit_description="""Following the merge of [a PR](https://github.com/huggingface/transformers/pull/24310) in \`transformers\` it appeared that \ | |
this model was not properly converted. This PR will fix the inference and was tested using the following script: | |
\`\`\`python | |
>>> from transformers import AutoTokenizer, MarianMTModel | |
>>> tokenizer = AutoTokenizer.from_pretrained('$model_id') | |
>>> inputs = tokenizer(\"$translation\", return_tensors=\"pt\", padding=True) | |
>>> model = MarianMTModel.from_pretrained('$model_id') | |
>>> print(tokenizer.batch_decode(model.generate(**inputs))) | |
"$captured_output" | |
\`\`\` | |
""" | |
echo -e "${YELLOW}🤗 transformers == 4.29.1: $captured_output ${NC}" | |
python_path="$CONDA_PATH/envs/py39/bin/python" | |
TRANSFORMERS_VERBOSITY=error $python_path - <<EOF | |
from transformers import AutoTokenizer, MarianMTModel, MarianModel | |
from logging import captureWarnings, getLogger, ERROR | |
import torch, os, transformers | |
from huggingface_hub import HfApi | |
api = HfApi() | |
logger = getLogger('py.warnings') | |
logger.setLevel(ERROR) | |
captureWarnings(True) | |
transformers.utils.logging.set_verbosity_error() | |
tokenizer = AutoTokenizer.from_pretrained("$model_id") | |
inputs = tokenizer("$translation", return_tensors="pt", padding=True).to($DEVICE) | |
model = MarianMTModel.from_pretrained("$model_id", torch_dtype="auto").to($DEVICE) | |
logits = model(**inputs, decoder_input_ids = inputs["input_ids"]).logits | |
torch.save(logits.detach(),'Arthur/$output_file') | |
from colorama import Fore, Back, Style | |
translated = model.generate(**inputs) | |
color = Fore.GREEN if $captured_output == tokenizer.batch_decode(translated) else Fore.RED | |
print(color + "🤗 transformers == 4.34-before:\t", tokenizer.batch_decode(translated)) | |
model_base = MarianModel.from_pretrained("$model_id", torch_dtype="auto").to($DEVICE) | |
model.model = model_base | |
model.lm_head.weight.data = model.model.shared.weight.data | |
translated_fixed = model.generate(**inputs) | |
color = Fore.GREEN if $captured_output == tokenizer.batch_decode(translated_fixed) else Fore.RED | |
print(color + "🤗 transformers == 4.34-fixed:\t", tokenizer.batch_decode(translated_fixed)) | |
print(Style.RESET_ALL) | |
if tokenizer.batch_decode(translated) != tokenizer.batch_decode(translated_fixed) and len(translated_fixed)<40: | |
print("writing $model_id to a file") | |
del model.config.transformers_version | |
del model.config._name_or_path | |
commit_details = model.push_to_hub("$model_id", create_pr = True, commit_message="Update checkpoint for transformers>=4.29", commit_description="""${commit_description}""") | |
print(commit_details.pr_url) | |
api.merge_pull_request("$model_id", discussion_num = commit_details.pr_num, comment = "Automatically merging the PR.") | |
with open("Arthur/files_to_update.txt", 'a+') as f: | |
f.write(f"$model_id\n") | |
EOF | |
# Clear Hugging Face cache | |
rm -rf /Users/arthur/.cache/huggingface/hub/models--Helsinki-NLP--opus-mt-* | |
done |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment