Created
May 6, 2019 09:28
-
-
Save andreasvc/af8672abb1643a197f6aa6694dcafebb to your computer and use it in GitHub Desktop.
Prepare https://benjaminvdb.github.io/110kDBRD/ for use with fastText
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Prepare https://benjaminvdb.github.io/110kDBRD/ for use with fastText. | |
Divide train set into 90% train and 10% dev, balance positive and negative | |
rewiews, and shuffle. Write result in fastText format.""" | |
import os | |
import re | |
import random | |
import glob | |
from syntok.tokenizer import Tokenizer | |
def process(filenames, outfilename): | |
"""Read reviews from filename and write in fastText format.""" | |
with open(outfilename, 'w') as out: | |
for filename in filenames: | |
with open(filename) as inp: | |
text = re.sub(r'\s+', ' ', inp.read()) | |
docid, _ = os.path.splitext(os.path.basename(filename)) | |
_id, rating = docid.rsplit('_', 1) | |
if rating == '1' or rating == '2': | |
label = 'neg' | |
elif rating == '4' or rating == '5': | |
label = 'pos' | |
else: | |
raise ValueError | |
out.write('__label__%s\t%s\n' % (label, text)) | |
def main(): | |
"""Divide reviews.""" | |
pos = glob.glob('train/pos/*.txt') | |
neg = glob.glob('train/neg/*.txt') | |
n = len(pos) | |
random.shuffle(pos) | |
random.shuffle(neg) | |
cutoff = int(0.9 * n) | |
train = pos[:cutoff] + neg[:cutoff] | |
dev = pos[cutoff:] + neg[cutoff:] | |
random.shuffle(train) | |
random.shuffle(dev) | |
process(train, 'dutchbookreviews_train.txt') | |
process(dev, 'dutchbookreviews_dev.txt') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment