Skip to content

Instantly share code, notes, and snippets.

@Houssem96
Created August 22, 2021 13:18
Show Gist options
  • Save Houssem96/392decdaa664b15873d875532d83d608 to your computer and use it in GitHub Desktop.
Save Houssem96/392decdaa664b15873d875532d83d608 to your computer and use it in GitHub Desktop.
loading and finetunning distillbert model on ag_news dataset
num_labels = 4
model_name = "distilbert-base-uncased"
AG_news_dataset_train = load_dataset(dataset_name, split='train[:8000]')
AG_news_dataset_validation = load_dataset(dataset_name, split='train[-2000:]')
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
def tokenize(batch):
return tokenizer(batch["text"], padding=True, truncation=True)
AG_news_dataset_encoded_train = AG_news_dataset_train.map(tokenize, batched=True, batch_size=None)
AG_news_dataset_encoded_validation = AG_news_dataset_validation.map(tokenize, batched=True, batch_size=None)
model = (AutoModelForSequenceClassification
.from_pretrained(model_name, num_labels=num_labels)
.to(device))
AG_news_dataset_encoded_train.set_format("torch",
columns=["input_ids", "attention_mask", "label"])
AG_news_dataset_encoded_validation.set_format("torch",
columns=["input_ids", "attention_mask", "label"])
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
f1 = f1_score(labels, preds, average="weighted")
acc = accuracy_score(labels, preds)
return {"accuracy": acc, "f1": f1}
batch_size = 32
logging_steps = len(AG_news_dataset_encoded_train['text']) // batch_size
training_args = TrainingArguments(output_dir="results",
num_train_epochs=2,
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
load_best_model_at_end=True,
metric_for_best_model="f1",
weight_decay=0.01,
evaluation_strategy="steps", # not epochs
disable_tqdm=False,
logging_steps=logging_steps,)
trainer = Trainer(model=model, args=training_args,
compute_metrics=compute_metrics,
train_dataset=AG_news_dataset_encoded_train,
eval_dataset=AG_news_dataset_encoded_validation)
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment