Skip to content

Instantly share code, notes, and snippets.

@sshleifer
Last active August 11, 2020 04:56
Show Gist options
  • Save sshleifer/9157129ad8315c64279eaf2c1776020c to your computer and use it in GitHub Desktop.
Save sshleifer/9157129ad8315c64279eaf2c1776020c to your computer and use it in GitHub Desktop.
# by stas00 and sshleifer
import nlp
from tqdm import tqdm
dataset = 'wmt19'
s = 'ru'
t = 'en'
pair = f'{s}-{t}'
ds = nlp.load_dataset(dataset, pair)
save_dir = Path(f'{dataset}-{pair}')
save_dir.mkdir(exist_ok=True)
for split in tqdm(ds.keys()):
tr_list = list(ds[split])
data = [x['translation'] for x in tr_list]
src, tgt = [], []
for example in data:
src.append(example[s])
tgt.append(example[t])
if split == 'validation':
split = 'val' # to save to val.source, val.target like summary datasets
src_path = save_dir.joinpath(f'{split}.source')
src_path.open('w+').write('\n'.join(src))
tgt_path = save_dir.joinpath(f'{split}.target')
tgt_path.open('w+').write('\n'.join(tgt))
@stas00
Copy link

stas00 commented Aug 11, 2020

wmt16 ru-en was about 1.5M records.

wmt19 is ~40M! can't read it all into memory - it wants to use some 50-100GB of RAM

This is a rewrite that doesn't load everything into memory. I'm sure there is a faster way to do it, but it works well, and for small datasets it's about the same speed.

from pathlib import Path
import nlp
from tqdm import tqdm

dataset = 'wmt19'
sl = 'ru'
tl = 'en'

pair = f'{sl}-{tl}'

ds = nlp.load_dataset(dataset, pair)

save_dir = Path(f'{dataset}-{pair}')
save_dir.mkdir(exist_ok=True)
for split in ds.keys():
    print(ds[split])

    # to save to val.source, val.target like summary datasets
    fn = 'val' if split == 'validation' else split
    src_path = save_dir.joinpath(f'{fn}.source')
    src_fp = src_path.open('w+')
    tgt_path = save_dir.joinpath(f'{fn}.target')
    tgt_fp = tgt_path.open('w+')

    def get_item():
        for x in tqdm(ds[split]):
            yield x['translation']

    for ex in get_item():
        src_fp.write(ex[sl]+'\n')
        tgt_fp.write(ex[tl]+'\n')

@stas00
Copy link

stas00 commented Aug 11, 2020

To the future reader: You will find the latest version of this script here: https://github.com/huggingface/transformers/blob/master/examples/seq2seq/download_wmt.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment