Last active
May 6, 2023 17:31
-
-
Save skliarpawlo/335b0d85594e444cc10fd414901cb902 to your computer and use it in GitHub Desktop.
Ray HuggingfaceTrainer problem
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset | |
import transformers | |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
import ray | |
from ray import tune, air | |
from ray.train.huggingface import HuggingFaceTrainer | |
from ray.air.config import ScalingConfig | |
import os | |
# If using GPUs, set this to True. | |
use_gpu = False | |
model_checkpoint = "gpt2" | |
tokenizer_checkpoint = "sgugger/gpt2-like-tokenizer" | |
block_size = 128 | |
datasets = load_dataset("wikitext", "wikitext-2-raw-v1") | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) | |
def tokenize_function(examples): | |
return tokenizer(examples["text"]) | |
tokenized_datasets = datasets.map( | |
tokenize_function, batched=True, num_proc=1, remove_columns=["text"] | |
) | |
def group_texts(examples): | |
# Concatenate all texts. | |
concatenated_examples = { | |
k: sum(examples[k], []) for k in examples.keys() | |
} | |
total_length = len(concatenated_examples[list(examples.keys())[0]]) | |
# We drop the small remainder, we could add padding if the model | |
# supported it. | |
# instead of this drop, you can customize this part to your needs. | |
total_length = (total_length // block_size) * block_size | |
# Split by chunks of max_len. | |
result = { | |
k: [ | |
t[i : i + block_size] | |
for i in range(0, total_length, block_size) | |
] | |
for k, t in concatenated_examples.items() | |
} | |
result["labels"] = result["input_ids"].copy() | |
return result | |
lm_datasets = tokenized_datasets.map( | |
group_texts, | |
batched=True, | |
batch_size=1000, | |
num_proc=1, | |
) | |
ray_train_ds = ray.data.from_huggingface(lm_datasets["train"]) | |
ray_evaluation_ds = ray.data.from_huggingface( | |
lm_datasets["validation"] | |
) | |
def trainer_init_per_worker(train_dataset, eval_dataset, **config): | |
model_config = AutoConfig.from_pretrained(model_checkpoint) | |
model = AutoModelForCausalLM.from_config(model_config) | |
args = transformers.TrainingArguments( | |
output_dir=f"/tmp/{model_checkpoint}-wikitext2", | |
# evaluation_strategy="epoch", | |
# save_strategy="epoch", | |
# logging_strategy="epoch", | |
save_steps=2, | |
logging_steps=2, | |
metric_for_best_model='loss', | |
save_total_limit=1, | |
learning_rate=config.get('learning_rate'), | |
weight_decay=config.get('weight_decay'), | |
max_steps=30, | |
num_train_epochs=3, | |
no_cuda=(not use_gpu), | |
) | |
return transformers.Trainer( | |
model=model, | |
args=args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
) | |
scaling_config = ScalingConfig(num_workers=3, use_gpu=use_gpu) | |
trainer = HuggingFaceTrainer( | |
trainer_init_per_worker=trainer_init_per_worker, | |
scaling_config=scaling_config, | |
datasets={"train": ray_train_ds, "evaluation": ray_evaluation_ds}, | |
) | |
if __name__ == '__main__': | |
S3_BUCKET = os.environ['S3_BUCKET'] | |
upload_dir = os.environ['UPLOAD_DIR'] | |
name = os.environ['EXPERIMENT_NAME'] | |
tuner = tune.Tuner( | |
trainer, | |
param_space={ | |
'trainer_init_config': { | |
'weight_decay': tune.grid_search([0.01, 0.02]), | |
'learning_rate': tune.grid_search([2e-5, 2e-4]), | |
}, | |
}, | |
tune_config=tune.TuneConfig( | |
num_samples=1, | |
max_concurrent_trials=20, | |
), | |
run_config=air.RunConfig( | |
name=name, | |
local_dir='/tmp/experiment_dir', | |
sync_config=tune.SyncConfig( | |
upload_dir=upload_dir, | |
), | |
checkpoint_config=air.CheckpointConfig( | |
num_to_keep=2, | |
checkpoint_score_attribute='loss', | |
checkpoint_score_order='min', | |
), | |
failure_config=air.FailureConfig( | |
max_failures=1, | |
), | |
), | |
) | |
results = tuner.fit() | |
print(results.get_best_result(metric="loss", mode="min").config) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment