Created
July 28, 2022 06:45
-
-
Save behitek/f47b25020b8084f9738c16cf75fb7e3d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset, load_metric | |
from transformers import AutoTokenizer | |
from transformers import DataCollatorWithPadding | |
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer | |
import numpy as np | |
dataset = load_dataset('json', data_files=['data/train_qa_vi_mailong.jsonl']) | |
checkpoint_name = "xlm-roberta-large" | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint_name) | |
def preprocess_function(examples): | |
return tokenizer(examples["question"], examples["context"], padding="max_length", truncation=True, max_length=512) | |
tokenized_data = dataset.map(preprocess_function, batched=True) | |
tokenized_data = tokenized_data["train"].train_test_split(test_size=0.1, seed=1996) | |
print(tokenized_data) | |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_name, num_labels=2) | |
training_args = TrainingArguments( | |
output_dir=f"{checkpoint_name}-finetuned-retrieval", | |
learning_rate=1e-6, | |
auto_find_batch_size=True, | |
num_train_epochs=2, | |
save_total_limit=5, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_data["train"], | |
eval_dataset=tokenized_data["test"], | |
tokenizer=tokenizer, | |
data_collator=data_collator, | |
) | |
trainer.train() | |
# print(tokenized_data["test"]["label"]) | |
results = trainer.predict(tokenized_data["test"]) | |
# results = trainer.evaluate() | |
# print(results.label_ids) | |
labels = tokenized_data["test"]["label"] | |
preds = results.predictions | |
count_true = 0 | |
count_false = 0 | |
count_correct = 0 | |
predict_true = 0 | |
predict_false = 0 | |
for label, pred in zip(labels, preds): | |
pred = np.argmax(pred) | |
print(label, pred) | |
if label == pred: | |
count_correct += 1 | |
if label == 0: | |
count_false += 1 | |
if label == 1: | |
count_true += 1 | |
if pred == 0: | |
predict_false += 1 | |
if pred == 1: | |
predict_true += 1 | |
print("False label: {}\nTrue label: {}\nCorrect rate: {}".format(count_false, count_true, count_correct/len(labels))) | |
print(predict_true, predict_false) | |
# 150.65.183.82 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment