Last active
October 20, 2021 14:17
-
-
Save davidmezzetti/f852631845929f7140e8f4d5fdd1c920 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset | |
from transformers import AutoTokenizer | |
from txtai.models import Registry | |
from txtai.pipeline import HFTrainer, Labels | |
ds = load_dataset("emotion") | |
# Set seed for reproducibility | |
seed() | |
# Define model | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
model = Simple(tokenizer.vocab_size, 128, len(ds["train"].unique("label"))) | |
# Train model | |
train = HFTrainer() | |
model, tokenizer = train((model, tokenizer), ds["train"], per_device_train_batch_size=8, learning_rate=1e-3, num_train_epochs=15, logging_steps=10000) | |
# Register custom model to fully support pipelines | |
Registry.register(model) | |
# Create labels pipeline using PyTorch model | |
labels = Labels((model, tokenizer), dynamic=False) | |
# Determine accuracy on validation set | |
results = [row["label"] == labels(row["text"])[0][0] for row in ds["validation"]] | |
print("Accuracy = ", sum(results) / len(ds["validation"])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment