-
-
Save sshleifer/9157129ad8315c64279eaf2c1776020c to your computer and use it in GitHub Desktop.
# by stas00 and sshleifer | |
import nlp | |
from tqdm import tqdm | |
dataset = 'wmt19' | |
s = 'ru' | |
t = 'en' | |
pair = f'{s}-{t}' | |
ds = nlp.load_dataset(dataset, pair) | |
save_dir = Path(f'{dataset}-{pair}') | |
save_dir.mkdir(exist_ok=True) | |
for split in tqdm(ds.keys()): | |
tr_list = list(ds[split]) | |
data = [x['translation'] for x in tr_list] | |
src, tgt = [], [] | |
for example in data: | |
src.append(example[s]) | |
tgt.append(example[t]) | |
if split == 'validation': | |
split = 'val' # to save to val.source, val.target like summary datasets | |
src_path = save_dir.joinpath(f'{split}.source') | |
src_path.open('w+').write('\n'.join(src)) | |
tgt_path = save_dir.joinpath(f'{split}.target') | |
tgt_path.open('w+').write('\n'.join(tgt)) |
@stas00 this is how I got en-ro
Thank you, @sshleifer!
You may want to add:
save_dir.mkdir(exist_ok=True)
after save_dir = Path('wmt_en_ro')
and probably rename en_ro
to ro_en
since that's the direction it's creating.
and there was another buglet, it was reversing direction, here is an updated script that only needs to change the parameters on top:
from pathlib import Path
import nlp
from tqdm import tqdm
dataset = 'wmt19'
s = 'ru'
t = 'en'
pair = f'{s}-{t}'
ds = nlp.load_dataset(dataset, pair)
save_dir = Path(f'{dataset}-{pair}')
save_dir.mkdir(exist_ok=True)
for split in tqdm(ds.keys()):
tr_list = list(ds[split])
data = [x['translation'] for x in tr_list]
src, tgt = [], []
for example in data:
src.append(example[s])
tgt.append(example[t])
if split == 'validation':
split = 'val' # to save to val.source, val.target like summary datasets
src_path = save_dir.joinpath(f'{split}.source')
src_path.open('w+').write('\n'.join(src))
tgt_path = save_dir.joinpath(f'{split}.target')
tgt_path.open('w+').write('\n'.join(tgt))
thx. Updated the gist and credited you as an author.
wmt16 ru-en was about 1.5M records.
wmt19 is ~40M! can't read it all into memory - it wants to use some 50-100GB of RAM
This is a rewrite that doesn't load everything into memory. I'm sure there is a faster way to do it, but it works well, and for small datasets it's about the same speed.
from pathlib import Path
import nlp
from tqdm import tqdm
dataset = 'wmt19'
sl = 'ru'
tl = 'en'
pair = f'{sl}-{tl}'
ds = nlp.load_dataset(dataset, pair)
save_dir = Path(f'{dataset}-{pair}')
save_dir.mkdir(exist_ok=True)
for split in ds.keys():
print(ds[split])
# to save to val.source, val.target like summary datasets
fn = 'val' if split == 'validation' else split
src_path = save_dir.joinpath(f'{fn}.source')
src_fp = src_path.open('w+')
tgt_path = save_dir.joinpath(f'{fn}.target')
tgt_fp = tgt_path.open('w+')
def get_item():
for x in tqdm(ds[split]):
yield x['translation']
for ex in get_item():
src_fp.write(ex[sl]+'\n')
tgt_fp.write(ex[tl]+'\n')
To the future reader: You will find the latest version of this script here: https://github.com/huggingface/transformers/blob/master/examples/seq2seq/download_wmt.py
This puts the WMT data in the same format as the CNNDM and XSUM summarization datasets.