Suhas Kotha Monday at 5:54 PM ive found this code to be a super simple and functioning multi-gpu training script https://github.com/ZitongYang/Synthetic_Continued_Pretraining/blob/main/train.py scripts/train.sh calls train.py. the number of gpus is pulled from the number of available gpus, and it uses the fsdp config specified in scripts/config/fsdp_config.json train.py from dataclasses import dataclass, field, asdict from typing import Optional import transformers import os import warnings warnings.filterwarnings("ignore", category=FutureWarning) import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
from data.cptdata import get_task_data_module
@dataclass class TrainingConfig: task_name: str block_size: int rehersal_rate: float model_name: str subsample_ratio: float wandb_project: Optional[str] = field(default="synthetic-continued-pretraining")
def __post_init__(self):
os.environ['WANDB_PROJECT'] = self.wandb_project
def train(): # parsing input parser = transformers.HfArgumentParser((TrainingConfig, transformers.TrainingArguments)) config, args = parser.parse_args_into_dataclasses() log_config = {**asdict(config), **asdict(args)} logging.info(f"Training config: {log_config}")
# loading model
model = transformers.AutoModelForCausalLM.from_pretrained(
config.model_name)
# loading dataset
data_module = get_task_data_module(**asdict(config))
# setting up trainer
trainer = transformers.Trainer(model=model, args=args, **data_module)
trainer.train()
trainer.save_model(output_dir=args.output_dir)
trainer.accelerator.wait_for_everyone()
if name == "main": train() Show less https://github.com/ZitongYang/Synthetic_Continued_Pretraining|ZitongYang/Synthetic_Continued_PretrainingZitongYang/Synthetic_Continued_Pretraining | Added by GitHub
Suhas Kotha Monday at 5:54 PM shoutout @Zitong Yang and @Neil Band for such simple code
from dataclasses import dataclass, field, asdict
from typing import Optional
import transformers
import os
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
from data.cptdata import get_task_data_module
@DataClass
class TrainingConfig:
task_name: str
block_size: int
rehersal_rate: float
model_name: str
subsample_ratio: float
wandb_project: Optional[str] = field(default="synthetic-continued-pretraining")
def train():
# parsing input
parser = transformers.HfArgumentParser((TrainingConfig, transformers.TrainingArguments))
config, args = parser.parse_args_into_dataclasses()
log_config = {**asdict(config), **asdict(args)}
logging.info(f"Training config: {log_config}")
if name == "main":
train()