Created
April 24, 2024 02:43
-
-
Save zjjott/9d26b31f99c5aaad7db6417a5bcc3ff9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
import logging | |
import math | |
import os | |
import random | |
import time | |
from collections.abc import Mapping | |
from contextlib import contextmanager, nullcontext | |
from functools import partial | |
from itertools import chain | |
import datasets | |
import matplotlib.pyplot as plt | |
import torch | |
import transformers | |
import traceback | |
from datasets import load_dataset, load_from_disk | |
from instruction_dataset_utils import InstructionDataset | |
from matplotlib.ticker import MaxNLocator | |
from peft import LoraConfig, TaskType, get_peft_model | |
from torch.distributed.fsdp import FullStateDictConfig | |
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision | |
from torch.distributed.fsdp import StateDictType | |
from torch.utils.collect_env import get_pretty_env_info | |
from torch.utils.data import DataLoader | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.utils.tensorboard import SummaryWriter | |
from tqdm.auto import tqdm | |
from transformers import ( | |
CONFIG_MAPPING, | |
MODEL_MAPPING, | |
AutoConfig, | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
SchedulerType, | |
default_data_collator, | |
get_scheduler, | |
set_seed, | |
) | |
# , LlamaAttention, LlamaMLP | |
from transformers.models.llama.modeling_llama import LlamaDecoderLayer | |
from transformers.utils import check_min_version | |
from transformers.utils.versions import require_version | |
import torch.distributed as dist | |
from torch.nn.parallel import DistributedDataParallel | |
import torch_xla | |
import torch.distributed as dist | |
from torch_xla import runtime as xr | |
import torch_xla.core.xla_model as xm | |
import torch_xla.debug.profiler as xp | |
import torch_xla.distributed.parallel_loader as pl | |
import torch_xla.test.test_utils as test_utils | |
from torch_xla.amp import autocast as xla_autocast, GradScaler | |
from torch.amp import autocast as torch_autocast | |
import torch_xla.distributed.xla_backend | |
from torch_xla._internal import pjrt | |
import functools | |
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
import os | |
from torch_xla.amp import syncfree | |
# import atorch | |
# from atorch.auto import auto_accelerate | |
# from atorch.utils.meta_model_utils import init_empty_weights_with_disk_offload | |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. | |
check_min_version("4.31.0") | |
logger = logging.getLogger(__name__) | |
require_version( | |
"datasets>=1.8.0", | |
"To fix: pip install -U -r requirements.txt", | |
) | |
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) | |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) | |
TRAINING_ARGS_NAME = "training_args.bin" | |
def _train_update(device, x, loss, rate): | |
if is_main_process(): | |
test_utils.print_training_update( | |
device, x, loss.item(), rate, rate*world_size()) | |
def is_main_process(): | |
return dist.get_rank() == 0 | |
def is_local_main_process(): | |
return int(os.environ["LOCAL_RANK"]) == 0 | |
def local_rank(): | |
return int(os.environ["LOCAL_RANK"]) | |
def world_size(): | |
return dist.get_world_size() | |
def wait_for_everyone(): | |
torch.distributed.barrier() | |
def _goes_first(is_main): | |
if is_main is False: | |
wait_for_everyone() | |
yield | |
if is_main is True: | |
wait_for_everyone() | |
@contextmanager | |
def main_process_first(): | |
yield from _goes_first(is_main_process()) | |
def unwrap_model(model): | |
""" | |
Recursively unwraps a model from potential containers (as used in distributed training). | |
Args: | |
model (`torch.nn.Module`): The model to unwrap. | |
""" | |
# since there could be multiple levels of wrapping, unwrap recursively | |
if hasattr(model, "module"): | |
return unwrap_model(model.module) | |
else: | |
return model | |
def honor_type(obj, generator): | |
""" | |
Cast a generator to the same type as obj (list, tuple or namedtuple) | |
""" | |
try: | |
return type(obj)(generator) | |
except TypeError: | |
# Some objects may not be able to instantiate from a generator directly | |
return type(obj)(*list(generator)) | |
def recursively_apply( | |
func, | |
data, | |
*args, | |
test_type=lambda t: isinstance(t, torch.Tensor), | |
error_on_other_type=False, | |
**kwargs, | |
): | |
if isinstance(data, (tuple, list)): | |
return honor_type( | |
data, | |
( | |
recursively_apply( | |
func, | |
o, | |
*args, | |
test_type=test_type, | |
error_on_other_type=error_on_other_type, | |
**kwargs, | |
) | |
for o in data | |
), | |
) | |
elif isinstance(data, Mapping): | |
return type(data)( | |
{ | |
k: recursively_apply( | |
func, | |
v, | |
*args, | |
test_type=test_type, | |
error_on_other_type=error_on_other_type, | |
**kwargs, | |
) | |
for k, v in data.items() | |
} | |
) | |
elif test_type(data): | |
return func(data, *args, **kwargs) | |
elif error_on_other_type: | |
raise TypeError( | |
f"Can't apply {func.__name__} on object of type {type(data)}, only of nested list/tuple/dicts of objects " | |
f"that satisfy {test_type.__name__}." | |
) | |
return data | |
def gather(tensor): | |
def _gpu_gather_one(tensor): | |
if tensor.ndim == 0: | |
tensor = tensor.clone()[None] | |
output_tensors = [tensor.clone() | |
for _ in range(torch.distributed.get_world_size())] | |
torch.distributed.all_gather(output_tensors, tensor) | |
return torch.cat(output_tensors, dim=0) | |
return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True) | |
def model_parameters_num(model): | |
""" | |
Prints the number of trainable parameters in the model. | |
""" | |
trainable_params = 0 | |
all_param = 0 | |
for param in model.parameters(): | |
num_params = param.numel() | |
# if using DS Zero 3 and the weights are initialized empty | |
if num_params == 0 and hasattr(param, "ds_numel"): | |
num_params = param.ds_numel | |
all_param += num_params | |
if param.requires_grad: | |
trainable_params += num_params | |
return all_param, trainable_params | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description="Finetune a transformers model on a causal language modeling task") | |
parser.add_argument( | |
"--not_save_model", | |
action="store_true", | |
help="Do not keep line breaks when using TXT files.", | |
) | |
parser.add_argument( | |
"--dataset_name", | |
type=str, | |
default=None, | |
help="The name of the dataset to use (via the datasets library).", | |
) | |
parser.add_argument( | |
"--dataset_config_name", | |
type=str, | |
default=None, | |
help="The configuration name of the dataset to use (via the datasets library).", | |
) | |
parser.add_argument( | |
"--dataset_path", | |
type=str, | |
default=None, | |
help="A dir containing dataset with .arrow format.", | |
) | |
parser.add_argument( | |
"--train_file", | |
type=str, | |
default=None, | |
help="A csv or a json file containing the training data.", | |
) | |
parser.add_argument( | |
"--validation_file", | |
type=str, | |
default=None, | |
help="A csv or a json file containing the validation data.", | |
) | |
parser.add_argument( | |
"--validation_split_percentage", | |
default=5, | |
help="The percentage of the train set used as validation set in case there's no validation split", | |
) | |
parser.add_argument( | |
"--model_name_or_path", | |
type=str, | |
help="Path to pretrained model or model identifier from huggingface.co/models.", | |
required=False, | |
) | |
parser.add_argument( | |
"--config_name", | |
type=str, | |
default=None, | |
help="Pretrained config name or path if not the same as model_name", | |
) | |
parser.add_argument( | |
"--tokenizer_name", | |
type=str, | |
default=None, | |
help="Pretrained tokenizer name or path if not the same as model_name", | |
) | |
parser.add_argument( | |
"--use_slow_tokenizer", | |
action="store_true", | |
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", | |
) | |
parser.add_argument( | |
"--trust_remote_code", | |
action="store_true", | |
help="If passed, will set trust_remote_code=True when calling from_pretrained.", | |
) | |
parser.add_argument( | |
"--per_device_train_batch_size", | |
type=int, | |
default=0, | |
help="Batch size (per device) for the training dataloader.", | |
) | |
parser.add_argument( | |
"--total_train_batch_size", | |
type=int, | |
default=8, | |
help="All batch size for the training dataloader. Equals to per_device_train_batch_size * world_size.", | |
) | |
parser.add_argument( | |
"--per_device_eval_batch_size", | |
type=int, | |
default=8, | |
help="Batch size (per device) for the evaluation dataloader.", | |
) | |
parser.add_argument( | |
"--learning_rate", | |
type=float, | |
default=5e-5, | |
help="Initial learning rate (after the potential warmup period) to use.", | |
) | |
parser.add_argument( | |
"--max_grad_norm", | |
type=float, | |
default=0, | |
help="Clips gradient norm of an iterable of parameters.", | |
) | |
parser.add_argument("--weight_decay", type=float, | |
default=0.0, help="Weight decay to use.") | |
parser.add_argument("--no_decay", nargs="*", | |
default=["bias", "LlamaRMSNorm.weight"], help="No decay params.") | |
parser.add_argument( | |
"--num_train_epochs", | |
type=int, | |
default=3, | |
help="Total number of training epochs to perform.", | |
) | |
parser.add_argument( | |
"--max_train_steps", | |
type=int, | |
default=None, | |
help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | |
) | |
parser.add_argument( | |
"--gradient_accumulation_steps", | |
type=int, | |
default=1, | |
help="Number of updates steps to accumulate before performing a backward/update pass.", | |
) | |
parser.add_argument( | |
"--lr_scheduler_type", | |
type=SchedulerType, | |
default="linear", | |
help="The scheduler type to use.", | |
choices=[ | |
"linear", | |
"cosine", | |
"cosine_with_restarts", | |
"polynomial", | |
"constant", | |
"constant_with_warmup", | |
], | |
) | |
parser.add_argument( | |
"--warmup_steps", | |
type=int, | |
default=0, | |
help="Number of steps for the warmup in the lr scheduler.", | |
) | |
parser.add_argument( | |
"--warmup_ratio", | |
type=float, | |
default=0.0, | |
help="Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.", | |
) | |
parser.add_argument("--output_dir", type=str, default=None, | |
help="Where to store the final model.") | |
parser.add_argument("--seed", type=int, default=None, | |
help="A seed for reproducible training.") | |
parser.add_argument( | |
"--model_type", | |
type=str, | |
default=None, | |
help="Model type to use if training from scratch.", | |
choices=MODEL_TYPES, | |
) | |
parser.add_argument( | |
"--block_size", | |
type=int, | |
default=None, | |
help=( | |
"Optional input sequence length after tokenization. The training dataset will be truncated in block of" | |
" this size for training. Default to the model max input length for single sentence inputs (take into" | |
" account special tokens)." | |
), | |
) | |
parser.add_argument( | |
"--preprocessing_num_workers", | |
type=int, | |
default=None, | |
help="The number of processes to use for the preprocessing.", | |
) | |
parser.add_argument( | |
"--dataloader_num_workers", | |
type=int, | |
default=0, | |
help="The number of sub-processes to use for the dataloader.", | |
) | |
parser.add_argument( | |
"--overwrite_cache", | |
action="store_true", | |
help="Overwrite the cached training and evaluation sets", | |
) | |
parser.add_argument( | |
"--no_keep_linebreaks", | |
action="store_true", | |
help="Do not keep line breaks when using TXT files.", | |
) | |
parser.add_argument( | |
"--checkpointing_steps", | |
type=str, | |
default=None, | |
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", | |
) | |
parser.add_argument( | |
"--resume_from_checkpoint", | |
type=str, | |
default=None, | |
help="If the training should continue from a checkpoint folder.", | |
) | |
parser.add_argument( | |
"--ignore_dryrun_on_load_strategy", | |
action="store_true", | |
) | |
parser.add_argument( | |
"--logging_steps", | |
type=int, | |
default=0, | |
help="Log every X updates steps. Zero means do not logging.", | |
) | |
parser.add_argument( | |
"--report_to", | |
type=str, | |
default=None, | |
help=( | |
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' | |
' `"matplotlib"`, and `"all"`. Use `"all"` (default) to report to all integrations.' | |
"Only applicable when `--logging_steps` bigger than 0." | |
), | |
choices=["all", "tensorboard", "matplotlib"], | |
) | |
parser.add_argument( | |
"--ignore_mismatched_sizes", | |
action="store_true", | |
help="If passed, will set ignore_mismatched_sizes=True when calling from_pretrained.", | |
) | |
parser.add_argument( | |
"--distributed_method", | |
default="ddp", | |
choices=["ddp", "fsdp"], | |
help="If passed, use fsdp", | |
) | |
parser.add_argument( | |
"--fsdp_cpu_offload", | |
action="store_true", | |
help="If passed, offload model params to cpu memory.", | |
) | |
parser.add_argument( | |
"--precision", | |
type=str, | |
choices=["fp32", "bf16_amp", "fp16_amp", "bf16"], | |
default="bf16_amp", | |
) | |
parser.add_argument( | |
"--gradient_checkpointing", | |
action="store_true", | |
help="Use gradient checkpointing or not.", | |
) | |
parser.add_argument( | |
"--peft_type", | |
type=str, | |
default=None, | |
help="Whether use peft and use what type of peft.", | |
) | |
parser.add_argument( | |
"--lora_r", | |
type=int, | |
default=8, | |
help="Lora attention dimension.", | |
) | |
parser.add_argument( | |
"--lora_alpha", | |
type=int, | |
default=16, | |
help="The alpha parameter for Lora scaling.", | |
) | |
parser.add_argument( | |
"--lora_dropout", | |
type=float, | |
default=0.05, | |
help="The dropout probability for Lora layers.", | |
) | |
parser.add_argument( | |
"--lora_target_modules", | |
nargs="*", | |
default=["q_proj", "v_proj"], | |
help="The names of the modules to apply Lora to.", | |
) | |
parser.add_argument( | |
"--peft_task_type", | |
type=str, | |
default=TaskType.CAUSAL_LM, | |
choices=[TaskType.SEQ_CLS, TaskType.SEQ_2_SEQ_LM, | |
TaskType.CAUSAL_LM, TaskType.TOKEN_CLS], | |
help="Peft task type.", | |
) | |
parser.add_argument( | |
"--fsdp_wrap_trainable_outmost", | |
action="store_true", | |
help="If fsdp would use wrap_trainable_outmost for peft model.", | |
) | |
parser.add_argument( | |
"--random_log_n_training_samples", | |
type=int, | |
default=3, | |
help="Log a few random samples from the training set.", | |
) | |
parser.add_argument( | |
"--max_shard_size", | |
default=None, | |
help=( | |
"The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size " | |
"lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`)." | |
"`None` means no shard." | |
), | |
) | |
parser.add_argument( | |
"--enable_profiler", | |
action="store_true", | |
help="If passed, use torch.profiler.profile", | |
) | |
parser.add_argument( | |
"--using_xla", | |
action="store_true", | |
help="If passed, using xla device", | |
) | |
parser.add_argument( | |
"--init_emtpy_offload", | |
action="store_true", | |
help="If passed, use init_empty_weights_with_disk_offload. Should be used when training from scratch.", | |
) | |
args = parser.parse_args() | |
# Sanity checks | |
if ( | |
args.dataset_name is None | |
and args.train_file is None | |
and args.validation_file is None | |
and args.dataset_path is None | |
): | |
raise ValueError( | |
"Need either a dataset name or a training/validation file.") | |
else: | |
if args.train_file is not None: | |
extension = args.train_file.split(".")[-1] | |
assert extension in [ | |
"csv", | |
"json", | |
"txt", | |
], "`train_file` should be a csv, json or txt file." | |
if args.validation_file is not None: | |
extension = args.validation_file.split(".")[-1] | |
assert extension in [ | |
"csv", | |
"json", | |
"txt", | |
], "`validation_file` should be a csv, json or txt file." | |
return args | |
# for auto_accelerate | |
def optim_param_func(model, args): | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [ | |
p for n, p in model.named_parameters() if not any(nd in n for nd in args.no_decay) and p.requires_grad | |
], | |
"weight_decay": args.weight_decay, | |
}, | |
{ | |
"params": [ | |
p for n, p in model.named_parameters() if any(nd in n for nd in args.no_decay) and p.requires_grad | |
], | |
"weight_decay": 0.0, | |
}, | |
] | |
return optimizer_grouped_parameters | |
# for auto_accelerate | |
def my_loss_func(_, outputs): | |
if isinstance(outputs, dict): | |
return outputs["loss"] | |
# for auto_accelerate | |
def my_prepare_input(batch, device): | |
batch = {k: v.to(device=device, non_blocking=True) | |
for k, v in batch.items()} | |
return batch | |
def get_dataset(args): | |
raw_datasets = None | |
if is_local_main_process(): | |
datasets.utils.logging.set_verbosity_warning() | |
transformers.utils.logging.set_verbosity_info() | |
else: | |
datasets.utils.logging.set_verbosity_error() | |
transformers.utils.logging.set_verbosity_error() | |
# If passed along, set the training seed now. | |
if args.seed is not None: | |
set_seed(args.seed) | |
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) | |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ | |
# (the dataset will be downloaded automatically from the datasets Hub). | |
# | |
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called | |
# 'text' is found. You can easily tweak this behavior (see below). | |
# | |
# In distributed training, the load_dataset function guarantee that only one local process can concurrently | |
# download the dataset. | |
if args.dataset_name is not None: | |
# Downloading and loading a dataset from the hub. | |
raw_datasets = load_dataset( | |
args.dataset_name, args.dataset_config_name) | |
if "validation" not in raw_datasets.keys(): | |
raw_datasets["validation"] = load_dataset( | |
args.dataset_name, | |
args.dataset_config_name, | |
split=f"train[:{args.validation_split_percentage}%]", | |
) | |
raw_datasets["train"] = load_dataset( | |
args.dataset_name, | |
args.dataset_config_name, | |
split=f"train[{args.validation_split_percentage}%:]", | |
) | |
elif args.dataset_path is not None: | |
raw_datasets = load_from_disk(args.dataset_path) | |
else: | |
data_files = {} | |
dataset_args = {} | |
if args.train_file is not None: | |
data_files["train"] = args.train_file | |
if args.validation_file is not None: | |
data_files["validation"] = args.validation_file | |
extension = args.train_file.split(".")[-1] | |
if extension == "txt": | |
extension = "text" | |
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks | |
raw_datasets = load_dataset( | |
extension, data_files=data_files, **dataset_args) | |
# If no validation data is there, validation_split_percentage will be used to divide the dataset. | |
if "validation" not in raw_datasets.keys(): | |
raw_datasets["validation"] = load_dataset( | |
extension, | |
data_files=data_files, | |
split=f"train[:{args.validation_split_percentage}%]", | |
**dataset_args, | |
) | |
raw_datasets["train"] = load_dataset( | |
extension, | |
data_files=data_files, | |
split=f"train[{args.validation_split_percentage}%:]", | |
**dataset_args, | |
) | |
return raw_datasets | |
def get_config(args): | |
config = None | |
if args.config_name: | |
config = AutoConfig.from_pretrained( | |
args.config_name, | |
trust_remote_code=args.trust_remote_code, | |
ignore_mismatched_sizes=args.ignore_mismatched_sizes, | |
) | |
elif args.model_name_or_path: | |
config = AutoConfig.from_pretrained( | |
args.model_name_or_path, | |
trust_remote_code=args.trust_remote_code, | |
ignore_mismatched_sizes=args.ignore_mismatched_sizes, | |
) | |
else: | |
config = CONFIG_MAPPING[args.model_type]() | |
logger.warning( | |
"You are instantiating a new config instance from scratch.") | |
return config | |
def get_tokenizer(args): | |
tokenizer = None | |
if args.tokenizer_name: | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.tokenizer_name, | |
use_fast=not args.use_slow_tokenizer, | |
trust_remote_code=args.trust_remote_code, | |
) | |
elif args.model_name_or_path: | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.model_name_or_path, | |
use_fast=not args.use_slow_tokenizer, | |
trust_remote_code=args.trust_remote_code, | |
) | |
else: | |
raise ValueError( | |
"You are instantiating a new tokenizer from scratch. This is not supported by this script." | |
"You can do it from another script, save it, and load it from here, using --tokenizer_name." | |
) | |
return tokenizer | |
def get_model(args, config): | |
model = None | |
if args.model_name_or_path: | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_name_or_path, | |
from_tf=bool(".ckpt" in args.model_name_or_path), | |
config=config, | |
trust_remote_code=args.trust_remote_code, | |
ignore_mismatched_sizes=args.ignore_mismatched_sizes, | |
) | |
else: | |
logger.info("Training new model from scratch") | |
model = AutoModelForCausalLM.from_config( | |
config, | |
trust_remote_code=args.trust_remote_code, | |
) | |
if args.peft_type is not None: | |
peft_config = get_peft_config(args) | |
logger.info(f"Load Peft {args.peft_type} model ......") | |
if args.gradient_checkpointing and args.peft_type == "lora": | |
# Make Lora and gradient checkpointing compatible | |
# https://github.com/huggingface/peft/issues/137 | |
if hasattr(model, "enable_input_require_grads"): | |
model.enable_input_require_grads() | |
else: | |
def make_inputs_require_grad(module, input, output): | |
output.requires_grad_(True) | |
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) | |
model = get_peft_model(model, peft_config) | |
return model | |
def get_peft_config(args): | |
""" | |
Returns: | |
config(PeftConfig) | |
""" | |
if args.peft_type == "lora": | |
peft_config = LoraConfig( | |
task_type=args.peft_task_type, | |
inference_mode=False, | |
r=args.lora_r, | |
lora_alpha=args.lora_alpha, | |
lora_dropout=args.lora_dropout, | |
target_modules=args.lora_target_modules, | |
) | |
else: | |
raise NotImplementedError(f"Not support {args.peft_type}") | |
return peft_config | |
def tokenize_dataset(args, model, raw_datasets, tokenizer): | |
def tokenize_function(examples): | |
return tokenizer(examples[text_column_name]) | |
column_names = raw_datasets["train"].column_names | |
text_column_name = "text" if "text" in column_names else column_names[0] | |
with main_process_first(): | |
tokenized_datasets = raw_datasets.map( | |
tokenize_function, | |
batched=True, | |
num_proc=args.preprocessing_num_workers, | |
remove_columns=column_names, | |
load_from_cache_file=not args.overwrite_cache, | |
desc="Running tokenizer on dataset", | |
) | |
return tokenized_datasets | |
def process_dataset(args, tokenized_datasets, tokenizer): | |
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. | |
def group_texts(examples): | |
# Concatenate all texts. | |
concatenated_examples = { | |
k: list(chain(*examples[k])) for k in examples.keys()} | |
total_length = len(concatenated_examples[list(examples.keys())[0]]) | |
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can | |
# customize this part to your needs. | |
if total_length >= block_size: | |
total_length = (total_length // block_size) * block_size | |
# Split by chunks of max_len. | |
result = { | |
k: [t[i: i + block_size] | |
for i in range(0, total_length, block_size)] | |
for k, t in concatenated_examples.items() | |
} | |
result["labels"] = result["input_ids"].copy() | |
return result | |
if args.block_size is None: | |
block_size = tokenizer.model_max_length | |
if block_size > 1024: | |
logger.warning( | |
"The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value" | |
" of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can" | |
" override this default with `--block_size xxx`." | |
) | |
block_size = 1024 | |
else: | |
if args.block_size > tokenizer.model_max_length: | |
logger.warning( | |
f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" | |
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." | |
) | |
block_size = min(args.block_size, tokenizer.model_max_length) | |
with main_process_first(): | |
lm_datasets = tokenized_datasets.map( | |
group_texts, | |
batched=True, | |
num_proc=args.preprocessing_num_workers, | |
load_from_cache_file=not args.overwrite_cache, | |
desc=f"Grouping texts in chunks of {block_size}", | |
) | |
return lm_datasets | |
def compute_training_flops( | |
batch_size, | |
sequence_length, | |
hidden_size, | |
vocab_size, | |
intermediate_size, | |
num_layers, | |
use_gradient_checkpointing=False, | |
use_peft=False, | |
): | |
"""Returns: | |
hardware flops | |
model flops | |
The source of formula: | |
Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM's | |
(APPENDIX: FLOATING-POINT OPERATIONS) | |
Assuming that backward pass has twice FLOPs as many as forward pass. Only matrix multiplication FLOPs are computed. | |
For use_peft, backward pass FLOPS is a little more than the forward pass. Assuming equal for simplicity here. | |
""" | |
attention_forward_flops = ( | |
8 * batch_size * sequence_length * hidden_size**2 + | |
4 * batch_size * sequence_length**2 * hidden_size | |
) | |
# llama2 use gate_proj, has 3 Linears | |
two_mlps_forward_flops = 3 * 2 * batch_size * \ | |
sequence_length * hidden_size * intermediate_size | |
logits_forward_flops = 2 * batch_size * \ | |
sequence_length * hidden_size * vocab_size | |
decoder_layer_forward_flops = attention_forward_flops + two_mlps_forward_flops | |
# forward FLOPs without gradient checkpointing | |
forward_flops_wo_gc = num_layers * \ | |
decoder_layer_forward_flops + logits_forward_flops | |
factor = 2 if use_peft else 3 | |
if not use_gradient_checkpointing: | |
return forward_flops_wo_gc * factor, forward_flops_wo_gc * factor | |
else: | |
return ( | |
num_layers * decoder_layer_forward_flops * | |
(factor + 1) + logits_forward_flops * factor, | |
forward_flops_wo_gc * factor, | |
) | |
def rewrite_logs(d): | |
new_d = {} | |
eval_prefix = "eval_" | |
eval_prefix_len = len(eval_prefix) | |
test_prefix = "test_" | |
test_prefix_len = len(test_prefix) | |
train_prefix = "train_" | |
train_prefix_len = len(train_prefix) | |
for k, v in d.items(): | |
if k.startswith(eval_prefix): | |
new_d["eval/" + k[eval_prefix_len:]] = v | |
elif k.startswith(test_prefix): | |
new_d["test/" + k[test_prefix_len:]] = v | |
elif k.startswith(train_prefix): | |
new_d["train/" + k[train_prefix_len:]] = v | |
else: | |
new_d["train/" + k] = v | |
return new_d | |
def default_logdir() -> str: | |
""" | |
Same default as PyTorch | |
""" | |
import socket | |
from datetime import datetime | |
current_time = datetime.now().strftime("%b%d_%H-%M-%S") | |
return os.path.join("runs", current_time + "_" + socket.gethostname()) | |
try: | |
import psutil | |
PSUTILS_INSTALLED = True | |
except ImportError: | |
PSUTILS_INSTALLED = False | |
pass | |
try: | |
from pynvml.smi import nvidia_smi | |
PYNAMY_INSTALLED = True | |
except ImportError: | |
nvidia_smi = None | |
PYNAMY_INSTALLED = False | |
class ThroughputTimer: | |
def __init__( | |
self, | |
batch_size, | |
start_step=2, | |
steps_per_output=50, | |
monitor_memory=False, | |
logging_fn=None, | |
): | |
self.start_time = 0 | |
self.end_time = 0 | |
self.started = False | |
self.batch_size = 1 if batch_size is None else batch_size | |
self.start_step = start_step | |
self.epoch_count = 0 | |
self.micro_step_count = 0 | |
self.global_step_count = 0 | |
self.total_elapsed_time = 0 | |
self.step_elapsed_time = 0 | |
self.steps_per_output = steps_per_output | |
self.monitor_memory = monitor_memory | |
self.logging = logging_fn | |
if self.logging is None: | |
import logging | |
logger = logging.getLogger(__name__) | |
self.logging = logger.info | |
self.initialized = False | |
if self.monitor_memory and not PSUTILS_INSTALLED: | |
self.logging( | |
"Unable to import `psutil`, please install package by `pip install psutil`. Set monitor_memory=False" | |
) | |
self.monitor_memory = False | |
self.nvsmi = nvidia_smi.getInstance() if PYNAMY_INSTALLED else None | |
def update_epoch_count(self): | |
self.epoch_count += 1 | |
self.micro_step_count = 0 | |
def _init_timer(self): | |
self.initialized = True | |
def start(self): | |
self._init_timer() | |
self.started = True | |
if self.global_step_count >= self.start_step: | |
torch.cuda.synchronize() | |
self.start_time = time.time() | |
return self.start_time | |
def stop(self, global_step=False, report_speed=True): | |
if not self.started: | |
return | |
self.started = False | |
self.micro_step_count += 1 | |
if global_step: | |
self.global_step_count += 1 | |
if self.start_time > 0: | |
torch.cuda.synchronize() | |
self.end_time = time.time() | |
duration = self.end_time - self.start_time | |
self.total_elapsed_time += duration | |
self.step_elapsed_time += duration | |
if global_step: | |
if report_speed and self.global_step_count % self.steps_per_output == 0: | |
logging_infos = ( | |
f"epoch={self.epoch_count}/micro_step={self.micro_step_count}/" | |
f"global_step={self.global_step_count}, RunningAvgSamplesPerSec={self.avg_samples_per_sec()}," | |
f" CurrSamplesPerSec={self.batch_size / self.step_elapsed_time}," | |
f" MemAllocated={round(torch.cuda.memory_allocated() / 1024**3, 2)}GB," | |
f" MaxMemAllocated={round(torch.cuda.max_memory_allocated() / 1024**3, 2)}GB" | |
) | |
if PYNAMY_INSTALLED: | |
current_node_gpu_mem = [] | |
nvsmi_gpu_memory_usage = self.nvsmi.DeviceQuery( | |
"memory.used, memory.total")["gpu"] | |
for gpu_id, memory_dict in enumerate(nvsmi_gpu_memory_usage): | |
total_memory, used_memory, unit = ( | |
memory_dict["fb_memory_usage"]["total"], | |
memory_dict["fb_memory_usage"]["used"], | |
memory_dict["fb_memory_usage"]["unit"], | |
) | |
current_node_gpu_mem.append( | |
f"GPU{gpu_id}:{int(used_memory)}/{int(total_memory)}{unit}") | |
nvismi_gpu_memory_infos = ",".join( | |
current_node_gpu_mem) | |
logging_infos += ". " + nvismi_gpu_memory_infos | |
self.logging(logging_infos) | |
if self.monitor_memory: | |
virt_mem = psutil.virtual_memory() | |
swap = psutil.swap_memory() | |
self.logging( | |
f"epoch={self.epoch_count}/micro_step={self.micro_step_count}/" | |
f"global_step={self.global_step_count} virtual_memory %: {virt_mem.percent}, " | |
f"swap_memory %: {swap.percent}" | |
) | |
self.step_elapsed_time = 0 | |
return self.end_time | |
def avg_samples_per_sec(self): | |
if self.global_step_count > 0: | |
total_step_offset = self.global_step_count - self.start_step | |
avg_time_per_step = self.total_elapsed_time / total_step_offset | |
# training samples per second | |
return self.batch_size / avg_time_per_step | |
return float("-inf") | |
def main(): | |
args = parse_args() | |
# no gc this server;otherwise server will stop | |
if args.using_xla: | |
pjrt.initialize_multiprocess( | |
os.environ["LOCAL_RANK"], os.environ["WORLD_SIZE"]) | |
device = xm.xla_device() | |
server = xp.start_server(9012) | |
dist.init_process_group('xla', init_method='xla://') | |
else: | |
device = torch.device("cuda:%d" % local_rank()) | |
dist.init_process_group(backend='nccl', init_method='env://', | |
rank=int(os.environ["LOCAL_RANK"]), | |
world_size=int(os.environ["WORLD_SIZE"]) | |
) | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
assert args.logging_steps >= 0, f"logging_steps must bigger or equal than 0 but got {args.logging_steps}." | |
with_tracking = args.logging_steps > 0 and args.output_dir is not None | |
if args.report_to is not None and not with_tracking: | |
logger.info( | |
f"Found args.logging_steps=={args.logging_steps} and args.output_dir=={args.output_dir}." | |
"args.report_to will be ignored." | |
) | |
if args.output_dir is not None and is_main_process() == 0: | |
os.makedirs(args.output_dir, exist_ok=True) | |
logger.info(f"output_dir is {args.output_dir}") | |
config = get_config(args) | |
model = get_model(args, config) | |
tokenizer = get_tokenizer(args) | |
embedding_size = model.get_input_embeddings().weight.shape[0] | |
if len(tokenizer) > embedding_size: | |
model.resize_token_embeddings(len(tokenizer)) | |
num_params, num_trainable_params = model_parameters_num(model) | |
if is_local_main_process(): | |
logger.info( | |
f"Model has {num_params} parameters and {num_trainable_params} " | |
f"trainable parameters({100 * num_trainable_params / num_params:.3f}%)." | |
) | |
if "alpaca" in args.dataset_path: | |
train_dataset = InstructionDataset( | |
args.dataset_path, | |
tokenizer, | |
partition="train", | |
max_words=args.block_size, | |
) | |
eval_dataset = InstructionDataset( | |
args.dataset_path, | |
tokenizer, | |
partition="eval", | |
max_words=args.block_size, | |
) | |
else: | |
raw_datasets = get_dataset(args) | |
tokenized_datasets = tokenize_dataset( | |
args, model, raw_datasets, tokenizer) | |
lm_datasets = process_dataset(args, tokenized_datasets, tokenizer) | |
train_dataset = lm_datasets["train"] | |
eval_dataset = lm_datasets["validation"] | |
# Log a few random samples from the training set: | |
for index in random.sample(range(len(train_dataset)), args.random_log_n_training_samples): | |
logger.info( | |
f"Sample {index} of the training set: {train_dataset[index]}.") | |
# DataLoaders creation: | |
eval_sampler = DistributedSampler(eval_dataset, shuffle=False) | |
eval_dataloader = DataLoader( | |
eval_dataset, | |
shuffle=False, | |
sampler=eval_sampler, | |
collate_fn=default_data_collator, | |
batch_size=args.per_device_eval_batch_size, | |
pin_memory=True, | |
drop_last=True, | |
) | |
dataloader_args = { | |
"shuffle": True, | |
"collate_fn": default_data_collator, | |
"batch_size": args.per_device_train_batch_size, | |
"pin_memory": True, | |
"num_workers": args.dataloader_num_workers, | |
"persistent_workers": args.dataloader_num_workers > 0, | |
} | |
if "amp" in args.precision: | |
pass | |
if args.gradient_checkpointing: | |
pass | |
# strategy.append(("checkpoint", (LlamaDecoderLayer,))) | |
model = model.to(device=device) | |
if args.distributed_method == "ddp": | |
model = DistributedDataParallel(model, gradient_as_bucket_view=True, find_unused_parameters=True, | |
static_graph=True, broadcast_buffers=False) | |
logger.info("Using DDP") | |
else: | |
wrapper_config = {} | |
if args.using_xla: | |
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP | |
from torch_xla.distributed.fsdp.wrap import (size_based_auto_wrap_policy, | |
transformer_auto_wrap_policy) | |
wrapper_config["compute_dtype"] = torch.bfloat16 | |
else: | |
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy | |
mixed_precision_config = MixedPrecision( | |
param_dtype=torch.bfloat16, | |
reduce_dtype=torch.bfloat16, | |
buffer_dtype=torch.bfloat16, | |
) | |
wrapper_config["mixed_precision"] = mixed_precision_config | |
wrapper_config["forward_prefetch"]=True | |
wrap_policy = functools.partial( | |
transformer_auto_wrap_policy, transformer_layer_cls={LlamaDecoderLayer, }) | |
wrapper_config["auto_wrap_policy"] = wrap_policy | |
model = FSDP(model, **wrapper_config).to(device=device) | |
logger.info("Using FSDP") | |
def check_add_to_fsdp(name, module): | |
type_name = type(module).__module__ + "." + type(module).__name__ | |
module_params_numel = sum([p.numel() for p in module.parameters() if p.requires_grad]) | |
if module_params_numel > 5e7: | |
numel_show = "%.2fM" % (module_params_numel / 1e6) | |
else: | |
numel_show = "%.2fK" % (module_params_numel / 1e3) | |
if not name or not isinstance(module, FSDP): | |
if module_params_numel > 1e3: | |
print("name=%s type=%s trainable=%s does not add to fsdp" % | |
(name, type_name, numel_show)) | |
for sub_name, sub_module in module.named_children(): | |
check_add_to_fsdp("%s.%s" % (name, sub_name), sub_module) | |
else: | |
print("name=%s type=%s trainable=%s has add to fsdp" % | |
(name, type_name, numel_show)) | |
check_add_to_fsdp("",model) | |
if args.using_xla: | |
optimizer = syncfree.AdamW(model.parameters(), lr=args.learning_rate) | |
else: | |
optimizer = torch.optim.AdamW( | |
model.parameters(), lr=args.learning_rate) | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, **dataloader_args) | |
if args.using_xla: | |
train_dataloader = pl.MpDeviceLoader( | |
train_dataloader, # wraps PyTorch DataLoader | |
device,) | |
num_devices = xr.global_runtime_device_count() | |
print("xla num_devices", num_devices) | |
overrode_max_train_steps = False | |
num_update_steps_per_epoch = math.ceil( | |
len(train_dataloader) / args.gradient_accumulation_steps) | |
if args.max_train_steps is None: | |
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
overrode_max_train_steps = True | |
warmup_steps = 0 | |
num_training_steps = args.max_train_steps * args.gradient_accumulation_steps | |
if args.warmup_steps > 0: | |
warmup_steps = args.warmup_steps * args.gradient_accumulation_steps | |
elif args.warmup_ratio > 0.0: | |
warmup_steps = int(num_training_steps * args.warmup_ratio) | |
lr_scheduler = get_scheduler( | |
name=args.lr_scheduler_type, | |
optimizer=optimizer, | |
num_warmup_steps=warmup_steps, | |
num_training_steps=num_training_steps, | |
) | |
# We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
num_update_steps_per_epoch = math.ceil( | |
len(train_dataloader) / args.gradient_accumulation_steps) | |
if overrode_max_train_steps: | |
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
# Afterwards we recalculate our number of training epochs | |
args.num_train_epochs = math.ceil( | |
args.max_train_steps / num_update_steps_per_epoch) | |
# Figure out how many steps we should save the Accelerator states | |
checkpointing_steps = args.checkpointing_steps | |
if checkpointing_steps is not None and checkpointing_steps.isdigit(): | |
checkpointing_steps = int(checkpointing_steps) | |
# create a summary_writer to write training metrics to tensorboard | |
summary_writer = None | |
report_to_tb = with_tracking and args.report_to in ("tensorboard", "all") | |
report_to_tb = False | |
if report_to_tb: | |
tb_path = os.path.join(args.output_dir, default_logdir()) | |
summary_writer = SummaryWriter(tb_path) | |
logger.info(f"Tensorboard eventfiles will be saved at {tb_path}") | |
# Train! | |
total_batch_size = args.total_train_batch_size * args.gradient_accumulation_steps | |
if args.total_train_batch_size > 0: | |
per_device_train_batch_size = int( | |
args.total_train_batch_size / world_size()) | |
total_train_batch_size = args.total_train_batch_size | |
elif args.per_device_train_batch_size > 0: | |
per_device_train_batch_size = args.per_device_train_batch_size | |
total_train_batch_size = per_device_train_batch_size * world_size() | |
else: | |
raise ValueError( | |
f"per_device_train_batch_size must greater than 0 but got {per_device_train_batch_size}") | |
flops_per_gpu_per_iteration, _ = compute_training_flops( | |
per_device_train_batch_size, | |
args.block_size, | |
config.hidden_size, | |
config.vocab_size, | |
config.intermediate_size, | |
config.num_hidden_layers, | |
args.gradient_checkpointing, | |
args.peft_type is not None, | |
) | |
tput_timer = ThroughputTimer( | |
total_train_batch_size, start_step=2, steps_per_output=50) | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {len(train_dataset)}") | |
logger.info(f" Num Epochs = {args.num_train_epochs}") | |
logger.info( | |
f" Instantaneous batch size per device = {per_device_train_batch_size}") | |
logger.info( | |
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
logger.info( | |
f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | |
logger.info(f" Total optimization steps = {args.max_train_steps}") | |
# Only show the progress bar once on each machine. | |
progress_bar = tqdm(range(args.max_train_steps), | |
disable=not is_local_main_process()) | |
completed_steps = 0 | |
completed_eval_steps = 0 | |
starting_epoch = 0 | |
# # Potentially load in the weights and states from a previous save | |
if args.resume_from_checkpoint: | |
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": | |
path = os.path.basename(args.resume_from_checkpoint) | |
else: | |
# Get the most recent checkpoint | |
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] | |
dirs.sort(key=os.path.getctime) | |
# Sorts folders by date modified, most recent checkpoint is the last | |
path = dirs[-1] | |
# Extract `epoch_{i}` or `step_{i}` | |
training_difference = os.path.splitext(path)[0] | |
if "epoch" in training_difference: | |
starting_epoch = int(training_difference.replace("epoch_", "")) + 1 | |
resume_step = None | |
else: | |
# need to multiply `gradient_accumulation_steps` to reflect real steps | |
resume_step = int(training_difference.replace( | |
"step_", "")) * args.gradient_accumulation_steps | |
starting_epoch = resume_step // len(train_dataloader) | |
resume_step -= starting_epoch * len(train_dataloader) | |
# update the progress_bar if load from checkpoint | |
progress_bar.update(starting_epoch * num_update_steps_per_epoch) | |
completed_steps = starting_epoch * num_update_steps_per_epoch | |
total_train_losses = [[], []] # steps, loss | |
total_eval_losses = [[], []] # steps, loss | |
all_results = {} | |
training_time = 0 | |
for epoch in range(starting_epoch, args.num_train_epochs): | |
model.train() | |
if with_tracking: | |
total_loss = torch.tensor(0.0, device=model.device) | |
torch.cuda.synchronize() | |
current_epoch_start_time = time.time() | |
start_time = time.time() | |
step = 0 | |
for step, batch in enumerate(train_dataloader): | |
# We need to skip steps until we reach the resumed step | |
optimizer.zero_grad() | |
if args.resume_from_checkpoint and epoch == starting_epoch: | |
if resume_step is not None and step < resume_step: | |
if step % args.gradient_accumulation_steps == 0: | |
progress_bar.update(1) | |
completed_steps += 1 | |
continue | |
if args.using_xla: | |
enable_trace = True | |
if step == 20 and enable_trace: | |
xp.trace_detached('localhost:9012', "./llama_xla_trace") | |
if enable_trace: | |
context = xp.StepTrace('train_loop', step_num=step) | |
else: | |
context = nullcontext() | |
autocast = xla_autocast(device, dtype=torch.bfloat16) | |
else: | |
if step == 20: | |
context = torch.profiler.profile( | |
activities=[ | |
torch.profiler.ProfilerActivity.CPU, | |
torch.profiler.ProfilerActivity.CUDA, | |
], | |
profile_memory=False, | |
with_stack=True, | |
with_modules=True, | |
record_shapes=True, | |
) | |
else: | |
context = nullcontext() | |
autocast = torch_autocast("cuda", dtype=torch.bfloat16) | |
with context as prof: | |
# step_start_timestamp = tput_timer.start() | |
batch = my_prepare_input(batch, device) | |
with autocast: | |
outputs = model(**batch) | |
loss = outputs["loss"] | |
loss.backward() | |
if args.using_xla: | |
# add this will cause loss=nan;if using syncfree optimizer,use this | |
# gradients = xm._fetch_gradients(optimizer) | |
# xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size(),pin_layout=False) | |
found_inf = torch.isnan(loss).to(torch.float32) | |
optimizer.step(found_inf=found_inf) | |
else: | |
optimizer.step() | |
lr_scheduler.step() | |
if step == 20 and not args.using_xla: | |
prof.export_chrome_trace("./llama_trace%d.json" % local_rank()) | |
if step % 20 == 0: | |
rate = per_device_train_batch_size / (time.time() - start_time) | |
if args.using_xla: | |
xm.add_step_closure( | |
_train_update, args=(device, step, loss, rate)) | |
else: | |
_train_update(device, step, loss, rate) | |
start_time = time.time() | |
torch.cuda.synchronize() | |
current_epoch_elapse_time = time.time() - current_epoch_start_time | |
if is_main_process(): | |
logger.info( | |
f"Training epoch {epoch} takes {current_epoch_elapse_time:.3f} seconds.") | |
training_time += current_epoch_elapse_time | |
continue | |
def _init_xla(): | |
import os | |
print("os.environ", os.environ) | |
import torch.nn.parallel.distributed | |
import torch.distributed.utils | |
torch.distributed.utils._verify_param_shape_across_processes = torch.nn.parallel.distributed._verify_param_shape_across_processes = lambda process_group, tensors, logger=None: None | |
if __name__ == "__main__": | |
_init_xla() | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment