Skip to content

Instantly share code, notes, and snippets.

@alvarobartt
Created September 17, 2024 10:17
Show Gist options
  • Save alvarobartt/b2a3067d69622a0ef05aee3a113ce73b to your computer and use it in GitHub Desktop.
Save alvarobartt/b2a3067d69622a0ef05aee3a113ce73b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%env PROJECT_ID=your-project\n",
"%env LOCATION=us-central1\n",
"%env BUCKET_URI=gs://your-bucket\n",
"%env CONTAINER_URI=us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-pytorch-training-cu121.2-3.transformers.4-42.ubuntu2204.py310"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!gcloud storage buckets create $BUCKET_URI --project $PROJECT_ID --location=$LOCATION --default-storage-class=STANDARD --uniform-bucket-level-access"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from google.cloud import aiplatform\n",
"\n",
"aiplatform.init(\n",
" project=os.getenv(\"PROJECT_ID\"),\n",
" location=os.getenv(\"LOCATION\"),\n",
" staging_bucket=os.getenv(\"BUCKET_URI\"),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile sft.py\n",
"from datasets import load_dataset\n",
"from transformers import AutoTokenizer\n",
"from trl import (\n",
" ModelConfig,\n",
" SFTConfig,\n",
" SFTTrainer,\n",
" get_peft_config,\n",
" get_quantization_config,\n",
" get_kbit_device_map,\n",
")\n",
"from trl.commands.cli_utils import SFTScriptArguments, TrlParser\n",
"\n",
"if __name__ == \"__main__\":\n",
" parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))\n",
" args, training_args, model_config = parser.parse_args_and_config()\n",
"\n",
" ################\n",
" # Model init kwargs & Tokenizer\n",
" ################\n",
" quantization_config = get_quantization_config(model_config)\n",
" model_kwargs = dict(\n",
" revision=model_config.model_revision,\n",
" trust_remote_code=model_config.trust_remote_code,\n",
" attn_implementation=model_config.attn_implementation,\n",
" torch_dtype=model_config.torch_dtype,\n",
" use_cache=False if training_args.gradient_checkpointing else True,\n",
" device_map=get_kbit_device_map() if quantization_config is not None else None,\n",
" quantization_config=quantization_config,\n",
" )\n",
" training_args.model_init_kwargs = model_kwargs\n",
" tokenizer = AutoTokenizer.from_pretrained(\n",
" model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True\n",
" )\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
" ################\n",
" # Dataset\n",
" ################\n",
" train_dataset = load_dataset(\"json\", data_files=[args.dataset_name])\n",
"\n",
" ################\n",
" # Training\n",
" ################\n",
" trainer = SFTTrainer(\n",
" model=model_config.model_name_or_path,\n",
" args=training_args,\n",
" train_dataset=train_dataset,\n",
" tokenizer=tokenizer,\n",
" peft_config=get_peft_config(model_config),\n",
" )\n",
"\n",
" trainer.train()\n",
" trainer.save_model(training_args.output_dir)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"job = aiplatform.CustomJob.from_local_script(\n",
" display_name=\"sft-lora\",\n",
" script_path=\"sft.py\",\n",
" container_uri=os.getenv(\"CONTAINER_URI\"),\n",
" requirements=[\"gcsfs==0.7.1\"],\n",
" replica_count=1,\n",
" args=[\n",
" '--dataset', 'gs://my-bucket/my-dataset',\n",
" '--model_output_uri', 'gs://my-bucket/model',\n",
" ],\n",
" args = [\n",
" # MODEL\n",
" \"--model_name_or_path=mistralai/Mistral-7B-v0.3\",\n",
" \"--torch_dtype=bfloat16\",\n",
" \"--attn_implementation=flash_attention_2\",\n",
" # DATASET\n",
" f\"--dataset_name={os.getenv('BUCKET_URI').replace('gs://', '/gcs/')}/dataset.jsonl\",\n",
" \"--dataset_text_field=text\",\n",
" # PEFT\n",
" \"--use_peft\",\n",
" \"--lora_r=16\",\n",
" \"--lora_alpha=32\",\n",
" \"--lora_dropout=0.1\",\n",
" \"--lora_target_modules=all-linear\",\n",
" # TRAINER\n",
" \"--bf16\",\n",
" \"--max_seq_length=1024\",\n",
" \"--per_device_train_batch_size=2\",\n",
" \"--gradient_accumulation_steps=8\",\n",
" \"--gradient_checkpointing\",\n",
" \"--learning_rate=0.0002\",\n",
" \"--lr_scheduler_type=cosine\",\n",
" \"--optim=adamw_bnb_8bit\",\n",
" \"--num_train_epochs=1\",\n",
" \"--logging_steps=10\",\n",
" \"--report_to=none\",\n",
" f\"--output_dir={os.getenv('BUCKET_URI').replace('gs://', '/gcs/')}/lora-ft\",\n",
" \"--overwrite_output_dir\",\n",
" \"--seed=42\",\n",
" \"--log_level=debug\",\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import get_token\n",
"\n",
"job.submit(\n",
" args=args,\n",
" replica_count=1,\n",
" machine_type=\"g2-standard-12\",\n",
" accelerator_type=\"NVIDIA_L4\",\n",
" accelerator_count=1,\n",
" base_output_dir=f\"{os.getenv('BUCKET_URI')}/lora-ft\",\n",
" environment_variables={\n",
" \"HF_HOME\": \"/root/.cache/huggingface\",\n",
" \"HF_TOKEN\": get_token(),\n",
" \"TRL_USE_RICH\": \"0\",\n",
" \"ACCELERATE_LOG_LEVEL\": \"INFO\",\n",
" \"TRANSFORMERS_LOG_LEVEL\": \"INFO\",\n",
" \"TQDM_POSITION\": \"-1\",\n",
" },\n",
" timeout=60 * 60 * 3, # 3 hours (10800s)\n",
" create_request_timeout=60 * 10, # 10 minutes (600s)\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment