Skip to content

Instantly share code, notes, and snippets.

@seastar105
Last active March 6, 2025 16:57
Show Gist options
  • Save seastar105/d1d8983b27611370528e3b194dcc5577 to your computer and use it in GitHub Desktop.
Save seastar105/d1d8983b27611370528e3b194dcc5577 to your computer and use it in GitHub Desktop.
Phi-4-multimodal-korean-finetuning
from datasets import load_dataset
import torch
import sacrebleu
import json
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
import soundfile as sf
from jiwer import cer
from whisper_normalizer.basic import BasicTextNormalizer
from tqdm.auto import tqdm
import re
import numpy as np
user_prompt = '<|user|>'
assistant_prompt = '<|assistant|>'
prompt_suffix = '<|end|>'
# task prompt is from technical report
asr_prompt = f'{user_prompt}<|audio_1|>Transcribe the audio clip into text.{prompt_suffix}{assistant_prompt}'
ast_ko_prompt = f'{user_prompt}<|audio_1|>Translate the audio to Korean.{prompt_suffix}{assistant_prompt}'
ast_cot_ko_prompt = f'{user_prompt}<|audio_1|>Transcribe the audio to text, and then translate the audio to Korean. Use <sep> as a separator between the original transcript and the translation.{prompt_suffix}{assistant_prompt}'
ast_en_prompt = f'{user_prompt}<|audio_1|>Translate the audio to English.{prompt_suffix}{assistant_prompt}'
ast_cot_en_prompt = f'{user_prompt}<|audio_1|>Transcribe the audio to text, and then translate the audio to English. Use <sep> as a separator between the original transcript and the translation.{prompt_suffix}{assistant_prompt}'
normalizer = BasicTextNormalizer()
def inference_audio(prompt, audio, model, processor, generation_config, max_new_tokens=64):
inputs = processor(text=prompt, audios=[audio], return_tensors='pt').to(model.device)
generate_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
generation_config=generation_config,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1] :]
response = processor.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return response
def evaluate(model, asr_ds, ast_ds, debug=False):
result = dict()
aggregated = dict()
result["ASR"] = []
for item in tqdm(asr_ds, desc="ASR Evaluation..."):
if len(result["ASR"]) >= 10 and debug:
break
ref = item["text"]
audio = (item["audio"]["array"], item["audio"]["sampling_rate"])
hyp = inference_audio(asr_prompt, audio, model, processor, generation_config, max_new_tokens=128)
utt_cer = round(cer(re.sub(r"\s+", "", normalizer(ref)), re.sub(r"\s+", "", normalizer(hyp))) * 100, 2)
result["ASR"].append({"ref": ref, "hyp": hyp, "cer": utt_cer})
scores = [d["cer"] for d in result["ASR"]]
aggregated["ASR"] = sum(scores) / len(scores)
result["AST_KO_EN"] = []
for item in tqdm(ast_ds, desc="AST Ko -> En Evaluation..."):
if len(result["AST_KO_EN"]) >= 10 and debug:
break
ref = item["en_transcription"]
audio = (np.asarray(item["ko_audio"]["array"]), item["ko_audio"]["sampling_rate"])
hyp = inference_audio(ast_en_prompt, audio, model, processor, generation_config, max_new_tokens=128)
utt_bleu = sacrebleu.sentence_bleu(hyp, [ref]).score
result["AST_KO_EN"].append({"ref": ref, "hyp": hyp, "bleu": utt_bleu})
scores = [d["bleu"] for d in result["AST_KO_EN"]]
aggregated["AST_KO_EN"] = sum(scores) / len(scores)
result["AST_KO_EN_COT"] = []
for item in tqdm(ast_ds, desc="AST Ko -> En COT Evaluation..."):
if len(result["AST_KO_EN_COT"]) >= 10 and debug:
break
ref = item["en_transcription"]
audio = (np.asarray(item["ko_audio"]["array"]), item["ko_audio"]["sampling_rate"])
orig_hyp = inference_audio(ast_cot_en_prompt, audio, model, processor, generation_config, max_new_tokens=256)
if "<sep>" in hyp:
hyp = orig_hyp.split("<sep>")[-1].strip()
else:
hyp = orig_hyp
utt_bleu = sacrebleu.sentence_bleu(hyp, [ref]).score
result["AST_KO_EN_COT"].append({"ref": ref, "hyp": orig_hyp, "bleu": utt_bleu})
scores = [d["bleu"] for d in result["AST_KO_EN_COT"]]
aggregated["AST_KO_EN_COT"] = sum(scores) / len(scores)
result["AST_EN_KO"] = []
for item in tqdm(ast_ds, desc="AST En -> Ko Evaluation..."):
if len(result["AST_EN_KO"]) >= 10 and debug:
break
ref = item["ko_transcription"]
audio = (np.asarray(item["en_audio"]["array"]), item["en_audio"]["sampling_rate"])
hyp = inference_audio(ast_ko_prompt, audio, model, processor, generation_config, max_new_tokens=128)
utt_bleu = sacrebleu.sentence_bleu(hyp, [ref], tokenize="ko-mecab").score
result["AST_EN_KO"].append({"ref": ref, "hyp": hyp, "bleu": utt_bleu})
scores = [d["bleu"] for d in result["AST_EN_KO"]]
aggregated["AST_EN_KO"] = sum(scores) / len(scores)
result["AST_EN_KO_COT"] = []
for item in tqdm(ast_ds, desc="AST En -> Ko COT Evaluation..."):
if len(result["AST_EN_KO_COT"]) >= 10 and debug:
break
ref = item["ko_transcription"]
audio = (np.asarray(item["en_audio"]["array"]), item["en_audio"]["sampling_rate"])
orig_hyp = inference_audio(ast_cot_ko_prompt, audio, model, processor, generation_config, max_new_tokens=256)
if "<sep>" in hyp:
hyp = orig_hyp.split("<sep>")[-1].strip()
else:
hyp = orig_hyp
utt_bleu = sacrebleu.sentence_bleu(hyp, [ref], tokenize="ko-mecab").score
result["AST_EN_KO_COT"].append({"ref": ref, "hyp": orig_hyp, "bleu": utt_bleu})
scores = [d["bleu"] for d in result["AST_EN_KO_COT"]]
aggregated["AST_EN_KO_COT"] = sum(scores) / len(scores)
result["aggregated"] = aggregated
return result
if __name__ == "__main__":
# Load model
orig_model_path = "microsoft/Phi-4-multimodal-instruct"
ft_model_path = "seastar105/Phi-4-mm-inst-zeroth-kor"
generation_config = GenerationConfig.from_pretrained(orig_model_path, 'generation_config.json')
processor = AutoProcessor.from_pretrained(orig_model_path, trust_remote_code=True)
asr_ds = load_dataset("kresnik/zeroth_korean", split="test")
ast_ds = load_dataset("seastar105/fleurs_ko_en_test", split="train")
print("Evaluating the model before fine-tuning")
model = AutoModelForCausalLM.from_pretrained(
orig_model_path,
trust_remote_code=True,
torch_dtype='auto',
_attn_implementation='flash_attention_2',
).cuda()
orig_result = evaluate(model, asr_ds, ast_ds)
print("Evaluation result before fine-tuning")
print(json.dumps(orig_result["aggregated"], ensure_ascii=False, indent=2))
with open("orig_result.json", "w", encoding="utf-8") as f:
json.dump(orig_result, f, ensure_ascii=False, indent=2)
del model
__import__('gc').collect()
torch.cuda.empty_cache()
print("Evaluating the model after fine-tuning")
model = AutoModelForCausalLM.from_pretrained(
ft_model_path,
trust_remote_code=True,
torch_dtype='auto',
_attn_implementation='flash_attention_2',
).cuda()
ft_result = evaluate(model, asr_ds, ast_ds)
print("Evaluation result after fine-tuning")
print(json.dumps(ft_result["aggregated"], ensure_ascii=False, indent=2))
with open("ft_result.json", "w", encoding="utf-8") as f:
json.dump(ft_result, f, ensure_ascii=False, indent=2)
"""
finetune Phi-4-multimodal-instruct on an speech task
scipy==1.15.1
peft==0.13.2
backoff==2.2.1
transformers==4.46.1
accelerate==1.3.0
"""
import argparse
import json
import os
from pathlib import Path
import torch
from jiwer import cer
import re
from whisper_normalizer.basic import BasicTextNormalizer
from accelerate import Accelerator
from accelerate.utils import gather_object
from datasets import load_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
BatchFeature,
Trainer,
TrainingArguments,
StoppingCriteria,
StoppingCriteriaList,
)
INSTSRUCTION = {
"en_zh-CN": "Translate the audio to Mandarin.",
"en_id": "Translate the audio to Indonesian.",
"en_sl": "Translate the audio to Slovenian.",
}
TOKENIZER = {
"en_zh-CN": "zh",
"en_ja": "ja-mecab",
}
ANSWER_SUFFIX = "<|end|><|endoftext|>"
_IGNORE_INDEX = -100
_TRAIN_SIZE = 50000
_EVAL_SIZE = 200
class MultipleTokenBatchStoppingCriteria(StoppingCriteria):
"""Stopping criteria capable of receiving multiple stop-tokens and handling batched inputs."""
def __init__(self, stop_tokens: torch.LongTensor, batch_size: int = 1) -> None:
"""Initialize the multiple token batch stopping criteria.
Args:
stop_tokens: Stop-tokens.
batch_size: Batch size.
"""
self.stop_tokens = stop_tokens
self.max_stop_tokens = stop_tokens.shape[-1]
self.stop_tokens_idx = torch.zeros(batch_size, dtype=torch.long, device=stop_tokens.device)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
# Only gather the maximum number of inputs compatible with stop tokens
# and checks whether generated inputs are equal to `stop_tokens`
generated_inputs = torch.eq(input_ids[:, -self.max_stop_tokens :].unsqueeze(1), self.stop_tokens)
equal_generated_inputs = torch.all(generated_inputs, dim=2)
# Mark the position where a stop token has been produced for each input in the batch,
# but only if the corresponding entry is not already set
sequence_idx = torch.any(equal_generated_inputs, dim=1)
sequence_set_mask = self.stop_tokens_idx == 0
self.stop_tokens_idx[sequence_idx & sequence_set_mask] = input_ids.shape[-1]
return torch.all(self.stop_tokens_idx)
class STTDataset(Dataset):
def __init__(self, processor, rank=0, world_size=1, split='train'):
self.dataset = load_dataset("kresnik/zeroth_korean", split=split)
self.processor = processor
self.rank = rank
self.world_size = world_size
self.instruction = "Transcribe the audio clip into text."
self.training = "train" in split
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
data = self.dataset[idx]
user_message = {
'role': 'user',
'content': '<|audio_1|>\n' + self.instruction,
}
prompt = self.processor.tokenizer.apply_chat_template(
[user_message], tokenize=False, add_generation_prompt=True
)
inputs = self.processor(text=prompt, audios=[(data["audio"]["array"], data["audio"]["sampling_rate"])], return_tensors='pt')
answer = f"{data['text']}{ANSWER_SUFFIX}"
answer_ids = self.processor.tokenizer(answer, return_tensors='pt').input_ids
if self.training:
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
labels = torch.full_like(input_ids, _IGNORE_INDEX)
labels[:, -answer_ids.shape[1] :] = answer_ids
else:
input_ids = inputs.input_ids
labels = answer_ids
return {
'input_ids': input_ids,
'labels': labels,
'input_audio_embeds': inputs.input_audio_embeds,
'audio_embed_sizes': inputs.audio_embed_sizes,
}
def pad_sequence(sequences, padding_side='right', padding_value=0):
"""
Pad a list of sequences to the same length.
sequences: list of tensors in [seq_len, *] shape
"""
assert padding_side in ['right', 'left']
max_size = sequences[0].size()
trailing_dims = max_size[1:]
max_len = max(len(seq) for seq in sequences)
batch_size = len(sequences)
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
for i, seq in enumerate(sequences):
length = seq.size(0)
if padding_side == 'right':
output.data[i, :length] = seq
else:
output.data[i, -length:] = seq
return output
def cat_with_pad(tensors, dim, padding_value=0):
"""
cat along dim, while pad to max for all other dims
"""
ndim = tensors[0].dim()
assert all(
t.dim() == ndim for t in tensors[1:]
), 'All tensors must have the same number of dimensions'
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
out_size[dim] = sum(t.shape[dim] for t in tensors)
output = tensors[0].new_full(out_size, padding_value)
index = 0
for t in tensors:
# Create a slice list where every dimension except dim is full slice
slices = [slice(0, t.shape[d]) for d in range(ndim)]
# Update only the concat dimension slice
slices[dim] = slice(index, index + t.shape[dim])
output[slices] = t
index += t.shape[dim]
return output
def collate_fn(batch):
input_ids_list = []
labels_list = []
input_audio_embeds_list = []
audio_embed_sizes_list = []
audio_attention_mask_list = []
for inputs in batch:
input_ids_list.append(inputs['input_ids'][0])
labels_list.append(inputs['labels'][0])
input_audio_embeds_list.append(inputs['input_audio_embeds'])
audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
audio_attention_mask_list.append(
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
)
try:
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
labels = pad_sequence(labels_list, padding_side='left', padding_value=0)
audio_attention_mask = (
pad_sequence(audio_attention_mask_list, padding_side='right', padding_value=False)
if len(audio_attention_mask_list) > 1
else None
)
except Exception as e:
print(e)
print(input_ids_list)
print(labels_list)
raise
attention_mask = (input_ids != 0).long()
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0)
audio_embed_sizes = torch.cat(audio_embed_sizes_list)
return BatchFeature(
{
'input_ids': input_ids,
'labels': labels,
'attention_mask': attention_mask,
'input_audio_embeds': input_audio_embeds,
'audio_embed_sizes': audio_embed_sizes,
'audio_attention_mask': audio_attention_mask,
'input_mode': 2, # speech mode
}
)
def create_model(model_name_or_path, use_flash_attention=False):
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16 if use_flash_attention else torch.float32,
_attn_implementation='flash_attention_2' if use_flash_attention else 'sdpa',
trust_remote_code=True,
).to('cuda')
return model
@torch.no_grad()
def evaluate(
model, processor, eval_dataset, save_path=None, disable_tqdm=False, eval_batch_size=1
):
rank = int(os.environ.get('RANK', 0))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
model.eval()
all_generated_texts = []
all_labels = []
eval_dataloader = torch.utils.data.DataLoader(
eval_dataset,
batch_size=eval_batch_size,
collate_fn=collate_fn,
shuffle=False,
drop_last=False,
num_workers=8,
prefetch_factor=2,
pin_memory=True,
)
stop_tokens = ["<|end|>", processor.tokenizer.eos_token]
stop_tokens_ids = processor.tokenizer(stop_tokens, add_special_tokens=False, padding="longest", return_tensors="pt")["input_ids"]
stop_tokens_ids = stop_tokens_ids.to(f'cuda:{local_rank}')
for inputs in tqdm(
eval_dataloader, disable=(rank != 0) or disable_tqdm, desc='running eval'
):
stopping_criteria=StoppingCriteriaList([MultipleTokenBatchStoppingCriteria(stop_tokens_ids, batch_size=inputs.input_ids.size(0))])
inputs = inputs.to(f'cuda:{local_rank}')
generated_ids = model.generate(
**inputs, eos_token_id=processor.tokenizer.eos_token_id, max_new_tokens=64,
stopping_criteria=stopping_criteria,
)
stop_tokens_idx = stopping_criteria[0].stop_tokens_idx.reshape(inputs.input_ids.size(0), -1)[:, 0]
stop_tokens_idx = torch.where(
stop_tokens_idx > 0,
stop_tokens_idx - stop_tokens_ids.shape[-1],
generated_ids.shape[-1],
)
generated_text = [
processor.decode(_pred_ids[inputs["input_ids"].shape[1] : _stop_tokens_idx], skip_special_tokens=True, clean_up_tokenization_spaces=False)
for _pred_ids, _stop_tokens_idx in zip(generated_ids, stop_tokens_idx)
]
all_generated_texts.extend(generated_text)
labels = [processor.decode(_label_ids[_label_ids != 0]).rstrip(ANSWER_SUFFIX) for _label_ids in inputs["labels"]]
all_labels.extend(labels)
all_generated_texts = gather_object(all_generated_texts)
all_labels = gather_object(all_labels)
if rank == 0:
assert len(all_generated_texts) == len(all_labels)
normalizer = BasicTextNormalizer()
hyps = [re.sub(r"\s+", "", normalizer(text)) for text in all_generated_texts]
refs = [re.sub(r"\s+", "", normalizer(text)) for text in all_labels]
cer_score = round(cer(refs, hyps) * 100, 2)
if save_path:
with open(save_path, 'w', encoding='utf-8') as f:
for ref, hyp in zip(all_labels, all_generated_texts):
utt_cer = round(cer(re.sub(r"\s+", "", normalizer(ref)), re.sub(r"\s+", "", normalizer(hyp))) * 100, 2)
print(json.dumps({'ref': ref, 'hyp': hyp, "cer": utt_cer}, ensure_ascii=False), file=f)
return cer_score
return None
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_name_or_path',
type=str,
default='microsoft/Phi-4-multimodal-instruct',
help='Model name or path to load from',
)
parser.add_argument('--use_flash_attention', action='store_true', help='Use Flash Attention')
parser.add_argument('--output_dir', type=str, default='./output/', help='Output directory')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
parser.add_argument(
'--batch_size_per_gpu',
type=int,
default=32,
help='Batch size per GPU (adjust this to fit in GPU memory)',
)
parser.add_argument(
'--num_train_epochs', type=int, default=1, help='Number of training epochs'
)
parser.add_argument('--learning_rate', type=float, default=4.0e-5, help='Learning rate')
parser.add_argument('--wd', type=float, default=0.01, help='Weight decay')
parser.add_argument('--no-tqdm', dest='tqdm', action='store_false', help='Disable tqdm')
args = parser.parse_args()
accelerator = Accelerator()
with accelerator.local_main_process_first():
processor = AutoProcessor.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
)
model = create_model(
args.model_name_or_path,
use_flash_attention=args.use_flash_attention,
)
model.set_lora_adapter('speech')
rank = int(os.environ.get('RANK', 0))
world_size = int(os.environ.get('WORLD_SIZE', 1))
eval_dataset = STTDataset(processor, split='test', rank=rank, world_size=world_size)
train_dataset = STTDataset(processor, split='train')
num_gpus = accelerator.num_processes
print(f'training on {num_gpus} GPUs')
assert (
args.batch_size % (num_gpus * args.batch_size_per_gpu) == 0
), 'Batch size must be divisible by the number of GPUs'
gradient_accumulation_steps = args.batch_size // (num_gpus * args.batch_size_per_gpu)
if args.use_flash_attention:
fp16 = False
bf16 = True
else:
fp16 = True
bf16 = False
# hard coded training args
training_args = TrainingArguments(
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.batch_size_per_gpu,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={'use_reentrant': False},
gradient_accumulation_steps=gradient_accumulation_steps,
optim='adamw_torch',
adam_beta1=0.9,
adam_beta2=0.95,
adam_epsilon=1e-7,
learning_rate=args.learning_rate,
weight_decay=args.wd,
max_grad_norm=1.0,
lr_scheduler_type='linear',
warmup_steps=50,
logging_steps=10,
output_dir=args.output_dir,
save_strategy='no',
save_total_limit=10,
save_only_model=True,
bf16=bf16,
fp16=fp16,
remove_unused_columns=False,
report_to='none',
deepspeed=None,
disable_tqdm=not args.tqdm,
dataloader_num_workers=4,
ddp_find_unused_parameters=True, # for unused SigLIP layers
)
# eval before fine-tuning
out_path = Path(training_args.output_dir)
out_path.mkdir(parents=True, exist_ok=True)
score = evaluate(
model,
processor,
eval_dataset,
save_path=out_path / 'eval_before.json',
disable_tqdm=not args.tqdm,
eval_batch_size=args.batch_size_per_gpu,
)
if accelerator.is_main_process:
print(f'CER Score before finetuning: {score}')
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Trainable parameters: {trainable_params / 1e6:.2f}M')
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=train_dataset,
)
trainer.train()
trainer.save_model()
if accelerator.is_main_process:
processor.save_pretrained(training_args.output_dir)
accelerator.wait_for_everyone()
# eval after fine-tuning (load saved checkpoint)
# first try to clear GPU memory
del model
del trainer
__import__('gc').collect()
torch.cuda.empty_cache()
# reload the model for inference
model = AutoModelForCausalLM.from_pretrained(
training_args.output_dir,
torch_dtype=torch.bfloat16 if args.use_flash_attention else torch.float32,
trust_remote_code=True,
_attn_implementation='flash_attention_2' if args.use_flash_attention else 'sdpa',
).to('cuda')
score = evaluate(
model,
processor,
eval_dataset,
save_path=out_path / 'eval_after.json',
disable_tqdm=not args.tqdm,
eval_batch_size=args.batch_size_per_gpu,
)
if accelerator.is_main_process:
print(f'CER Score after finetuning: {score}')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment