Created
November 11, 2024 13:17
-
-
Save bench87/9273cc191999980edb4875c7ba18ff36 to your computer and use it in GitHub Desktop.
bert-base multi label
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
from torch.utils.tensorboard import SummaryWriter | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer | |
import numpy as np | |
import torch | |
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score | |
from transformers import EvalPrediction | |
# Load dataset | |
dataset = load_dataset("sem_eval_2018_task_1", "subtask5.english") | |
train_df = dataset['train'].to_pandas() | |
validation_df = dataset['validation'].to_pandas() | |
test_df = dataset['test'].to_pandas() | |
# Prepare labels | |
labels = [label for label in dataset['train'].features.keys() if label not in ['ID', 'Tweet']] | |
id2label = {idx: label for idx, label in enumerate(labels)} | |
label2id = {label: idx for idx, label in enumerate(labels)} | |
# Tokenizer and preprocessing | |
tokenizer = AutoTokenizer.from_pretrained("klue/bert-base") | |
def preprocess_data(examples): | |
text = examples["Tweet"] | |
encoding = tokenizer(text, padding="max_length", truncation=True, max_length=128) | |
labels_batch = {k: examples[k] for k in examples.keys() if k in labels} | |
labels_matrix = np.zeros((len(text), len(labels))) | |
for idx, label in enumerate(labels): | |
labels_matrix[:, idx] = labels_batch[label] | |
encoding["labels"] = labels_matrix.tolist() | |
return encoding | |
encoded_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset['train'].column_names) | |
encoded_dataset.set_format("torch") | |
# Load model | |
model = AutoModelForSequenceClassification.from_pretrained( | |
"klue/bert-base", | |
problem_type="multi_label_classification", | |
num_labels=len(labels), | |
id2label=id2label, | |
label2id=label2id | |
) | |
# Define unique log directory for each run | |
log_dir = f"./logs/run-{datetime.now().strftime('%Y%m%d-%H%M%S')}" | |
# Initialize TensorBoard writer | |
writer = SummaryWriter(log_dir=log_dir) | |
# Embedding의 크기를 확인 | |
embedding_size = model.get_input_embeddings().weight.size(0) | |
# metadata의 길이가 embedding의 크기와 같도록 맞추기 | |
# 예를 들어 각 라벨을 반복하여 embedding_size와 동일하게 맞춤 | |
expanded_labels = labels * (embedding_size // len(labels)) + labels[:embedding_size % len(labels)] | |
# Add embeddings for projector visualization | |
embeddings = model.get_input_embeddings().weight | |
writer.add_embedding(embeddings, metadata=expanded_labels) | |
# Training arguments | |
batch_size = 8 | |
metric_name = "f1" | |
args = TrainingArguments( | |
output_dir="bert-finetuned-sem_eval-english", | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
learning_rate=2e-5, | |
per_device_train_batch_size=batch_size, | |
per_device_eval_batch_size=batch_size, | |
num_train_epochs=5, | |
weight_decay=0.01, | |
load_best_model_at_end=True, | |
metric_for_best_model=metric_name, | |
logging_dir=log_dir, | |
logging_steps=10, | |
) | |
def multi_label_metrics(predictions, labels, threshold=0.5): | |
sigmoid = torch.nn.Sigmoid() | |
probs = sigmoid(torch.Tensor(predictions)) | |
y_pred = np.zeros(probs.shape) | |
y_pred[np.where(probs >= threshold)] = 1 | |
# Calculate metrics | |
f1_micro_average = f1_score(y_true=labels, y_pred=y_pred, average='micro') | |
roc_auc = roc_auc_score(y_true=labels, y_score=probs, average='micro') # Use probs as y_score | |
accuracy = accuracy_score(y_true=labels, y_pred=y_pred) | |
metrics = {'f1': f1_micro_average, 'roc_auc': roc_auc, 'accuracy': accuracy} | |
return metrics | |
def compute_metrics(p: EvalPrediction): | |
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions | |
result = multi_label_metrics(predictions=preds, labels=p.label_ids) | |
return result | |
# Trainer | |
trainer = Trainer( | |
model=model, | |
args=args, | |
train_dataset=encoded_dataset["train"], | |
eval_dataset=encoded_dataset["validation"], | |
tokenizer=tokenizer, | |
compute_metrics=compute_metrics | |
) | |
# Start training and logging to TensorBoard | |
trainer.train() | |
trainer.evaluate() | |
# Close TensorBoard writer | |
writer.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment