Skip to content

Instantly share code, notes, and snippets.

@erfanzar
Last active February 27, 2024 20:55
Show Gist options
  • Save erfanzar/b7e21ade936909a9cd1154e76cad4491 to your computer and use it in GitHub Desktop.
Save erfanzar/b7e21ade936909a9cd1154e76cad4491 to your computer and use it in GitHub Desktop.
Fine-tune Phi-2 with EasyDeL
fjformer
datasets
gradio
wandb
sentencepiece
transformers==4.38.0
jax[tpu]==0.4.22
-e git+https://github.com/erfanzar/EasyDeL.git#egg=EasyDeL
tensorflow
from EasyDel import JAXServer, JAXServerConfig, EasyServe, EasyDelState, easystate_to_huggingface_model, AutoEasyDelConfig
from fjformer import get_dtype
from EasyDel.serve.prompters import GemmaPrompter, Llama2Prompter, OpenChatPrompter, Qwen2Prompter
from EasyDel.serve.prompters.base_prompter import BasePrompter
from huggingface_hub import hf_hub_download
from transformers import PhiForCausalLM as TorchModule, AutoTokenizer
from jax import numpy as jnp, lax
import jax
import torch
from typing import List, Union, Optional
max_sequence_length = 1536
max_compile_tokens = 32
max_new_tokens_ratio = 25
dtype = "fp16"
prompter_type = "chatml"
sharding_axis_dims = (1, -1, 1, 1)
pretrained_model_name_or_path = "REPO_ID"
attn_mechanism = "normal"
scan_mlp_chunk_size = max_compile_tokens
use_scan_mlp = True
scan_ring_attention = True
block_k = 128
block_q = 128
use_sharded_kv_caching = False
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
prompters = {
"gemma": GemmaPrompter(),
"llama": Llama2Prompter(),
"openchat": OpenChatPrompter(),
"chatml": Qwen2Prompter()
}
prompter: BasePrompter = prompters[prompter_type]
server_config = JAXServerConfig(
max_sequence_length=max_sequence_length,
max_compile_tokens=max_compile_tokens,
max_new_tokens=max_compile_tokens * max_new_tokens_ratio,
dtype=dtype,
pre_compile=False,
eos_token_id=tokenizer.encode(prompter.end_of_turn_token)[0],
bos_token_id=tokenizer.encode(prompter.end_of_turn_token)[0],
)
class JAXServerC(JAXServer):
@staticmethod
def format_chat(history: List[List[str]], prompt: str, system: Union[str, None]) -> str:
return prompter.format_message(
history=history,
prompt=prompt,
system_message=system,
prefix=None
)
@staticmethod
def format_instruct(system: str, instruction: str) -> str:
return prompter.format_message(
prefix=None,
system_message=system,
prompt=instruction,
history=[]
)
server = JAXServerC.from_torch_pretrained(
server_config=server_config,
pretrained_model_name_or_path=pretrained_model_name_or_path,
device=jax.devices('cpu')[0],
dtype=get_dtype(dtype=dtype),
param_dtype=get_dtype(dtype=dtype),
precision=jax.lax.Precision("fastest"),
sharding_axis_dims=sharding_axis_dims,
sharding_axis_names=("dp", "fsdp", "tp", "sp"),
input_shape=(1, server_config.max_sequence_length),
device_map="auto",
torch_dtype=torch.float16,
model_config_kwargs=dict(
fully_sharded_data_parallel=True,
attn_mechanism=attn_mechanism,
scan_mlp_chunk_size=max_compile_tokens,
use_scan_mlp=use_scan_mlp,
scan_ring_attention=scan_ring_attention,
block_k=block_k,
block_q=block_q,
use_sharded_kv_caching=use_sharded_kv_caching
)
)
history = []
while True:
user_prompt = input("> ")
model_prompt = server.format_chat(history, user_prompt, "You are an AI be polite and respectful as possible and help user in tasks of debugging code and writing code.")
past_response_length = 0
for response, used_tokens in server.sample(
model_prompt,
greedy=False
):
response = response.replace("<|im_end|>","")
print(response[past_response_length:], end="")
past_response_length = len(response)
history.append([user_prompt, response])
from EasyDel import (
AutoEasyDelModelForCausalLM,
TrainArguments,
CausalLanguageModelTrainer,
EasyDelOptimizers,
EasyDelSchedulers,
EasyDelGradientCheckPointers,
EasyDelState,
EasyDeLXRapTureConfig,
get_modules_by_type,
easystate_to_huggingface_model
)
from datasets import load_dataset
from jax.sharding import PartitionSpec
from flax.core import FrozenDict
from transformers import AutoTokenizer
from jax import numpy as jnp
import jax
from transformers import PhiForCausalLM as ModuleTorch
from huggingface_hub import HfApi, hf_hub_download
from fjformer import GenerateRNG, save_ckpt, get_dtype
rng_g = GenerateRNG()
api = HfApi()
sharding_axis_dims = (1, -1, 1, 1) # Using FSDP instead of SeqP
max_length = 2048
input_shape = (8, max_length) # 8 TPU devices are available
pretrained_model_name_or_path = "microsoft/phi-2"
model, params = AutoEasyDelModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
device = jax.devices('cpu')[0],
input_shape = input_shape,
device_map = "auto",
sharding_axis_dims = sharding_axis_dims
)
config = model.config
model_parameters = FrozenDict({"params" : params})
dtype = jnp.bfloat16
use_lora = False
config.add_basic_configurations(
attn_mechanism="flash",
block_b=1,
block_q=128,
block_k=128,
block_k_major=128,
use_shard_map=False
)
tokenizer = AutoTokenizer.from_pretrained(
"erfanzar/phi-2",
trust_remote_code=True
)
configs_to_initialize_model_class={
"config" : config,
"dtype" : dtype,
"param_dtype" : dtype,
"input_shape" : input_shape
}
if tokenizer.pad_token == None:
tokenizer.pad_token = tokenizer.eos_token
rapture_config = EasyDeLXRapTureConfig(
model_parameters,
lora_dim=64,
fully_fine_tune_parameters=["embed_tokens"],
lora_fine_tune_parameters=["q_proj", "v_proj", "k_proj", "o_proj"],
verbose=True
) if use_lora else None
dataset = load_dataset(
"erfanzar/GPT-4-Prompts",
split="train",
)
def tokenization_process(data_chunk) -> dict:
return tokenizer(
data_chunk["chatml_prompt"],
add_special_tokens=False,
max_length=max_length,
padding="max_length"
)
dataset = dataset.map(
tokenization_process,
num_proc=18,
remove_columns=dataset.column_names
)
rules = (
("embed_tokens/embedding", PartitionSpec(("fsdp", "sp"), )),
("final_layernorm/(scale|bias)", PartitionSpec(None, )),
("final_layernorm/(scale|bias)", PartitionSpec(None, )),
("mlp/fc1/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
("mlp/fc1/bias", PartitionSpec("tp", )),
("mlp/fc2/kernel", PartitionSpec(("fsdp", "sp"),"tp", )),
("mlp/fc2/bias", PartitionSpec("tp", )),
("self_attn/dense/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
("self_attn/dense/bias", PartitionSpec("tp")),
("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
("self_attn/(q_proj|k_proj|v_proj)/bias", PartitionSpec("tp", )),
("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
("lm_head/bias", PartitionSpec("tp")),
(".*", PartitionSpec(None, ))
)
config.get_partition_rules = lambda x:rules
train_args = TrainArguments(
model_class=get_modules_by_type(config.model_type)[1],
configs_to_initialize_model_class=configs_to_initialize_model_class,
custom_rule=config.get_partition_rules(True),
model_name="TLLM",
num_train_epochs=3,
learning_rate=2e-5,
learning_rate_end=7e-6,
warmup_steps=50,
optimizer=EasyDelOptimizers.ADAMW,
scheduler=EasyDelSchedulers.LINEAR,
weight_decay=0.02,
total_batch_size=16,
max_sequence_length=max_length,
gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=sharding_axis_dims,
use_pjit_attention_force=False,
gradient_accumulation_steps=1,
init_input_shape=input_shape,
dtype=dtype,
param_dtype=dtype,
step_start_point=0,
training_time="8H",
rapture_config=rapture_config,
)
trainer = CausalLanguageModelTrainer(
train_args,
dataset.shuffle().shuffle().shuffle(),
checkpoint_path=None
)
output = trainer.train(
model_parameters=model_parameters if not use_lora else None,
state=None
)
new_repo_id = "Phi-2-TLLM"
try:
api.create_repo(new_repo_id, private=False)
except:
...
api.upload_file(
path_or_fileobj=output.checkpoint_path,
repo_id=new_repo_id,
path_in_repo=output.last_save_file_name
)
with jax.default_device(jax.devices("cpu")[0]):
model = easystate_to_huggingface_model(
state=EasyDelState.load_state(
output.checkpoint_path
),
base_huggingface_module=ModuleTorch,
config=config
)
half_model = model.half()
half_model.push_to_hub(new_repo_id)
tokenizer.push_to_hub(new_repo_id)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment