Last active
August 30, 2021 10:53
-
-
Save seanbenhur/8f015588a10aabfc7a36d954ff7a24c6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset | |
from transformers import AutoTokenizer | |
#load the dataset | |
dataset = load_dataset("imdb") | |
#create tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
def encode_batch(batch): | |
"""Encodes a batch of input data using the model tokenizer.""" | |
return tokenizer(batch["text"], max_length=80, truncation=True, padding="max_length") | |
# Encode the input data | |
dataset = dataset.map(encode_batch, batched=True) | |
# The transformers model expects the target class column to be named "labels" | |
dataset.rename_column_("label", "labels") | |
# Transform to pytorch tensors and only output the required columns | |
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment