Skip to content

Instantly share code, notes, and snippets.

@sshleifer
Last active August 11, 2020 04:56
Show Gist options
  • Save sshleifer/9157129ad8315c64279eaf2c1776020c to your computer and use it in GitHub Desktop.
Save sshleifer/9157129ad8315c64279eaf2c1776020c to your computer and use it in GitHub Desktop.
# by stas00 and sshleifer
import nlp
from tqdm import tqdm
dataset = 'wmt19'
s = 'ru'
t = 'en'
pair = f'{s}-{t}'
ds = nlp.load_dataset(dataset, pair)
save_dir = Path(f'{dataset}-{pair}')
save_dir.mkdir(exist_ok=True)
for split in tqdm(ds.keys()):
tr_list = list(ds[split])
data = [x['translation'] for x in tr_list]
src, tgt = [], []
for example in data:
src.append(example[s])
tgt.append(example[t])
if split == 'validation':
split = 'val' # to save to val.source, val.target like summary datasets
src_path = save_dir.joinpath(f'{split}.source')
src_path.open('w+').write('\n'.join(src))
tgt_path = save_dir.joinpath(f'{split}.target')
tgt_path.open('w+').write('\n'.join(tgt))
@sshleifer
Copy link
Author

This puts the WMT data in the same format as the CNNDM and XSUM summarization datasets.

@sshleifer
Copy link
Author

@stas00 this is how I got en-ro

@stas00
Copy link

stas00 commented Aug 10, 2020

Thank you, @sshleifer!

You may want to add:

save_dir.mkdir(exist_ok=True)

after save_dir = Path('wmt_en_ro')

and probably rename en_ro to ro_en since that's the direction it's creating.

@stas00
Copy link

stas00 commented Aug 11, 2020

and there was another buglet, it was reversing direction, here is an updated script that only needs to change the parameters on top:

from pathlib import Path
import nlp
from tqdm import tqdm

dataset = 'wmt19'
s = 'ru'
t = 'en'

pair = f'{s}-{t}'

ds = nlp.load_dataset(dataset, pair)
save_dir = Path(f'{dataset}-{pair}')
save_dir.mkdir(exist_ok=True)
for split in tqdm(ds.keys()):
    tr_list = list(ds[split])
    data = [x['translation'] for x in tr_list]
    src, tgt = [], []
    for example in data:
        src.append(example[s])
        tgt.append(example[t])
    if split == 'validation':
        split = 'val' # to save to val.source, val.target like summary datasets
    src_path = save_dir.joinpath(f'{split}.source')
    src_path.open('w+').write('\n'.join(src))
    tgt_path = save_dir.joinpath(f'{split}.target')
    tgt_path.open('w+').write('\n'.join(tgt))

@sshleifer
Copy link
Author

thx. Updated the gist and credited you as an author.

@stas00
Copy link

stas00 commented Aug 11, 2020

wmt16 ru-en was about 1.5M records.

wmt19 is ~40M! can't read it all into memory - it wants to use some 50-100GB of RAM

This is a rewrite that doesn't load everything into memory. I'm sure there is a faster way to do it, but it works well, and for small datasets it's about the same speed.

from pathlib import Path
import nlp
from tqdm import tqdm

dataset = 'wmt19'
sl = 'ru'
tl = 'en'

pair = f'{sl}-{tl}'

ds = nlp.load_dataset(dataset, pair)

save_dir = Path(f'{dataset}-{pair}')
save_dir.mkdir(exist_ok=True)
for split in ds.keys():
    print(ds[split])

    # to save to val.source, val.target like summary datasets
    fn = 'val' if split == 'validation' else split
    src_path = save_dir.joinpath(f'{fn}.source')
    src_fp = src_path.open('w+')
    tgt_path = save_dir.joinpath(f'{fn}.target')
    tgt_fp = tgt_path.open('w+')

    def get_item():
        for x in tqdm(ds[split]):
            yield x['translation']

    for ex in get_item():
        src_fp.write(ex[sl]+'\n')
        tgt_fp.write(ex[tl]+'\n')

@stas00
Copy link

stas00 commented Aug 11, 2020

To the future reader: You will find the latest version of this script here: https://github.com/huggingface/transformers/blob/master/examples/seq2seq/download_wmt.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment