Created
July 23, 2023 23:59
-
-
Save conceptofmind/f27822cdafcc165e490b20a281192649 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from datasets import load_dataset | |
import argparse | |
import os | |
import math | |
import time | |
import random | |
import wandb | |
from huggingface_hub import HfApi, HfFolder | |
from huggingface_hub.utils._errors import HfHubHTTPError | |
from itertools import chain | |
from datetime import timedelta | |
from torch.utils.data import DataLoader | |
import yaml | |
from accelerate import Accelerator | |
from accelerate.utils import (DummyOptim, DummyScheduler, | |
InitProcessGroupKwargs, set_seed) | |
from tqdm import tqdm | |
from transformers import LlamaTokenizer, LlamaForCausalLM | |
from transformers import (AutoModelForCausalLM, AutoTokenizer, get_scheduler, | |
set_seed, default_data_collator) | |
class CFG: | |
#3B bs - 18 - z2 offload - activation checkpointing - 2k - A100 80GB - 3e-5 | |
#7b bs - 13 - z2 offload - activation checkpointing - 2k - A100 80GB - 3e-5 | |
#13b bs - 6 - z2 offload - activation checkpointing - 2k - A100 80GB - 2e-5 | |
#7b bs - 5 - z2 offload - activation checkpointing - 2k - A100 40GB - 3e-5 | |
BATCH_SIZE: int = 8 | |
GRADIENT_ACCUMULATE_EVERY: int = 1 | |
RESUME_FROM_CHECKPOINT: str = None | |
CHECKPOINTING_STEPS: int = 500 | |
OUTPUT_DIR: str = "" | |
ENTITY_NAME: str = "" | |
def main(): | |
wandb.login( | |
key="" | |
) | |
set_seed(42) | |
timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000)) | |
accelerator = Accelerator( | |
gradient_accumulation_steps=CFG.GRADIENT_ACCUMULATE_EVERY, | |
mixed_precision="bf16", | |
log_with="wandb", | |
kwargs_handlers=[timeout] | |
) | |
accelerator.init_trackers( | |
project_name="open_llama", | |
init_kwargs={"wandb": {"entity": CFG.ENTITY_NAME}}, | |
) | |
accelerator.print(f"Total GPUS: {accelerator.num_processes}") | |
tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_3b") | |
model = LlamaForCausalLM.from_pretrained( | |
"openlm-research/open_llama_3b", | |
use_cache=False, | |
) | |
model.gradient_checkpointing_enable() | |
accelerator.print(f"Training a {model.num_parameters():,} parameter model") | |
# Dataloaders | |
#with accelerator.main_process_first(): | |
train_dataset = load_dataset('conceptofmind/tasksource-instruct-open-llama-2k', split = 'train') | |
train_loader = DataLoader( | |
train_dataset, | |
collate_fn=default_data_collator, | |
shuffle=True, | |
batch_size=CFG.BATCH_SIZE | |
) | |
# Dummy Optimizer for DeepSpeed | |
optim = DummyOptim( | |
model.parameters(), | |
lr=2e-5 | |
) | |
# Determine number of training steps | |
max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY) | |
accelerator.print(f"Max train steps: {max_train_steps}") | |
# Dummy Scheduler for DeepSpeed | |
scheduler = DummyScheduler( | |
optim, | |
total_num_steps=max_train_steps, | |
warmup_num_steps=int((max_train_steps * 0.01) / accelerator.num_processes) | |
) | |
# prepare | |
model, optim, train_loader, scheduler = accelerator.prepare( | |
model, optim, train_loader, scheduler | |
) | |
# checkpoint scheduler | |
accelerator.register_for_checkpointing(scheduler) | |
# Recalculate | |
max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY) | |
accelerator.print(f"Max train steps recalculated: {max_train_steps}") | |
# Total batch size for logging | |
total_batch_size = ( | |
CFG.BATCH_SIZE * accelerator.num_processes * CFG.GRADIENT_ACCUMULATE_EVERY | |
) | |
accelerator.print(f"Total batch size: {total_batch_size}") | |
# resume training | |
progress_bar = tqdm( | |
range(max_train_steps), disable=not accelerator.is_local_main_process | |
) | |
completed_steps = 0 | |
if CFG.RESUME_FROM_CHECKPOINT: | |
if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "": | |
accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}") | |
accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT) | |
path = os.path.basename(CFG.RESUME_FROM_CHECKPOINT) | |
training_difference = os.path.splitext(path)[0] | |
resume_step = ( | |
int(training_difference.replace("step_", "")) | |
) | |
if CFG.RESUME_FROM_CHECKPOINT and resume_step is not None: | |
# We need to skip steps until we reach the resumed step | |
train_loader = accelerator.skip_first_batches(train_loader, resume_step) | |
completed_steps += resume_step | |
progress_bar.update(resume_step) | |
accelerator.print(f"Resuming training from step {resume_step}") | |
# training | |
model.train() | |
for step, batch in enumerate(train_loader): | |
with accelerator.accumulate(model): | |
inputs = batch["input_ids"] | |
labels = batch["input_ids"] | |
loss = model(inputs, labels=labels).loss | |
accelerator.backward(loss) | |
accelerator.log({"loss": loss.item()}, step=step) | |
if accelerator.sync_gradients: | |
accelerator.clip_grad_norm_(model.parameters(), 1.0) | |
optim.step() | |
scheduler.step() | |
optim.zero_grad() | |
if accelerator.sync_gradients: | |
progress_bar.update(1) | |
completed_steps += 1 | |
if isinstance(CFG.CHECKPOINTING_STEPS, int): | |
if completed_steps % CFG.CHECKPOINTING_STEPS == 0: | |
output_dir = f"step_{completed_steps}" | |
if CFG.OUTPUT_DIR is not None: | |
output_dir = os.path.join(CFG.OUTPUT_DIR, output_dir) | |
accelerator.save_state(output_dir) | |
if completed_steps >= max_train_steps: | |
break | |
# end training | |
accelerator.print(f"Training Finished") | |
accelerator.end_training() | |
# save final model | |
accelerator.print(f"Saving model to {CFG.OUTPUT_DIR}") | |
if CFG.OUTPUT_DIR is not None: | |
accelerator.wait_for_everyone() | |
unwrapped_model = accelerator.unwrap_model(model) | |
unwrapped_model.save_pretrained( | |
f"{CFG.OUTPUT_DIR}/final/open_llama_2k_3b-test/", | |
is_main_process=accelerator.is_main_process, | |
save_function=accelerator.save, | |
state_dict=accelerator.get_state_dict(model), | |
) | |
max_retries = 5 | |
for attempt in range(max_retries): | |
try: | |
with accelerator.main_process_first(): | |
unwrapped_model.push_to_hub("Open-Llama-3b-test", private=True) | |
print(f"Pushed to hub after {max_retries} attempts.") | |
break | |
except HfHubHTTPError as e: | |
wait_time = random.uniform(1, 10) # wait between 1 to 10 seconds | |
print(f"Attempt {attempt + 1} failed, waiting for {wait_time:.2f} seconds before retrying...") | |
time.sleep(wait_time) | |
else: | |
print(f"Failed to push to hub after {max_retries} attempts.") | |
with accelerator.main_process_first(): | |
tokenizer.push_to_hub("Open-Llama-3b-test", private=True) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment