Skip to content

Instantly share code, notes, and snippets.

@brando90
Created February 5, 2025 18:36
Show Gist options
  • Save brando90/beff410b90d9a3f3e44490acf084df89 to your computer and use it in GitHub Desktop.
Save brando90/beff410b90d9a3f3e44490acf084df89 to your computer and use it in GitHub Desktop.
train.py

Suhas Kotha Monday at 5:54 PM ive found this code to be a super simple and functioning multi-gpu training script https://github.com/ZitongYang/Synthetic_Continued_Pretraining/blob/main/train.py scripts/train.sh calls train.py. the number of gpus is pulled from the number of available gpus, and it uses the fsdp config specified in scripts/config/fsdp_config.json train.py from dataclasses import dataclass, field, asdict from typing import Optional import transformers import os import warnings warnings.filterwarnings("ignore", category=FutureWarning) import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

from data.cptdata import get_task_data_module

@dataclass class TrainingConfig: task_name: str block_size: int rehersal_rate: float model_name: str subsample_ratio: float wandb_project: Optional[str] = field(default="synthetic-continued-pretraining")

def __post_init__(self):
    os.environ['WANDB_PROJECT'] = self.wandb_project

def train(): # parsing input parser = transformers.HfArgumentParser((TrainingConfig, transformers.TrainingArguments)) config, args = parser.parse_args_into_dataclasses() log_config = {**asdict(config), **asdict(args)} logging.info(f"Training config: {log_config}")

# loading model
model = transformers.AutoModelForCausalLM.from_pretrained(
    config.model_name)
# loading dataset
data_module = get_task_data_module(**asdict(config))

# setting up trainer
trainer = transformers.Trainer(model=model, args=args, **data_module)
trainer.train()
trainer.save_model(output_dir=args.output_dir)
trainer.accelerator.wait_for_everyone()

if name == "main": train() Show less https://github.com/ZitongYang/Synthetic_Continued_Pretraining|ZitongYang/Synthetic_Continued_PretrainingZitongYang/Synthetic_Continued_Pretraining | Added by GitHub

Suhas Kotha Monday at 5:54 PM shoutout @Zitong Yang and @Neil Band for such simple code

@brando90
Copy link
Author

brando90 commented Feb 5, 2025

from dataclasses import dataclass, field, asdict
from typing import Optional
import transformers
import os
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

from data.cptdata import get_task_data_module

@DataClass
class TrainingConfig:
task_name: str
block_size: int
rehersal_rate: float
model_name: str
subsample_ratio: float
wandb_project: Optional[str] = field(default="synthetic-continued-pretraining")

def __post_init__(self):
    os.environ['WANDB_PROJECT'] = self.wandb_project

def train():
# parsing input
parser = transformers.HfArgumentParser((TrainingConfig, transformers.TrainingArguments))
config, args = parser.parse_args_into_dataclasses()
log_config = {**asdict(config), **asdict(args)}
logging.info(f"Training config: {log_config}")

# loading model
model = transformers.AutoModelForCausalLM.from_pretrained(
    config.model_name)
# loading dataset
data_module = get_task_data_module(**asdict(config))

# setting up trainer
trainer = transformers.Trainer(model=model, args=args, **data_module)
trainer.train()
trainer.save_model(output_dir=args.output_dir)
trainer.accelerator.wait_for_everyone()

if name == "main":
train()

@brando90
Copy link
Author

brando90 commented Feb 5, 2025

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment