Last active
February 27, 2024 20:55
-
-
Save erfanzar/b7e21ade936909a9cd1154e76cad4491 to your computer and use it in GitHub Desktop.
Fine-tune Phi-2 with EasyDeL
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fjformer | |
datasets | |
gradio | |
wandb | |
sentencepiece | |
transformers==4.38.0 | |
jax[tpu]==0.4.22 | |
-e git+https://github.com/erfanzar/EasyDeL.git#egg=EasyDeL | |
tensorflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from EasyDel import JAXServer, JAXServerConfig, EasyServe, EasyDelState, easystate_to_huggingface_model, AutoEasyDelConfig | |
from fjformer import get_dtype | |
from EasyDel.serve.prompters import GemmaPrompter, Llama2Prompter, OpenChatPrompter, Qwen2Prompter | |
from EasyDel.serve.prompters.base_prompter import BasePrompter | |
from huggingface_hub import hf_hub_download | |
from transformers import PhiForCausalLM as TorchModule, AutoTokenizer | |
from jax import numpy as jnp, lax | |
import jax | |
import torch | |
from typing import List, Union, Optional | |
max_sequence_length = 1536 | |
max_compile_tokens = 32 | |
max_new_tokens_ratio = 25 | |
dtype = "fp16" | |
prompter_type = "chatml" | |
sharding_axis_dims = (1, -1, 1, 1) | |
pretrained_model_name_or_path = "REPO_ID" | |
attn_mechanism = "normal" | |
scan_mlp_chunk_size = max_compile_tokens | |
use_scan_mlp = True | |
scan_ring_attention = True | |
block_k = 128 | |
block_q = 128 | |
use_sharded_kv_caching = False | |
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) | |
prompters = { | |
"gemma": GemmaPrompter(), | |
"llama": Llama2Prompter(), | |
"openchat": OpenChatPrompter(), | |
"chatml": Qwen2Prompter() | |
} | |
prompter: BasePrompter = prompters[prompter_type] | |
server_config = JAXServerConfig( | |
max_sequence_length=max_sequence_length, | |
max_compile_tokens=max_compile_tokens, | |
max_new_tokens=max_compile_tokens * max_new_tokens_ratio, | |
dtype=dtype, | |
pre_compile=False, | |
eos_token_id=tokenizer.encode(prompter.end_of_turn_token)[0], | |
bos_token_id=tokenizer.encode(prompter.end_of_turn_token)[0], | |
) | |
class JAXServerC(JAXServer): | |
@staticmethod | |
def format_chat(history: List[List[str]], prompt: str, system: Union[str, None]) -> str: | |
return prompter.format_message( | |
history=history, | |
prompt=prompt, | |
system_message=system, | |
prefix=None | |
) | |
@staticmethod | |
def format_instruct(system: str, instruction: str) -> str: | |
return prompter.format_message( | |
prefix=None, | |
system_message=system, | |
prompt=instruction, | |
history=[] | |
) | |
server = JAXServerC.from_torch_pretrained( | |
server_config=server_config, | |
pretrained_model_name_or_path=pretrained_model_name_or_path, | |
device=jax.devices('cpu')[0], | |
dtype=get_dtype(dtype=dtype), | |
param_dtype=get_dtype(dtype=dtype), | |
precision=jax.lax.Precision("fastest"), | |
sharding_axis_dims=sharding_axis_dims, | |
sharding_axis_names=("dp", "fsdp", "tp", "sp"), | |
input_shape=(1, server_config.max_sequence_length), | |
device_map="auto", | |
torch_dtype=torch.float16, | |
model_config_kwargs=dict( | |
fully_sharded_data_parallel=True, | |
attn_mechanism=attn_mechanism, | |
scan_mlp_chunk_size=max_compile_tokens, | |
use_scan_mlp=use_scan_mlp, | |
scan_ring_attention=scan_ring_attention, | |
block_k=block_k, | |
block_q=block_q, | |
use_sharded_kv_caching=use_sharded_kv_caching | |
) | |
) | |
history = [] | |
while True: | |
user_prompt = input("> ") | |
model_prompt = server.format_chat(history, user_prompt, "You are an AI be polite and respectful as possible and help user in tasks of debugging code and writing code.") | |
past_response_length = 0 | |
for response, used_tokens in server.sample( | |
model_prompt, | |
greedy=False | |
): | |
response = response.replace("<|im_end|>","") | |
print(response[past_response_length:], end="") | |
past_response_length = len(response) | |
history.append([user_prompt, response]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from EasyDel import ( | |
AutoEasyDelModelForCausalLM, | |
TrainArguments, | |
CausalLanguageModelTrainer, | |
EasyDelOptimizers, | |
EasyDelSchedulers, | |
EasyDelGradientCheckPointers, | |
EasyDelState, | |
EasyDeLXRapTureConfig, | |
get_modules_by_type, | |
easystate_to_huggingface_model | |
) | |
from datasets import load_dataset | |
from jax.sharding import PartitionSpec | |
from flax.core import FrozenDict | |
from transformers import AutoTokenizer | |
from jax import numpy as jnp | |
import jax | |
from transformers import PhiForCausalLM as ModuleTorch | |
from huggingface_hub import HfApi, hf_hub_download | |
from fjformer import GenerateRNG, save_ckpt, get_dtype | |
rng_g = GenerateRNG() | |
api = HfApi() | |
sharding_axis_dims = (1, -1, 1, 1) # Using FSDP instead of SeqP | |
max_length = 2048 | |
input_shape = (8, max_length) # 8 TPU devices are available | |
pretrained_model_name_or_path = "microsoft/phi-2" | |
model, params = AutoEasyDelModelForCausalLM.from_pretrained( | |
pretrained_model_name_or_path, | |
device = jax.devices('cpu')[0], | |
input_shape = input_shape, | |
device_map = "auto", | |
sharding_axis_dims = sharding_axis_dims | |
) | |
config = model.config | |
model_parameters = FrozenDict({"params" : params}) | |
dtype = jnp.bfloat16 | |
use_lora = False | |
config.add_basic_configurations( | |
attn_mechanism="flash", | |
block_b=1, | |
block_q=128, | |
block_k=128, | |
block_k_major=128, | |
use_shard_map=False | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"erfanzar/phi-2", | |
trust_remote_code=True | |
) | |
configs_to_initialize_model_class={ | |
"config" : config, | |
"dtype" : dtype, | |
"param_dtype" : dtype, | |
"input_shape" : input_shape | |
} | |
if tokenizer.pad_token == None: | |
tokenizer.pad_token = tokenizer.eos_token | |
rapture_config = EasyDeLXRapTureConfig( | |
model_parameters, | |
lora_dim=64, | |
fully_fine_tune_parameters=["embed_tokens"], | |
lora_fine_tune_parameters=["q_proj", "v_proj", "k_proj", "o_proj"], | |
verbose=True | |
) if use_lora else None | |
dataset = load_dataset( | |
"erfanzar/GPT-4-Prompts", | |
split="train", | |
) | |
def tokenization_process(data_chunk) -> dict: | |
return tokenizer( | |
data_chunk["chatml_prompt"], | |
add_special_tokens=False, | |
max_length=max_length, | |
padding="max_length" | |
) | |
dataset = dataset.map( | |
tokenization_process, | |
num_proc=18, | |
remove_columns=dataset.column_names | |
) | |
rules = ( | |
("embed_tokens/embedding", PartitionSpec(("fsdp", "sp"), )), | |
("final_layernorm/(scale|bias)", PartitionSpec(None, )), | |
("final_layernorm/(scale|bias)", PartitionSpec(None, )), | |
("mlp/fc1/kernel", PartitionSpec(("fsdp", "sp"), "tp")), | |
("mlp/fc1/bias", PartitionSpec("tp", )), | |
("mlp/fc2/kernel", PartitionSpec(("fsdp", "sp"),"tp", )), | |
("mlp/fc2/bias", PartitionSpec("tp", )), | |
("self_attn/dense/kernel", PartitionSpec(("fsdp", "sp"), "tp")), | |
("self_attn/dense/bias", PartitionSpec("tp")), | |
("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")), | |
("self_attn/(q_proj|k_proj|v_proj)/bias", PartitionSpec("tp", )), | |
("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")), | |
("lm_head/bias", PartitionSpec("tp")), | |
(".*", PartitionSpec(None, )) | |
) | |
config.get_partition_rules = lambda x:rules | |
train_args = TrainArguments( | |
model_class=get_modules_by_type(config.model_type)[1], | |
configs_to_initialize_model_class=configs_to_initialize_model_class, | |
custom_rule=config.get_partition_rules(True), | |
model_name="TLLM", | |
num_train_epochs=3, | |
learning_rate=2e-5, | |
learning_rate_end=7e-6, | |
warmup_steps=50, | |
optimizer=EasyDelOptimizers.ADAMW, | |
scheduler=EasyDelSchedulers.LINEAR, | |
weight_decay=0.02, | |
total_batch_size=16, | |
max_sequence_length=max_length, | |
gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE, | |
sharding_array=sharding_axis_dims, | |
use_pjit_attention_force=False, | |
gradient_accumulation_steps=1, | |
init_input_shape=input_shape, | |
dtype=dtype, | |
param_dtype=dtype, | |
step_start_point=0, | |
training_time="8H", | |
rapture_config=rapture_config, | |
) | |
trainer = CausalLanguageModelTrainer( | |
train_args, | |
dataset.shuffle().shuffle().shuffle(), | |
checkpoint_path=None | |
) | |
output = trainer.train( | |
model_parameters=model_parameters if not use_lora else None, | |
state=None | |
) | |
new_repo_id = "Phi-2-TLLM" | |
try: | |
api.create_repo(new_repo_id, private=False) | |
except: | |
... | |
api.upload_file( | |
path_or_fileobj=output.checkpoint_path, | |
repo_id=new_repo_id, | |
path_in_repo=output.last_save_file_name | |
) | |
with jax.default_device(jax.devices("cpu")[0]): | |
model = easystate_to_huggingface_model( | |
state=EasyDelState.load_state( | |
output.checkpoint_path | |
), | |
base_huggingface_module=ModuleTorch, | |
config=config | |
) | |
half_model = model.half() | |
half_model.push_to_hub(new_repo_id) | |
tokenizer.push_to_hub(new_repo_id) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment