Created
December 4, 2019 21:38
-
-
Save lambdabaa/cd3c46de53fed0807283f11351428c06 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import collections | |
import json | |
import os | |
import random | |
import re | |
import sys | |
IDENT = re.compile(r'^[_a-zA-Z][_a-zA-Z0-9]*$') | |
def is_punctuation(word): | |
if IDENT.match(word): | |
return False | |
if "'" in word or '"' in word: | |
return False | |
if word.replace('.', '').isdigit() or word.startswith('0x'): | |
return False | |
if '$' in word or ' ' in word: | |
return False | |
return True | |
def read_last_n(tokens): | |
result = [] | |
for token in reversed(tokens): | |
if token == 'new': | |
continue | |
result.append(token) | |
if len(result) == 100: | |
break | |
return reversed(result) | |
def main(): | |
files = [f for f in os.listdir('.') if f.endswith('.json') and not f.endswith('idx2word.json') and not f.endswith('word2idx.json')] | |
with open('word2idx.json', 'r') as f: | |
word2idx = json.load(f) | |
with open('idx2word.json', 'r') as f: | |
idx2word = json.load(f) | |
vocabulary = set(word2idx.keys()) | |
idx2word_inverse = {x: i for i, x in idx2word.items()} | |
idx2word_size = len(idx2word) | |
def process(word): | |
if word in vocabulary: | |
return word | |
if "'" in word or '"' in word: | |
return '<str>' | |
if word.replace('.', '').isdigit() or word.startswith('0x'): | |
return '<num>' | |
return '<unk>' | |
# Unique in-vocabulary training examples. | |
uniq1 = set() | |
# Unique repetition training examples. | |
uniq2 = set() | |
# Parallel arrays of features (X) and labels (Y). | |
X = [] | |
Y = [] | |
filecount = 0 | |
for file in files: | |
filecount += 1 | |
with open(file, 'r') as f: | |
source = json.load(f) | |
for idx, item in enumerate(source): | |
if item == 'new' or is_punctuation(item): | |
continue | |
x = [] | |
index = -1 | |
try: | |
prev = read_last_n(source[:idx]) | |
except: | |
continue | |
for idx2, w in enumerate(prev): | |
if w == item: | |
index = idx2 | |
val = process(w) | |
if val not in word2idx: | |
word2idx[val] = len(word2idx) | |
x.append(word2idx[val]) | |
if item not in vocabulary and index == -1: | |
continue | |
if item in vocabulary: | |
y = idx2word_inverse[item] | |
key = ','.join(map(str, x + [y])) | |
if key not in uniq1: | |
uniq1.add(key) | |
X.append(x) | |
Y.append(y) | |
if index != -1: | |
y = idx2word_size + index | |
key = ','.join(map(str, x + [y])) | |
if key not in uniq2: | |
uniq2.add(key) | |
X.append(x) | |
Y.append(y) | |
sys.stdout.write('\r[%d | %d | %d]' % (filecount, len(uniq1), len(uniq2))) | |
with open('x.json', 'w') as f: | |
f.write(json.dumps(X)) | |
with open('y.json', 'w') as f: | |
f.write(json.dumps(Y)) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment