This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Metric | Value | |
---|---|---|
MRR@1 | 0.8084420567920184 | |
MRR@2 | 0.8511128165771297 | |
MRR@3 | 0.861857252494244 | |
MRR@5 | 0.8678357636224099 | |
MRR@10 | 0.8696916151981385 | |
Top 1 Accuracy | 0.8084420567920184 | |
Top 2 Accuracy | 0.8937835763622409 | |
Top 3 Accuracy | 0.9260168841135841 | |
Top 5 Accuracy | 0.9516500383729855 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from simpletransformers.retrieval import RetrievalModel, RetrievalArgs | |
logging.basicConfig(level=logging.INFO) | |
transformers_logger = logging.getLogger("transformers") | |
transformers_logger.setLevel(logging.WARNING) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from simpletransformers.retrieval import RetrievalModel, RetrievalArgs | |
logging.basicConfig(level=logging.INFO) | |
transformers_logger = logging.getLogger("transformers") | |
transformers_logger.setLevel(logging.WARNING) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import pandas as pd | |
from simpletransformers.classification import ( | |
MultiLabelClassificationModel, | |
MultiLabelClassificationArgs, | |
) | |
logging.basicConfig(level=logging.INFO) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
from transformers import ( | |
RobertaModel | |
) | |
class RobertaForMultiLabelSequenceClassification(BertPreTrainedModel): | |
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model | English to Sinhalese | Sinhalese to English | |
---|---|---|---|
mT5 | 10.3 | 24.4 | |
Tatoeba | 9.2 | 22.1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Predict | |
sinhala_preds = model.predict(to_sinhala) | |
eng_sin_bleu = sacrebleu.corpus_bleu(sinhala_preds, sinhala_truth) | |
print("--------------------------") | |
print("English to Sinhalese: ", eng_sin_bleu.score) | |
english_preds = model.predict(to_english) | |
sin_eng_bleu = sacrebleu.corpus_bleu(english_preds, english_truth) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
eval_df = pd.read_csv("data/eval.tsv", sep="\t").astype(str) | |
sinhala_truth = [eval_df.loc[eval_df["prefix"] == "translate english to sinhala"]["target_text"].tolist()] | |
to_sinhala = eval_df.loc[eval_df["prefix"] == "translate english to sinhala"]["input_text"].tolist() | |
english_truth = [eval_df.loc[eval_df["prefix"] == "translate sinhala to english"]["target_text"].tolist()] | |
to_english = eval_df.loc[eval_df["prefix"] == "translate sinhala to english"]["input_text"].tolist() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import sacrebleu | |
import pandas as pd | |
from simpletransformers.t5 import T5Model, T5Args | |
logging.basicConfig(level=logging.INFO) | |
transformers_logger = logging.getLogger("transformers") | |
transformers_logger.setLevel(logging.WARNING) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Train the model | |
model.train_model(train_df, eval_data=eval_df) |
NewerOlder