Last active
March 6, 2025 16:57
-
-
Save seastar105/d1d8983b27611370528e3b194dcc5577 to your computer and use it in GitHub Desktop.
Phi-4-multimodal-korean-finetuning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset | |
import torch | |
import sacrebleu | |
import json | |
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig | |
import soundfile as sf | |
from jiwer import cer | |
from whisper_normalizer.basic import BasicTextNormalizer | |
from tqdm.auto import tqdm | |
import re | |
import numpy as np | |
user_prompt = '<|user|>' | |
assistant_prompt = '<|assistant|>' | |
prompt_suffix = '<|end|>' | |
# task prompt is from technical report | |
asr_prompt = f'{user_prompt}<|audio_1|>Transcribe the audio clip into text.{prompt_suffix}{assistant_prompt}' | |
ast_ko_prompt = f'{user_prompt}<|audio_1|>Translate the audio to Korean.{prompt_suffix}{assistant_prompt}' | |
ast_cot_ko_prompt = f'{user_prompt}<|audio_1|>Transcribe the audio to text, and then translate the audio to Korean. Use <sep> as a separator between the original transcript and the translation.{prompt_suffix}{assistant_prompt}' | |
ast_en_prompt = f'{user_prompt}<|audio_1|>Translate the audio to English.{prompt_suffix}{assistant_prompt}' | |
ast_cot_en_prompt = f'{user_prompt}<|audio_1|>Transcribe the audio to text, and then translate the audio to English. Use <sep> as a separator between the original transcript and the translation.{prompt_suffix}{assistant_prompt}' | |
normalizer = BasicTextNormalizer() | |
def inference_audio(prompt, audio, model, processor, generation_config, max_new_tokens=64): | |
inputs = processor(text=prompt, audios=[audio], return_tensors='pt').to(model.device) | |
generate_ids = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
generation_config=generation_config, | |
) | |
generate_ids = generate_ids[:, inputs['input_ids'].shape[1] :] | |
response = processor.batch_decode( | |
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
)[0] | |
return response | |
def evaluate(model, asr_ds, ast_ds, debug=False): | |
result = dict() | |
aggregated = dict() | |
result["ASR"] = [] | |
for item in tqdm(asr_ds, desc="ASR Evaluation..."): | |
if len(result["ASR"]) >= 10 and debug: | |
break | |
ref = item["text"] | |
audio = (item["audio"]["array"], item["audio"]["sampling_rate"]) | |
hyp = inference_audio(asr_prompt, audio, model, processor, generation_config, max_new_tokens=128) | |
utt_cer = round(cer(re.sub(r"\s+", "", normalizer(ref)), re.sub(r"\s+", "", normalizer(hyp))) * 100, 2) | |
result["ASR"].append({"ref": ref, "hyp": hyp, "cer": utt_cer}) | |
scores = [d["cer"] for d in result["ASR"]] | |
aggregated["ASR"] = sum(scores) / len(scores) | |
result["AST_KO_EN"] = [] | |
for item in tqdm(ast_ds, desc="AST Ko -> En Evaluation..."): | |
if len(result["AST_KO_EN"]) >= 10 and debug: | |
break | |
ref = item["en_transcription"] | |
audio = (np.asarray(item["ko_audio"]["array"]), item["ko_audio"]["sampling_rate"]) | |
hyp = inference_audio(ast_en_prompt, audio, model, processor, generation_config, max_new_tokens=128) | |
utt_bleu = sacrebleu.sentence_bleu(hyp, [ref]).score | |
result["AST_KO_EN"].append({"ref": ref, "hyp": hyp, "bleu": utt_bleu}) | |
scores = [d["bleu"] for d in result["AST_KO_EN"]] | |
aggregated["AST_KO_EN"] = sum(scores) / len(scores) | |
result["AST_KO_EN_COT"] = [] | |
for item in tqdm(ast_ds, desc="AST Ko -> En COT Evaluation..."): | |
if len(result["AST_KO_EN_COT"]) >= 10 and debug: | |
break | |
ref = item["en_transcription"] | |
audio = (np.asarray(item["ko_audio"]["array"]), item["ko_audio"]["sampling_rate"]) | |
orig_hyp = inference_audio(ast_cot_en_prompt, audio, model, processor, generation_config, max_new_tokens=256) | |
if "<sep>" in hyp: | |
hyp = orig_hyp.split("<sep>")[-1].strip() | |
else: | |
hyp = orig_hyp | |
utt_bleu = sacrebleu.sentence_bleu(hyp, [ref]).score | |
result["AST_KO_EN_COT"].append({"ref": ref, "hyp": orig_hyp, "bleu": utt_bleu}) | |
scores = [d["bleu"] for d in result["AST_KO_EN_COT"]] | |
aggregated["AST_KO_EN_COT"] = sum(scores) / len(scores) | |
result["AST_EN_KO"] = [] | |
for item in tqdm(ast_ds, desc="AST En -> Ko Evaluation..."): | |
if len(result["AST_EN_KO"]) >= 10 and debug: | |
break | |
ref = item["ko_transcription"] | |
audio = (np.asarray(item["en_audio"]["array"]), item["en_audio"]["sampling_rate"]) | |
hyp = inference_audio(ast_ko_prompt, audio, model, processor, generation_config, max_new_tokens=128) | |
utt_bleu = sacrebleu.sentence_bleu(hyp, [ref], tokenize="ko-mecab").score | |
result["AST_EN_KO"].append({"ref": ref, "hyp": hyp, "bleu": utt_bleu}) | |
scores = [d["bleu"] for d in result["AST_EN_KO"]] | |
aggregated["AST_EN_KO"] = sum(scores) / len(scores) | |
result["AST_EN_KO_COT"] = [] | |
for item in tqdm(ast_ds, desc="AST En -> Ko COT Evaluation..."): | |
if len(result["AST_EN_KO_COT"]) >= 10 and debug: | |
break | |
ref = item["ko_transcription"] | |
audio = (np.asarray(item["en_audio"]["array"]), item["en_audio"]["sampling_rate"]) | |
orig_hyp = inference_audio(ast_cot_ko_prompt, audio, model, processor, generation_config, max_new_tokens=256) | |
if "<sep>" in hyp: | |
hyp = orig_hyp.split("<sep>")[-1].strip() | |
else: | |
hyp = orig_hyp | |
utt_bleu = sacrebleu.sentence_bleu(hyp, [ref], tokenize="ko-mecab").score | |
result["AST_EN_KO_COT"].append({"ref": ref, "hyp": orig_hyp, "bleu": utt_bleu}) | |
scores = [d["bleu"] for d in result["AST_EN_KO_COT"]] | |
aggregated["AST_EN_KO_COT"] = sum(scores) / len(scores) | |
result["aggregated"] = aggregated | |
return result | |
if __name__ == "__main__": | |
# Load model | |
orig_model_path = "microsoft/Phi-4-multimodal-instruct" | |
ft_model_path = "seastar105/Phi-4-mm-inst-zeroth-kor" | |
generation_config = GenerationConfig.from_pretrained(orig_model_path, 'generation_config.json') | |
processor = AutoProcessor.from_pretrained(orig_model_path, trust_remote_code=True) | |
asr_ds = load_dataset("kresnik/zeroth_korean", split="test") | |
ast_ds = load_dataset("seastar105/fleurs_ko_en_test", split="train") | |
print("Evaluating the model before fine-tuning") | |
model = AutoModelForCausalLM.from_pretrained( | |
orig_model_path, | |
trust_remote_code=True, | |
torch_dtype='auto', | |
_attn_implementation='flash_attention_2', | |
).cuda() | |
orig_result = evaluate(model, asr_ds, ast_ds) | |
print("Evaluation result before fine-tuning") | |
print(json.dumps(orig_result["aggregated"], ensure_ascii=False, indent=2)) | |
with open("orig_result.json", "w", encoding="utf-8") as f: | |
json.dump(orig_result, f, ensure_ascii=False, indent=2) | |
del model | |
__import__('gc').collect() | |
torch.cuda.empty_cache() | |
print("Evaluating the model after fine-tuning") | |
model = AutoModelForCausalLM.from_pretrained( | |
ft_model_path, | |
trust_remote_code=True, | |
torch_dtype='auto', | |
_attn_implementation='flash_attention_2', | |
).cuda() | |
ft_result = evaluate(model, asr_ds, ast_ds) | |
print("Evaluation result after fine-tuning") | |
print(json.dumps(ft_result["aggregated"], ensure_ascii=False, indent=2)) | |
with open("ft_result.json", "w", encoding="utf-8") as f: | |
json.dump(ft_result, f, ensure_ascii=False, indent=2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
finetune Phi-4-multimodal-instruct on an speech task | |
scipy==1.15.1 | |
peft==0.13.2 | |
backoff==2.2.1 | |
transformers==4.46.1 | |
accelerate==1.3.0 | |
""" | |
import argparse | |
import json | |
import os | |
from pathlib import Path | |
import torch | |
from jiwer import cer | |
import re | |
from whisper_normalizer.basic import BasicTextNormalizer | |
from accelerate import Accelerator | |
from accelerate.utils import gather_object | |
from datasets import load_dataset | |
from torch.utils.data import Dataset | |
from tqdm import tqdm | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoProcessor, | |
BatchFeature, | |
Trainer, | |
TrainingArguments, | |
StoppingCriteria, | |
StoppingCriteriaList, | |
) | |
INSTSRUCTION = { | |
"en_zh-CN": "Translate the audio to Mandarin.", | |
"en_id": "Translate the audio to Indonesian.", | |
"en_sl": "Translate the audio to Slovenian.", | |
} | |
TOKENIZER = { | |
"en_zh-CN": "zh", | |
"en_ja": "ja-mecab", | |
} | |
ANSWER_SUFFIX = "<|end|><|endoftext|>" | |
_IGNORE_INDEX = -100 | |
_TRAIN_SIZE = 50000 | |
_EVAL_SIZE = 200 | |
class MultipleTokenBatchStoppingCriteria(StoppingCriteria): | |
"""Stopping criteria capable of receiving multiple stop-tokens and handling batched inputs.""" | |
def __init__(self, stop_tokens: torch.LongTensor, batch_size: int = 1) -> None: | |
"""Initialize the multiple token batch stopping criteria. | |
Args: | |
stop_tokens: Stop-tokens. | |
batch_size: Batch size. | |
""" | |
self.stop_tokens = stop_tokens | |
self.max_stop_tokens = stop_tokens.shape[-1] | |
self.stop_tokens_idx = torch.zeros(batch_size, dtype=torch.long, device=stop_tokens.device) | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
# Only gather the maximum number of inputs compatible with stop tokens | |
# and checks whether generated inputs are equal to `stop_tokens` | |
generated_inputs = torch.eq(input_ids[:, -self.max_stop_tokens :].unsqueeze(1), self.stop_tokens) | |
equal_generated_inputs = torch.all(generated_inputs, dim=2) | |
# Mark the position where a stop token has been produced for each input in the batch, | |
# but only if the corresponding entry is not already set | |
sequence_idx = torch.any(equal_generated_inputs, dim=1) | |
sequence_set_mask = self.stop_tokens_idx == 0 | |
self.stop_tokens_idx[sequence_idx & sequence_set_mask] = input_ids.shape[-1] | |
return torch.all(self.stop_tokens_idx) | |
class STTDataset(Dataset): | |
def __init__(self, processor, rank=0, world_size=1, split='train'): | |
self.dataset = load_dataset("kresnik/zeroth_korean", split=split) | |
self.processor = processor | |
self.rank = rank | |
self.world_size = world_size | |
self.instruction = "Transcribe the audio clip into text." | |
self.training = "train" in split | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
data = self.dataset[idx] | |
user_message = { | |
'role': 'user', | |
'content': '<|audio_1|>\n' + self.instruction, | |
} | |
prompt = self.processor.tokenizer.apply_chat_template( | |
[user_message], tokenize=False, add_generation_prompt=True | |
) | |
inputs = self.processor(text=prompt, audios=[(data["audio"]["array"], data["audio"]["sampling_rate"])], return_tensors='pt') | |
answer = f"{data['text']}{ANSWER_SUFFIX}" | |
answer_ids = self.processor.tokenizer(answer, return_tensors='pt').input_ids | |
if self.training: | |
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1) | |
labels = torch.full_like(input_ids, _IGNORE_INDEX) | |
labels[:, -answer_ids.shape[1] :] = answer_ids | |
else: | |
input_ids = inputs.input_ids | |
labels = answer_ids | |
return { | |
'input_ids': input_ids, | |
'labels': labels, | |
'input_audio_embeds': inputs.input_audio_embeds, | |
'audio_embed_sizes': inputs.audio_embed_sizes, | |
} | |
def pad_sequence(sequences, padding_side='right', padding_value=0): | |
""" | |
Pad a list of sequences to the same length. | |
sequences: list of tensors in [seq_len, *] shape | |
""" | |
assert padding_side in ['right', 'left'] | |
max_size = sequences[0].size() | |
trailing_dims = max_size[1:] | |
max_len = max(len(seq) for seq in sequences) | |
batch_size = len(sequences) | |
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value) | |
for i, seq in enumerate(sequences): | |
length = seq.size(0) | |
if padding_side == 'right': | |
output.data[i, :length] = seq | |
else: | |
output.data[i, -length:] = seq | |
return output | |
def cat_with_pad(tensors, dim, padding_value=0): | |
""" | |
cat along dim, while pad to max for all other dims | |
""" | |
ndim = tensors[0].dim() | |
assert all( | |
t.dim() == ndim for t in tensors[1:] | |
), 'All tensors must have the same number of dimensions' | |
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] | |
out_size[dim] = sum(t.shape[dim] for t in tensors) | |
output = tensors[0].new_full(out_size, padding_value) | |
index = 0 | |
for t in tensors: | |
# Create a slice list where every dimension except dim is full slice | |
slices = [slice(0, t.shape[d]) for d in range(ndim)] | |
# Update only the concat dimension slice | |
slices[dim] = slice(index, index + t.shape[dim]) | |
output[slices] = t | |
index += t.shape[dim] | |
return output | |
def collate_fn(batch): | |
input_ids_list = [] | |
labels_list = [] | |
input_audio_embeds_list = [] | |
audio_embed_sizes_list = [] | |
audio_attention_mask_list = [] | |
for inputs in batch: | |
input_ids_list.append(inputs['input_ids'][0]) | |
labels_list.append(inputs['labels'][0]) | |
input_audio_embeds_list.append(inputs['input_audio_embeds']) | |
audio_embed_sizes_list.append(inputs['audio_embed_sizes']) | |
audio_attention_mask_list.append( | |
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool) | |
) | |
try: | |
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0) | |
labels = pad_sequence(labels_list, padding_side='left', padding_value=0) | |
audio_attention_mask = ( | |
pad_sequence(audio_attention_mask_list, padding_side='right', padding_value=False) | |
if len(audio_attention_mask_list) > 1 | |
else None | |
) | |
except Exception as e: | |
print(e) | |
print(input_ids_list) | |
print(labels_list) | |
raise | |
attention_mask = (input_ids != 0).long() | |
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) | |
audio_embed_sizes = torch.cat(audio_embed_sizes_list) | |
return BatchFeature( | |
{ | |
'input_ids': input_ids, | |
'labels': labels, | |
'attention_mask': attention_mask, | |
'input_audio_embeds': input_audio_embeds, | |
'audio_embed_sizes': audio_embed_sizes, | |
'audio_attention_mask': audio_attention_mask, | |
'input_mode': 2, # speech mode | |
} | |
) | |
def create_model(model_name_or_path, use_flash_attention=False): | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name_or_path, | |
torch_dtype=torch.bfloat16 if use_flash_attention else torch.float32, | |
_attn_implementation='flash_attention_2' if use_flash_attention else 'sdpa', | |
trust_remote_code=True, | |
).to('cuda') | |
return model | |
@torch.no_grad() | |
def evaluate( | |
model, processor, eval_dataset, save_path=None, disable_tqdm=False, eval_batch_size=1 | |
): | |
rank = int(os.environ.get('RANK', 0)) | |
local_rank = int(os.environ.get('LOCAL_RANK', 0)) | |
model.eval() | |
all_generated_texts = [] | |
all_labels = [] | |
eval_dataloader = torch.utils.data.DataLoader( | |
eval_dataset, | |
batch_size=eval_batch_size, | |
collate_fn=collate_fn, | |
shuffle=False, | |
drop_last=False, | |
num_workers=8, | |
prefetch_factor=2, | |
pin_memory=True, | |
) | |
stop_tokens = ["<|end|>", processor.tokenizer.eos_token] | |
stop_tokens_ids = processor.tokenizer(stop_tokens, add_special_tokens=False, padding="longest", return_tensors="pt")["input_ids"] | |
stop_tokens_ids = stop_tokens_ids.to(f'cuda:{local_rank}') | |
for inputs in tqdm( | |
eval_dataloader, disable=(rank != 0) or disable_tqdm, desc='running eval' | |
): | |
stopping_criteria=StoppingCriteriaList([MultipleTokenBatchStoppingCriteria(stop_tokens_ids, batch_size=inputs.input_ids.size(0))]) | |
inputs = inputs.to(f'cuda:{local_rank}') | |
generated_ids = model.generate( | |
**inputs, eos_token_id=processor.tokenizer.eos_token_id, max_new_tokens=64, | |
stopping_criteria=stopping_criteria, | |
) | |
stop_tokens_idx = stopping_criteria[0].stop_tokens_idx.reshape(inputs.input_ids.size(0), -1)[:, 0] | |
stop_tokens_idx = torch.where( | |
stop_tokens_idx > 0, | |
stop_tokens_idx - stop_tokens_ids.shape[-1], | |
generated_ids.shape[-1], | |
) | |
generated_text = [ | |
processor.decode(_pred_ids[inputs["input_ids"].shape[1] : _stop_tokens_idx], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
for _pred_ids, _stop_tokens_idx in zip(generated_ids, stop_tokens_idx) | |
] | |
all_generated_texts.extend(generated_text) | |
labels = [processor.decode(_label_ids[_label_ids != 0]).rstrip(ANSWER_SUFFIX) for _label_ids in inputs["labels"]] | |
all_labels.extend(labels) | |
all_generated_texts = gather_object(all_generated_texts) | |
all_labels = gather_object(all_labels) | |
if rank == 0: | |
assert len(all_generated_texts) == len(all_labels) | |
normalizer = BasicTextNormalizer() | |
hyps = [re.sub(r"\s+", "", normalizer(text)) for text in all_generated_texts] | |
refs = [re.sub(r"\s+", "", normalizer(text)) for text in all_labels] | |
cer_score = round(cer(refs, hyps) * 100, 2) | |
if save_path: | |
with open(save_path, 'w', encoding='utf-8') as f: | |
for ref, hyp in zip(all_labels, all_generated_texts): | |
utt_cer = round(cer(re.sub(r"\s+", "", normalizer(ref)), re.sub(r"\s+", "", normalizer(hyp))) * 100, 2) | |
print(json.dumps({'ref': ref, 'hyp': hyp, "cer": utt_cer}, ensure_ascii=False), file=f) | |
return cer_score | |
return None | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--model_name_or_path', | |
type=str, | |
default='microsoft/Phi-4-multimodal-instruct', | |
help='Model name or path to load from', | |
) | |
parser.add_argument('--use_flash_attention', action='store_true', help='Use Flash Attention') | |
parser.add_argument('--output_dir', type=str, default='./output/', help='Output directory') | |
parser.add_argument('--batch_size', type=int, default=128, help='Batch size') | |
parser.add_argument( | |
'--batch_size_per_gpu', | |
type=int, | |
default=32, | |
help='Batch size per GPU (adjust this to fit in GPU memory)', | |
) | |
parser.add_argument( | |
'--num_train_epochs', type=int, default=1, help='Number of training epochs' | |
) | |
parser.add_argument('--learning_rate', type=float, default=4.0e-5, help='Learning rate') | |
parser.add_argument('--wd', type=float, default=0.01, help='Weight decay') | |
parser.add_argument('--no-tqdm', dest='tqdm', action='store_false', help='Disable tqdm') | |
args = parser.parse_args() | |
accelerator = Accelerator() | |
with accelerator.local_main_process_first(): | |
processor = AutoProcessor.from_pretrained( | |
args.model_name_or_path, | |
trust_remote_code=True, | |
) | |
model = create_model( | |
args.model_name_or_path, | |
use_flash_attention=args.use_flash_attention, | |
) | |
model.set_lora_adapter('speech') | |
rank = int(os.environ.get('RANK', 0)) | |
world_size = int(os.environ.get('WORLD_SIZE', 1)) | |
eval_dataset = STTDataset(processor, split='test', rank=rank, world_size=world_size) | |
train_dataset = STTDataset(processor, split='train') | |
num_gpus = accelerator.num_processes | |
print(f'training on {num_gpus} GPUs') | |
assert ( | |
args.batch_size % (num_gpus * args.batch_size_per_gpu) == 0 | |
), 'Batch size must be divisible by the number of GPUs' | |
gradient_accumulation_steps = args.batch_size // (num_gpus * args.batch_size_per_gpu) | |
if args.use_flash_attention: | |
fp16 = False | |
bf16 = True | |
else: | |
fp16 = True | |
bf16 = False | |
# hard coded training args | |
training_args = TrainingArguments( | |
num_train_epochs=args.num_train_epochs, | |
per_device_train_batch_size=args.batch_size_per_gpu, | |
gradient_checkpointing=True, | |
gradient_checkpointing_kwargs={'use_reentrant': False}, | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
optim='adamw_torch', | |
adam_beta1=0.9, | |
adam_beta2=0.95, | |
adam_epsilon=1e-7, | |
learning_rate=args.learning_rate, | |
weight_decay=args.wd, | |
max_grad_norm=1.0, | |
lr_scheduler_type='linear', | |
warmup_steps=50, | |
logging_steps=10, | |
output_dir=args.output_dir, | |
save_strategy='no', | |
save_total_limit=10, | |
save_only_model=True, | |
bf16=bf16, | |
fp16=fp16, | |
remove_unused_columns=False, | |
report_to='none', | |
deepspeed=None, | |
disable_tqdm=not args.tqdm, | |
dataloader_num_workers=4, | |
ddp_find_unused_parameters=True, # for unused SigLIP layers | |
) | |
# eval before fine-tuning | |
out_path = Path(training_args.output_dir) | |
out_path.mkdir(parents=True, exist_ok=True) | |
score = evaluate( | |
model, | |
processor, | |
eval_dataset, | |
save_path=out_path / 'eval_before.json', | |
disable_tqdm=not args.tqdm, | |
eval_batch_size=args.batch_size_per_gpu, | |
) | |
if accelerator.is_main_process: | |
print(f'CER Score before finetuning: {score}') | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print(f'Trainable parameters: {trainable_params / 1e6:.2f}M') | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
data_collator=collate_fn, | |
train_dataset=train_dataset, | |
) | |
trainer.train() | |
trainer.save_model() | |
if accelerator.is_main_process: | |
processor.save_pretrained(training_args.output_dir) | |
accelerator.wait_for_everyone() | |
# eval after fine-tuning (load saved checkpoint) | |
# first try to clear GPU memory | |
del model | |
del trainer | |
__import__('gc').collect() | |
torch.cuda.empty_cache() | |
# reload the model for inference | |
model = AutoModelForCausalLM.from_pretrained( | |
training_args.output_dir, | |
torch_dtype=torch.bfloat16 if args.use_flash_attention else torch.float32, | |
trust_remote_code=True, | |
_attn_implementation='flash_attention_2' if args.use_flash_attention else 'sdpa', | |
).to('cuda') | |
score = evaluate( | |
model, | |
processor, | |
eval_dataset, | |
save_path=out_path / 'eval_after.json', | |
disable_tqdm=not args.tqdm, | |
eval_batch_size=args.batch_size_per_gpu, | |
) | |
if accelerator.is_main_process: | |
print(f'CER Score after finetuning: {score}') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment