This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import csv | |
import errno | |
import os | |
import re |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Downloads HuggingFace datasets and concatenates them based on split type.""" | |
import datasets | |
from datasets import concatenate_datasets, load_dataset | |
from datasets.dataset_dict import DatasetDict | |
# `config_name`s for the `universal_dependencies` dataset | |
TBIDS = [ | |
"af_afribooms", |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# replace function: https://github.com/huggingface/transformers/blob/f9c16b02e3f5d2ee0a1cadb6f50dc9e3281e2536/src/transformers/data/data_collator.py#L78 | |
def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]: | |
"""place this function in transformers/data/data_collator.py""" | |
import torch | |
if not isinstance(features[0], (dict, BatchEncoding)): | |
features = [vars(f) for f in features] | |
first = features[0] | |
batch = {} |