Created
September 17, 2024 10:17
-
-
Save alvarobartt/b2a3067d69622a0ef05aee3a113ce73b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%env PROJECT_ID=your-project\n", | |
"%env LOCATION=us-central1\n", | |
"%env BUCKET_URI=gs://your-bucket\n", | |
"%env CONTAINER_URI=us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-pytorch-training-cu121.2-3.transformers.4-42.ubuntu2204.py310" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!gcloud storage buckets create $BUCKET_URI --project $PROJECT_ID --location=$LOCATION --default-storage-class=STANDARD --uniform-bucket-level-access" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"from google.cloud import aiplatform\n", | |
"\n", | |
"aiplatform.init(\n", | |
" project=os.getenv(\"PROJECT_ID\"),\n", | |
" location=os.getenv(\"LOCATION\"),\n", | |
" staging_bucket=os.getenv(\"BUCKET_URI\"),\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%writefile sft.py\n", | |
"from datasets import load_dataset\n", | |
"from transformers import AutoTokenizer\n", | |
"from trl import (\n", | |
" ModelConfig,\n", | |
" SFTConfig,\n", | |
" SFTTrainer,\n", | |
" get_peft_config,\n", | |
" get_quantization_config,\n", | |
" get_kbit_device_map,\n", | |
")\n", | |
"from trl.commands.cli_utils import SFTScriptArguments, TrlParser\n", | |
"\n", | |
"if __name__ == \"__main__\":\n", | |
" parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))\n", | |
" args, training_args, model_config = parser.parse_args_and_config()\n", | |
"\n", | |
" ################\n", | |
" # Model init kwargs & Tokenizer\n", | |
" ################\n", | |
" quantization_config = get_quantization_config(model_config)\n", | |
" model_kwargs = dict(\n", | |
" revision=model_config.model_revision,\n", | |
" trust_remote_code=model_config.trust_remote_code,\n", | |
" attn_implementation=model_config.attn_implementation,\n", | |
" torch_dtype=model_config.torch_dtype,\n", | |
" use_cache=False if training_args.gradient_checkpointing else True,\n", | |
" device_map=get_kbit_device_map() if quantization_config is not None else None,\n", | |
" quantization_config=quantization_config,\n", | |
" )\n", | |
" training_args.model_init_kwargs = model_kwargs\n", | |
" tokenizer = AutoTokenizer.from_pretrained(\n", | |
" model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True\n", | |
" )\n", | |
" tokenizer.pad_token = tokenizer.eos_token\n", | |
"\n", | |
" ################\n", | |
" # Dataset\n", | |
" ################\n", | |
" train_dataset = load_dataset(\"json\", data_files=[args.dataset_name])\n", | |
"\n", | |
" ################\n", | |
" # Training\n", | |
" ################\n", | |
" trainer = SFTTrainer(\n", | |
" model=model_config.model_name_or_path,\n", | |
" args=training_args,\n", | |
" train_dataset=train_dataset,\n", | |
" tokenizer=tokenizer,\n", | |
" peft_config=get_peft_config(model_config),\n", | |
" )\n", | |
"\n", | |
" trainer.train()\n", | |
" trainer.save_model(training_args.output_dir)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"job = aiplatform.CustomJob.from_local_script(\n", | |
" display_name=\"sft-lora\",\n", | |
" script_path=\"sft.py\",\n", | |
" container_uri=os.getenv(\"CONTAINER_URI\"),\n", | |
" requirements=[\"gcsfs==0.7.1\"],\n", | |
" replica_count=1,\n", | |
" args=[\n", | |
" '--dataset', 'gs://my-bucket/my-dataset',\n", | |
" '--model_output_uri', 'gs://my-bucket/model',\n", | |
" ],\n", | |
" args = [\n", | |
" # MODEL\n", | |
" \"--model_name_or_path=mistralai/Mistral-7B-v0.3\",\n", | |
" \"--torch_dtype=bfloat16\",\n", | |
" \"--attn_implementation=flash_attention_2\",\n", | |
" # DATASET\n", | |
" f\"--dataset_name={os.getenv('BUCKET_URI').replace('gs://', '/gcs/')}/dataset.jsonl\",\n", | |
" \"--dataset_text_field=text\",\n", | |
" # PEFT\n", | |
" \"--use_peft\",\n", | |
" \"--lora_r=16\",\n", | |
" \"--lora_alpha=32\",\n", | |
" \"--lora_dropout=0.1\",\n", | |
" \"--lora_target_modules=all-linear\",\n", | |
" # TRAINER\n", | |
" \"--bf16\",\n", | |
" \"--max_seq_length=1024\",\n", | |
" \"--per_device_train_batch_size=2\",\n", | |
" \"--gradient_accumulation_steps=8\",\n", | |
" \"--gradient_checkpointing\",\n", | |
" \"--learning_rate=0.0002\",\n", | |
" \"--lr_scheduler_type=cosine\",\n", | |
" \"--optim=adamw_bnb_8bit\",\n", | |
" \"--num_train_epochs=1\",\n", | |
" \"--logging_steps=10\",\n", | |
" \"--report_to=none\",\n", | |
" f\"--output_dir={os.getenv('BUCKET_URI').replace('gs://', '/gcs/')}/lora-ft\",\n", | |
" \"--overwrite_output_dir\",\n", | |
" \"--seed=42\",\n", | |
" \"--log_level=debug\",\n", | |
" ]\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from huggingface_hub import get_token\n", | |
"\n", | |
"job.submit(\n", | |
" args=args,\n", | |
" replica_count=1,\n", | |
" machine_type=\"g2-standard-12\",\n", | |
" accelerator_type=\"NVIDIA_L4\",\n", | |
" accelerator_count=1,\n", | |
" base_output_dir=f\"{os.getenv('BUCKET_URI')}/lora-ft\",\n", | |
" environment_variables={\n", | |
" \"HF_HOME\": \"/root/.cache/huggingface\",\n", | |
" \"HF_TOKEN\": get_token(),\n", | |
" \"TRL_USE_RICH\": \"0\",\n", | |
" \"ACCELERATE_LOG_LEVEL\": \"INFO\",\n", | |
" \"TRANSFORMERS_LOG_LEVEL\": \"INFO\",\n", | |
" \"TQDM_POSITION\": \"-1\",\n", | |
" },\n", | |
" timeout=60 * 60 * 3, # 3 hours (10800s)\n", | |
" create_request_timeout=60 * 10, # 10 minutes (600s)\n", | |
")" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.14" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment