Last active
October 1, 2021 12:46
-
-
Save jbrry/414f520e87919dd35f397eaede9ac844 to your computer and use it in GitHub Desktop.
Downloads HuggingFace datasets and concatenates them based on split type.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Downloads HuggingFace datasets and concatenates them based on split type.""" | |
import datasets | |
from datasets import concatenate_datasets, load_dataset | |
from datasets.dataset_dict import DatasetDict | |
# `config_name`s for the `universal_dependencies` dataset | |
TBIDS = [ | |
"af_afribooms", | |
"ga_idt", | |
] | |
SPLITS = ["train", "validation", "test"] | |
train_files = [] | |
validation_files = [] | |
test_files = [] | |
# Store files based on split type | |
tmp_files = {split: [] for split in SPLITS} | |
# Download and sort the files by split type | |
for tbid in TBIDS: | |
for split in SPLITS: | |
# first argument is the dataset, second is the config and split determines the split type | |
d = load_dataset("universal_dependencies", tbid, split=split) | |
tmp_files[split].append(d) | |
# Create a DatasetDict from concatenated datasets | |
dd = datasets.DatasetDict( | |
{split: concatenate_datasets(files) for split, files in tmp_files.items()} | |
) | |
print(f"the concatenated dataset \n {dd}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment