Last active
May 17, 2022 04:15
-
-
Save jinhangjiang/9db134cca88da9a063ffd3b09372fbc2 to your computer and use it in GitHub Desktop.
Code Demo for Data2vec vs. SBERT on Text Classification
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Call Model | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels = num_labels).to("cuda") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load Metrics | |
metric = load_metric(metrics_name) # e.g. "f1" | |
# Create Metrics | |
def compute_metrics(eval_pred): | |
predictions, labels = eval_pred | |
predictions = np.argmax(predictions, axis=1) | |
# 'micro', 'macro', etc. are for multi-label classification. If you are running a binary classification, leave it as default or specify "binary" for average | |
return metric.compute(predictions=predictions, references=labels, average="micro") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Specifiy the arguments for the trainer | |
training_args = TrainingArguments( | |
output_dir='./results', # output directory | |
num_train_epochs=num_epochs, # total number of training epochs | |
per_device_train_batch_size=8, # batch size per device during training | |
per_device_eval_batch_size=20, # batch size for evaluation | |
warmup_steps=500, # number of warmup steps for learning rate scheduler | |
weight_decay=0.01, # strength of weight decay | |
logging_dir='./logs', # directory for storing logs | |
load_best_model_at_end=True, # load the best model when finished training (default metric is loss) | |
metric_for_best_model = "f1", # select the base metrics | |
logging_steps=200, # log & save weights each logging_steps | |
save_steps=200, | |
evaluation_strategy="steps", # evaluate each `logging_steps` | |
) | |
# Call the Trainer | |
trainer = Trainer( | |
model=model, # the instantiated Transformers model to be trained | |
args=training_args, # training arguments, defined above | |
train_dataset=train_dataset, # training dataset | |
eval_dataset=valid_dataset, # evaluation dataset | |
compute_metrics=compute_metrics, # the callback that computes metrics of interest | |
) | |
# Train the model | |
trainer.train() | |
# Call the summary | |
trainer.evaluate() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pip install torch transformers memory_profiler datasets | |
import torch | |
import random | |
from transformers.file_utils import is_tf_available, is_torch_available, is_torch_tpu_available | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments | |
from datasets import load_metric | |
from sklearn.model_selection import train_test_split | |
import pandas as pd | |
import numpy as np | |
%load_ext memory_profiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.datasets import fetch_20newsgroups | |
dataset = fetch_20newsgroups(subset="all", | |
shuffle=True, | |
remove=("headers", "footers", "quotes")) | |
documents = dataset.data | |
labels = dataset.target |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MakeTorchData(torch.utils.data.Dataset): | |
def __init__(self, encodings, labels): | |
self.encodings = encodings | |
self.labels = labels | |
def __getitem__(self, idx): | |
item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()} | |
item["labels"] = torch.tensor([self.labels[idx]]) | |
return item | |
def __len__(self): | |
return len(self.labels) | |
# convert our tokenized data into a torch Dataset | |
train_dataset = MakeTorchData(train_encodings, y_train.ravel()) | |
valid_dataset = MakeTorchData(valid_encodings, y_test.ravel()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# option 1: | |
trainer.save_model("path/to/model") | |
# option 2: | |
model.save_pretrained("path/to/model") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Make data | |
X = Data | |
y = Target | |
y = pd.factorize(y)[0] # convert labels to numbers | |
# Split Data | |
X_train, X_test, y_train, y_test = train_test_split(X.tolist(), y, test_size=0.33) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Call the Tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=True) | |
# Encode the text | |
train_encodings = tokenizer(X_train, truncation=True, padding=True, max_length=512) | |
valid_encodings = tokenizer(X_test, truncation=True, padding=True, max_length=512) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment