Created
September 27, 2023 18:00
-
-
Save Shivampanwar/16155b5646563d10a58221d4a007591b to your computer and use it in GitHub Desktop.
Flan-T5 Instruction tuning code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from datasets import load_dataset | |
huggingface_dataset_name = "knkarthick/dialogsum" | |
dataset = load_dataset(huggingface_dataset_name) | |
dataset | |
from datasets import load_dataset | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer | |
import torch | |
import time | |
import evaluate | |
import pandas as pd | |
import numpy as np | |
import os | |
os.environ["CUDA_VISIBLE_DEVICES"]="0" | |
model_name='google/flan-t5-base' | |
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def tokenize_function(example): | |
start_prompt = 'Summarize the following conversation.\n\n' | |
end_prompt = '\n\nSummary: ' | |
prompt = [start_prompt + dialogue + end_prompt for dialogue in example["dialogue"]] | |
example['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids | |
example['labels'] = tokenizer(example["summary"], padding="max_length", truncation=True, return_tensors="pt").input_ids | |
return example | |
# The dataset actually contains 3 diff splits: train, validation, test. | |
# The tokenize_function code is handling all data across all splits in batches. | |
tokenized_datasets = dataset.map(tokenize_function, batched=True) | |
tokenized_datasets = tokenized_datasets.remove_columns(['id', 'topic', 'dialogue', 'summary',]) | |
from datasets import load_dataset, load_from_disk, concatenate_datasets | |
train_indices_to_keep=np.arange(12000) | |
# train_indices_to_keep | |
def get_example(example,index): | |
global train_indices_to_keep | |
if index < len(train_indices_to_keep): | |
return True | |
train_dataset = tokenized_datasets['train'].filter(get_example,with_indices=True) | |
val_dataset = tokenized_datasets['validation'] | |
test_dataset = tokenized_datasets['test'] | |
train_dataset = train_dataset.shuffle(seed=42) | |
from datasets import DatasetDict | |
reduced_dataset = DatasetDict({ | |
'train': train_dataset, | |
'validation': val_dataset, | |
'test': test_dataset}) | |
# temp_tokenized_datasets = tokenized_datasets.filter(lambda example, index: index % 100 == 0, with_indices=True) | |
# base_dir = "/content/drive/MyDrive/Model_wts" | |
base_dir = "Model_wts" | |
output_dir = os.path.join(base_dir,str(len(train_indices_to_keep))+"_examples") | |
output_dir | |
if os.path.exists(base_dir): | |
pass | |
else: | |
os.mkdir(base_dir) | |
if os.path.exists(output_dir): | |
pass | |
else: | |
os.mkdir(output_dir) | |
import os | |
print(f"Shapes of the datasets:") | |
print(f"Training: {reduced_dataset['train'].shape}") | |
print(f"Validation: {reduced_dataset['validation'].shape}") | |
print(f"Test: {reduced_dataset['test'].shape}") | |
print(reduced_dataset) | |
# In[27]: | |
training_args = TrainingArguments( | |
output_dir=output_dir, | |
learning_rate=1e-5, | |
num_train_epochs=7, | |
weight_decay=0.01, | |
logging_steps=1, | |
# max_steps=1, | |
logging_strategy="epoch", | |
# auto_find_batch_size=True, | |
evaluation_strategy="epoch", | |
per_device_train_batch_size=4, | |
) | |
trainer = Trainer( | |
model=original_model, | |
args=training_args, | |
train_dataset=reduced_dataset['train'], | |
eval_dataset=reduced_dataset['validation'] | |
) | |
print ("***training getting started**") | |
trainer.train() | |
trainer.save_model() | |
print ("Model saved") | |
# trainer.train(resume_from_checkpoint=True) | |
# trainer.train(resume_from_checkpoint=True) | |
# trainer.train(resume_from_checkpoint=True) | |
# | |
del original_model | |
instruct_model = AutoModelForSeq2SeqLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16,device_map='auto') | |
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16,device_map='auto') | |
dash_line="------------------------------------------" | |
index = 100 | |
dialogue = dataset['test'][index]['dialogue'] | |
human_baseline_summary = dataset['test'][index]['summary'] | |
prompt = f""" | |
Summarize the following conversation. | |
{dialogue} | |
Summary: | |
""" | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
original_model_outputs = original_model.generate(input_ids=input_ids.to('cuda'), generation_config=GenerationConfig(max_new_tokens=200, num_beams=1)) | |
original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True) | |
instruct_model_outputs = instruct_model.generate(input_ids=input_ids.to('cuda'), generation_config=GenerationConfig(max_new_tokens=200, num_beams=1)) | |
instruct_model_text_output = tokenizer.decode(instruct_model_outputs[0], skip_special_tokens=True) | |
print ("Original converstation is :\n") | |
print (dialogue) | |
print(dash_line) | |
print(f'BASELINE HUMAN SUMMARY:\n{human_baseline_summary}') | |
print(dash_line) | |
print(f'ORIGINAL MODEL:\n{original_model_text_output}') | |
print(dash_line) | |
print(f'INSTRUCT MODEL:\n{instruct_model_text_output}') | |
rouge = evaluate.load('rouge') | |
dialogues = dataset['test']['dialogue'] | |
human_baseline_summaries = dataset['test']['summary'] | |
original_model_summaries = [] | |
instruct_model_summaries = [] | |
for _, dialogue in enumerate(dialogues): | |
prompt = f""" | |
Summarize the following conversation. | |
{dialogue} | |
Summary: """ | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
original_model_outputs = original_model.generate(input_ids=input_ids.to('cuda'), generation_config=GenerationConfig(max_new_tokens=200)) | |
original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True) | |
original_model_summaries.append(original_model_text_output) | |
instruct_model_outputs = instruct_model.generate(input_ids=input_ids.to('cuda'), generation_config=GenerationConfig(max_new_tokens=200)) | |
instruct_model_text_output = tokenizer.decode(instruct_model_outputs[0], skip_special_tokens=True) | |
instruct_model_summaries.append(instruct_model_text_output) | |
zipped_summaries = list(zip(human_baseline_summaries, original_model_summaries, instruct_model_summaries)) | |
df = pd.DataFrame(zipped_summaries, columns = ['human_baseline_summaries', 'original_model_summaries', 'instruct_model_summaries']) | |
df | |
df.to_csv(output_dir+"/"+str(len(train_indices_to_keep))+".csv",index=None) | |
original_model_results = rouge.compute( | |
predictions=original_model_summaries, | |
references=human_baseline_summaries[0:len(original_model_summaries)], | |
use_aggregator=True, | |
use_stemmer=True, | |
) | |
instruct_model_results = rouge.compute( | |
predictions=instruct_model_summaries, | |
references=human_baseline_summaries[0:len(instruct_model_summaries)], | |
use_aggregator=True, | |
use_stemmer=True, | |
) | |
print('ORIGINAL MODEL:') | |
print(original_model_results) | |
print('INSTRUCT MODEL:') | |
print(instruct_model_results) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment