This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Optional, Union, Tuple | |
import torch | |
from torch import nn | |
from transformers.activations import ACT2FN | |
from transformers.models.deberta.modeling_deberta import ( | |
DebertaPreTrainedModel, | |
DebertaModel, | |
) | |
from transformers.models.deberta_v2.modeling_deberta_v2 import ( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
class MultiSampleDropout(nn.Module): | |
def __init__(self, dropout_probs, problem_type, num_labels) -> None: | |
super().__init__() | |
self.dropouts = [nn.Dropout(p=p) for p in dropout_probs] | |
self.problem_type = problem_type |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function SUMMARIZE(input, repo_id="google/pegasus-xsum", use_gpu=false) { | |
// other models to consider | |
// short sequences | |
// sshleifer/distilbart-cnn-12-6 | |
// knkarthick/MEETING_SUMMARY | |
// long sequences | |
// google/bigbird-pegasus-large-bigpatent |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Generic LM | |
roberta-base | |
roberta-large | |
microsoft/deberta-v3-base | |
microsoft/deberta-v3-large | |
microsoft/deberta-v3-xsmall | |
# Long LM | |
allenai/longformer-base-4096 | |
google/bigbird-roberta-base |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
from typing import Optional, Any, Union, Dict | |
import mlflow | |
from transformers import TrainingArguments | |
from accelerate.tracking import GeneralTracker | |
from accelerate.logging import get_logger |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from itertools import chain | |
import evaluate | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling | |
if __name__ == "__main__": |
OlderNewer