Created
September 24, 2024 04:43
-
-
Save Joelfranklin96/92b637127f3e5bf20a3423d7edd99d32 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import BertTokenizer, BertForTokenClassification, Trainer, TrainingArguments | |
from datasets import load_dataset | |
# Load pre-trained BERT tokenizer | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
# Tokenize the text and align the labels | |
def tokenize_and_align_labels(examples): | |
tokenized_inputs = tokenizer(examples['tokens'], truncation=True, is_split_into_words=True, padding='max_length') | |
labels = [] | |
for i, label in enumerate(examples['labels']): | |
word_ids = tokenized_inputs.word_ids(batch_index=i) | |
label_ids = [] | |
previous_word_idx = None | |
for word_idx in word_ids: | |
if word_idx is None: | |
label_ids.append(-100) # Ignored in loss computation | |
elif word_idx != previous_word_idx: | |
label_ids.append(label[word_idx]) # Label of the first token | |
else: | |
label_ids.append(label[word_idx]) # Label of subsequent subword tokens | |
previous_word_idx = word_idx | |
labels.append(label_ids) | |
tokenized_inputs["labels"] = labels | |
return tokenized_inputs | |
# Load dataset and prepare it for BERT | |
dataset = load_dataset('conll2003') | |
tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True) | |
# Load pre-trained BERT model for token classification | |
model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=num_labels) | |
# Define training arguments | |
training_args = TrainingArguments( | |
output_dir='./results', | |
evaluation_strategy="epoch", | |
learning_rate=2e-5, | |
per_device_train_batch_size=16, | |
per_device_eval_batch_size=16, | |
num_train_epochs=3, | |
weight_decay=0.01, | |
) | |
# Fine-tune BERT using Trainer API | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_dataset['train'], | |
eval_dataset=tokenized_dataset['validation'], | |
) | |
# Start training | |
trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment