Last active
November 9, 2024 02:33
-
-
Save lewtun/2a4c2569efeecdf5f68a7026249521f5 to your computer and use it in GitHub Desktop.
SFT Llama 3.1 8B - full training vs LoRA vs QLoRA
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2023 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
# Everything below was run on 8 x H100 GPUs | |
# Full training with packing (learns chat template) | |
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml scratch/sft_llama.py \ | |
--model_name_or_path meta-llama/Llama-3.1-8B \ | |
--dataset_name trl-lib/Capybara \ | |
--report_to wandb \ | |
--learning_rate 2.0e-5 \ | |
--per_device_train_batch_size 8 \ | |
--gradient_accumulation_steps 1 \ | |
--gradient_checkpointing \ | |
--output_dir Llama-3.1-8B-SFT-full-packing \ | |
--logging_steps 10 \ | |
--num_train_epochs 1 \ | |
--push_to_hub \ | |
--packing \ | |
--bf16 | |
# LoRA with packing and all-linear modules and lm_head and embed_tokens (learns chat template) | |
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml scratch/sft_llama.py \ | |
--model_name_or_path meta-llama/Llama-3.1-8B \ | |
--dataset_name trl-lib/Capybara \ | |
--report_to wandb \ | |
--learning_rate 2.0e-4 \ | |
--per_device_train_batch_size 8 \ | |
--gradient_accumulation_steps 1 \ | |
--gradient_checkpointing \ | |
--ddp_find_unused_parameters False \ | |
--output_dir Llama-3.1-8B-SFT-LoRA-packing \ | |
--logging_steps 10 \ | |
--num_train_epochs 1 \ | |
--push_to_hub \ | |
--use_peft \ | |
--lora_r 16 \ | |
--lora_alpha 32 \ | |
--lora_target_modules all-linear \ | |
--lora_modules_to_save lm_head embed_tokens \ | |
--packing | |
# LoRA with packing and all-linear modules (doesn't learn chat template!) | |
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml scratch/sft_llama.py \ | |
--model_name_or_path meta-llama/Llama-3.1-8B \ | |
--dataset_name trl-lib/Capybara \ | |
--report_to wandb \ | |
--learning_rate 2.0e-4 \ | |
--per_device_train_batch_size 8 \ | |
--gradient_accumulation_steps 1 \ | |
--gradient_checkpointing \ | |
--ddp_find_unused_parameters False \ | |
--output_dir Llama-3.1-8B-SFT-LoRA-packing-no-saved-modules \ | |
--logging_steps 10 \ | |
--num_train_epochs 1 \ | |
--push_to_hub \ | |
--use_peft \ | |
--lora_r 16 \ | |
--lora_alpha 32 \ | |
--lora_target_modules all-linear \ | |
--packing | |
# QLoRA with packing (learns chat template) | |
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml scratch/sft_llama.py \ | |
--model_name_or_path meta-llama/Llama-3.1-8B \ | |
--dataset_name trl-lib/Capybara \ | |
--report_to wandb \ | |
--learning_rate 2.0e-4 \ | |
--per_device_train_batch_size 8 \ | |
--gradient_accumulation_steps 1 \ | |
--gradient_checkpointing \ | |
--ddp_find_unused_parameters False \ | |
--output_dir Llama-3.1-8B-SFT-QLoRA-packing \ | |
--logging_steps 10 \ | |
--num_train_epochs 1 \ | |
--push_to_hub \ | |
--use_peft \ | |
--lora_r 16 \ | |
--lora_alpha 32 \ | |
--lora_target_modules all-linear \ | |
--lora_modules_to_save lm_head embed_tokens \ | |
--packing \ | |
--load_in_4bit | |
""" | |
from datasets import load_dataset | |
from transformers import AutoTokenizer | |
from trl import ( | |
ModelConfig, | |
SFTConfig, | |
SFTScriptArguments, | |
SFTTrainer, | |
TrlParser, | |
get_kbit_device_map, | |
get_peft_config, | |
get_quantization_config, | |
) | |
if __name__ == "__main__": | |
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) | |
script_args, training_args, model_config = parser.parse_args_and_config() | |
print(f"Script args: {script_args}\n") | |
print(f"Training args: {training_args}\n") | |
print(f"Model config: {model_config}\n") | |
################ | |
# Model init kwargs & Tokenizer | |
################ | |
quantization_config = get_quantization_config(model_config) | |
model_kwargs = dict( | |
revision=model_config.model_revision, | |
trust_remote_code=model_config.trust_remote_code, | |
attn_implementation=model_config.attn_implementation, | |
torch_dtype=model_config.torch_dtype, | |
use_cache=False if training_args.gradient_checkpointing else True, | |
device_map=get_kbit_device_map() if quantization_config is not None else None, | |
quantization_config=quantization_config, | |
) | |
training_args.model_init_kwargs = model_kwargs | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True | |
) | |
tokenizer.pad_token = tokenizer.eos_token | |
print(f"Using pad token: {tokenizer.pad_token}") | |
# Set Llama chat template | |
tokenizer.chat_template = "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" | |
################ | |
# Dataset | |
################ | |
dataset = load_dataset(script_args.dataset_name) | |
################ | |
# Training | |
################ | |
trainer = SFTTrainer( | |
model=model_config.model_name_or_path, | |
args=training_args, | |
train_dataset=dataset[script_args.dataset_train_split], | |
eval_dataset=dataset[script_args.dataset_test_split], | |
tokenizer=tokenizer, | |
peft_config=get_peft_config(model_config), | |
) | |
trainer.train() | |
# Save and push to hub | |
trainer.save_model(training_args.output_dir) | |
if training_args.push_to_hub: | |
trainer.push_to_hub(dataset_name=script_args.dataset_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Run inference tests with