Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save alexeyev/abf3804de1dbd62a74fdf98ea22c06d9 to your computer and use it in GitHub Desktop.
Save alexeyev/abf3804de1dbd62a74fdf98ea22c06d9 to your computer and use it in GitHub Desktop.
Binary classification with DistilBERT, minimal example
import evaluate
import numpy as np
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}
accuracy = evaluate.load("accuracy")
imdb = load_dataset("imdb")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)
tokenized_imdb = imdb.map(preprocess_function, batched=True)
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return accuracy.compute(predictions=predictions, references=labels)
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased",
num_labels=2,
id2label=id2label,
label2id=label2id)
training_args = TrainingArguments(output_dir="my_wonderful_model")
trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_imdb["train"],
eval_dataset=tokenized_imdb["test"], tokenizer=tokenizer, data_collator=data_collator,
compute_metrics=compute_metrics)
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment