Skip to content

Instantly share code, notes, and snippets.

@bench87
Created November 11, 2024 13:17
Show Gist options
  • Save bench87/9273cc191999980edb4875c7ba18ff36 to your computer and use it in GitHub Desktop.
Save bench87/9273cc191999980edb4875c7ba18ff36 to your computer and use it in GitHub Desktop.
bert-base multi label
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import numpy as np
import torch
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
# Load dataset
dataset = load_dataset("sem_eval_2018_task_1", "subtask5.english")
train_df = dataset['train'].to_pandas()
validation_df = dataset['validation'].to_pandas()
test_df = dataset['test'].to_pandas()
# Prepare labels
labels = [label for label in dataset['train'].features.keys() if label not in ['ID', 'Tweet']]
id2label = {idx: label for idx, label in enumerate(labels)}
label2id = {label: idx for idx, label in enumerate(labels)}
# Tokenizer and preprocessing
tokenizer = AutoTokenizer.from_pretrained("klue/bert-base")
def preprocess_data(examples):
text = examples["Tweet"]
encoding = tokenizer(text, padding="max_length", truncation=True, max_length=128)
labels_batch = {k: examples[k] for k in examples.keys() if k in labels}
labels_matrix = np.zeros((len(text), len(labels)))
for idx, label in enumerate(labels):
labels_matrix[:, idx] = labels_batch[label]
encoding["labels"] = labels_matrix.tolist()
return encoding
encoded_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset['train'].column_names)
encoded_dataset.set_format("torch")
# Load model
model = AutoModelForSequenceClassification.from_pretrained(
"klue/bert-base",
problem_type="multi_label_classification",
num_labels=len(labels),
id2label=id2label,
label2id=label2id
)
# Define unique log directory for each run
log_dir = f"./logs/run-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
# Initialize TensorBoard writer
writer = SummaryWriter(log_dir=log_dir)
# Embedding의 크기를 확인
embedding_size = model.get_input_embeddings().weight.size(0)
# metadata의 길이가 embedding의 크기와 같도록 맞추기
# 예를 들어 각 라벨을 반복하여 embedding_size와 동일하게 맞춤
expanded_labels = labels * (embedding_size // len(labels)) + labels[:embedding_size % len(labels)]
# Add embeddings for projector visualization
embeddings = model.get_input_embeddings().weight
writer.add_embedding(embeddings, metadata=expanded_labels)
# Training arguments
batch_size = 8
metric_name = "f1"
args = TrainingArguments(
output_dir="bert-finetuned-sem_eval-english",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=5,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model=metric_name,
logging_dir=log_dir,
logging_steps=10,
)
def multi_label_metrics(predictions, labels, threshold=0.5):
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(torch.Tensor(predictions))
y_pred = np.zeros(probs.shape)
y_pred[np.where(probs >= threshold)] = 1
# Calculate metrics
f1_micro_average = f1_score(y_true=labels, y_pred=y_pred, average='micro')
roc_auc = roc_auc_score(y_true=labels, y_score=probs, average='micro') # Use probs as y_score
accuracy = accuracy_score(y_true=labels, y_pred=y_pred)
metrics = {'f1': f1_micro_average, 'roc_auc': roc_auc, 'accuracy': accuracy}
return metrics
def compute_metrics(p: EvalPrediction):
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
result = multi_label_metrics(predictions=preds, labels=p.label_ids)
return result
# Trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=encoded_dataset["train"],
eval_dataset=encoded_dataset["validation"],
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
# Start training and logging to TensorBoard
trainer.train()
trainer.evaluate()
# Close TensorBoard writer
writer.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment