Skip to content

Instantly share code, notes, and snippets.

@Joelfranklin96
Created September 24, 2024 04:43
Show Gist options
  • Save Joelfranklin96/92b637127f3e5bf20a3423d7edd99d32 to your computer and use it in GitHub Desktop.
Save Joelfranklin96/92b637127f3e5bf20a3423d7edd99d32 to your computer and use it in GitHub Desktop.
from transformers import BertTokenizer, BertForTokenClassification, Trainer, TrainingArguments
from datasets import load_dataset
# Load pre-trained BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Tokenize the text and align the labels
def tokenize_and_align_labels(examples):
tokenized_inputs = tokenizer(examples['tokens'], truncation=True, is_split_into_words=True, padding='max_length')
labels = []
for i, label in enumerate(examples['labels']):
word_ids = tokenized_inputs.word_ids(batch_index=i)
label_ids = []
previous_word_idx = None
for word_idx in word_ids:
if word_idx is None:
label_ids.append(-100) # Ignored in loss computation
elif word_idx != previous_word_idx:
label_ids.append(label[word_idx]) # Label of the first token
else:
label_ids.append(label[word_idx]) # Label of subsequent subword tokens
previous_word_idx = word_idx
labels.append(label_ids)
tokenized_inputs["labels"] = labels
return tokenized_inputs
# Load dataset and prepare it for BERT
dataset = load_dataset('conll2003')
tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True)
# Load pre-trained BERT model for token classification
model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=num_labels)
# Define training arguments
training_args = TrainingArguments(
output_dir='./results',
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
)
# Fine-tune BERT using Trainer API
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset['train'],
eval_dataset=tokenized_dataset['validation'],
)
# Start training
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment