Last active
August 15, 2017 12:23
-
-
Save odashi/840b92e18ddf984bfacbdcd10d883f66 to your computer and use it in GitHub Desktop.
Byte-pair encoding tools
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import sys | |
from argparse import ArgumentParser | |
from collections import defaultdict | |
def parse_args(): | |
p = ArgumentParser('Converts word to integer using byte-pair encoding.') | |
p.add_argument( | |
'--input', | |
type=str, metavar='FILE', required=True, help='source corpus') | |
p.add_argument( | |
'--output', | |
type=str, metavar='FILE', required=True, help='output corpus') | |
p.add_argument( | |
'--vocab', | |
type=str, metavar='FILE', required=True, help='BPE vocabulary file') | |
args = p.parse_args() | |
return args | |
def load_vocab(filename): | |
vocab = {} | |
chars = defaultdict(lambda: '<unk>') | |
ops = [] | |
with open(filename) as fp: | |
for line in fp: | |
wid, left, right, *rest = line.strip().split('\t') | |
if len(rest) == 0: | |
vocab[left] = wid | |
chars[left] = left | |
else: | |
vocab[left + ' ' + right] = wid | |
before = '\t' + left + '\t' + right + '\t' | |
after = '\t' + left + ' ' + right + '\t' | |
ops.append((before, after)) | |
return vocab, chars, ops | |
def convert(line, vocab, chars, ops, memo): | |
wids = [] | |
for word in line.split(): | |
if word in memo: | |
wids.append(memo[word]) | |
else: | |
subwords = '\t' + '\t'.join(chars[c] for c in list(word)) + '\t</w>\t' | |
for before, after in ops: | |
subwords = subwords.replace(before, after) | |
result = ' '.join(vocab[x] for x in subwords.strip('\t').split('\t')) | |
memo[word] = result | |
wids.append(result) | |
return ' '.join(wids) | |
def main(): | |
args = parse_args() | |
vocab, chars, ops = load_vocab(args.vocab) | |
memo = {} | |
with open(args.input) as ifp, open(args.output, 'w') as ofp: | |
for i, line in enumerate(ifp): | |
print(convert(line, vocab, chars, ops, memo), file=ofp) | |
if (i + 1) % 100 == 0: | |
print('Processed %d lines.' % (i + 1), end='\r', file=sys.stderr) | |
print('Processed %d lines.' % (i + 1), file=sys.stderr) | |
if __name__ == '__main__': | |
main() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import multiprocessing | |
import math | |
import sys | |
import time | |
from argparse import ArgumentParser | |
from collections import defaultdict | |
def parse_args(): | |
p = ArgumentParser('Constructs vocabulary file.') | |
p.add_argument( | |
'--input', | |
type=str, metavar='FILE', required=True, help='source corpus') | |
p.add_argument( | |
'--output', | |
type=str, metavar='FILE', required=True, help='vocabulary file') | |
p.add_argument( | |
'--size', | |
type=int, metavar='N', required=True, help='vicabulary size') | |
p.add_argument( | |
'--min-freq', | |
type=int, metavar='N', required=True, help='minimum occurence per each character') | |
p.add_argument( | |
'--threads', | |
type=int, metavar='N', required=True, help='number of threads') | |
args = p.parse_args() | |
assert args.size > 3 | |
return args | |
def trace(*args, nolf=False): | |
print(*args, file=sys.stderr, end='\r' if nolf else '\n') | |
sys.stderr.flush() | |
def word2key(word): | |
return '\t' + '\t'.join(list(word)) + '\t</w>\t' | |
def key2subwords(key): | |
return key.strip('\t').split('\t') | |
def subwords2key(subwords): | |
return '\t' + '\t'.join(subwords) + '\t' | |
def calculate_unigram_freq(encoding): | |
freq = defaultdict(int) | |
for key, val in encoding.items(): | |
for sw in key2subwords(key): | |
freq[sw] += val | |
return freq | |
def calculate_bigram_freq(encoding): | |
freq = defaultdict(int) | |
for key, val in encoding.items(): | |
subwords = key2subwords(key) | |
for i in range(len(subwords) - 1): | |
freq[subwords[i], subwords[i + 1]] += val | |
return freq | |
def load_initial_encoding(filename): | |
encoding = defaultdict(int) | |
with open(filename) as fp: | |
for i, line in enumerate(fp): | |
for word in line.split(): | |
key = word2key(word) | |
encoding[key] += 1 | |
if (i + 1) % 10000 == 0: | |
trace('Loaded', i + 1, 'lines', nolf=True) | |
trace('Loaded', i + 1, 'lines') | |
trace('#unique encodings:', len(encoding)) | |
return i + 1, encoding | |
def filter_chars(encoding, min_freq): | |
freq = calculate_unigram_freq(encoding) | |
trace('#unique characters:', len(freq)) | |
filtered = [c for c in freq if freq[c] >= min_freq] | |
trace('#filtered characters:', len(filtered)) | |
result = defaultdict(int) | |
for key, val in encoding.items(): | |
subwords = key2subwords(key) | |
new_subwords = [(sw if sw in filtered else '<unk>') for sw in subwords] | |
new_key = subwords2key(new_subwords) | |
result[new_key] += val | |
trace('#filtered encodings:', len(result)) | |
return filtered, result | |
def make_shards(encoding, n): | |
shards = [{} for _ in range(n)] | |
for i, (key, val) in enumerate(encoding.items()): | |
shards[i % n][key] = val | |
return shards | |
def merge_freqs(freq, diffs): | |
for diff in diffs: | |
for key, val in diff.items(): | |
freq[key] += val | |
def merge_bigram(config): | |
encoding, (left, right) = config | |
before = '\t' + left + '\t' + right + '\t' | |
after = '\t' + left + ' ' + right + '\t' | |
new_encoding = {} | |
diff = defaultdict(int) | |
for key, val in encoding.items(): | |
new_key = key.replace(before, after) | |
if new_key != key: | |
subwords = key2subwords(new_key) | |
for i in range(len(subwords) - 1): | |
diff[subwords[i], subwords[i + 1]] += val | |
subwords = key2subwords(key) | |
for i in range(len(subwords) - 1): | |
diff[subwords[i], subwords[i + 1]] -= val | |
new_encoding[new_key] = val | |
return new_encoding, {key: val for key, val in diff.items() if val != 0} | |
def main(): | |
args = parse_args() | |
total_begin_time = time.time() | |
max_words = args.size - 3 | |
num_lines, encoding = load_initial_encoding(args.input) | |
chars, encoding = filter_chars(encoding, args.min_freq) | |
assert len(chars) <= max_words | |
pool = multiprocessing.Pool(args.threads) | |
shards = make_shards(encoding, args.threads) | |
for i, shard in enumerate(shards): | |
print('Shard %d size: %d' % (i, len(shard))) | |
freq = defaultdict(int) | |
merge_freqs(freq, pool.imap_unordered(calculate_bigram_freq, shards)) | |
ops = [] | |
for i in range(len(chars), max_words): | |
begin_time = time.time() | |
left, right = max(freq, key=freq.get) | |
merged_freq = freq[left, right] | |
results = pool.map(merge_bigram, ((x, (left, right)) for x in shards)) | |
shards = [x[0] for x in results] | |
merge_freqs(freq, (x[1] for x in results)) | |
ops.append((left, right)) | |
elapsed = time.time() - begin_time | |
trace('Merged %d/%d: "%s" + "%s" (freq=%d, time=%fs)' % (i + 1, max_words, left, right, merged_freq, elapsed)) | |
trace('Writing vocabulary file ...') | |
freq = defaultdict(int) | |
merge_freqs(freq, pool.imap_unordered(calculate_unigram_freq, shards)) | |
num_unk = freq['<unk>'] if '<unk>' in freq else 0 | |
with open(args.output, 'w') as fp: | |
print('0\t<unk>\t%d' % num_unk, file=fp) | |
print('1\t<s>\t%d' % num_lines, file=fp) | |
print('2\t</s>\t%d' % num_lines, file=fp) | |
print('3\t</w>\t%d' % freq['</w>'], file=fp) | |
for i, c in enumerate(sorted(chars)): | |
if c == '</w>': | |
continue | |
print('%d\t%s\t%d' % (i + 4, c, freq[c]), file=fp) | |
for i, bigram in enumerate(ops): | |
key = ' '.join(bigram) | |
print('%d\t%s\t%s\t%d' % (i + 3 + len(chars), bigram[0], bigram[1], freq[key]), file=fp) | |
total_elapsed = time.time() - total_begin_time | |
trace('Total time elapsed: %fs' % total_elapsed) | |
if __name__ == '__main__': | |
main() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import multiprocessing | |
import math | |
import sys | |
import time | |
from argparse import ArgumentParser | |
from collections import defaultdict | |
def parse_args(): | |
p = ArgumentParser('Constructs vocabulary file.') | |
p.add_argument( | |
'--input', | |
type=str, metavar='FILE', required=True, help='source corpus') | |
p.add_argument( | |
'--output', | |
type=str, metavar='FILE', required=True, help='vocabulary file') | |
p.add_argument( | |
'--size', | |
type=int, metavar='N', required=True, help='vicabulary size') | |
p.add_argument( | |
'--min-freq', | |
type=int, metavar='N', required=True, help='minimum occurence per each character') | |
p.add_argument( | |
'--threads', | |
type=int, metavar='N', required=True, help='number of threads') | |
args = p.parse_args() | |
assert args.size > 3 | |
return args | |
def trace(*args, nolf=False): | |
print(*args, file=sys.stderr, end='\r' if nolf else '\n') | |
sys.stderr.flush() | |
def word2key(word): | |
return '\t' + '\t'.join(list(word)) + '\t</w>\t' | |
def key2subwords(key): | |
return key.strip('\t').split('\t') | |
def subwords2key(subwords): | |
return '\t' + '\t'.join(subwords) + '\t' | |
def calculate_unigram_freq(encoding): | |
freq = defaultdict(int) | |
for key, val in encoding.items(): | |
for sw in key2subwords(key): | |
freq[sw] += val | |
return freq | |
def calculate_bigram_freq(encoding): | |
freq = defaultdict(int) | |
for key, val in encoding.items(): | |
subwords = key2subwords(key) | |
for i in range(len(subwords) - 1): | |
freq[subwords[i], subwords[i + 1]] += val | |
return freq | |
def load_initial_encoding(filename): | |
encoding = defaultdict(int) | |
with open(filename) as fp: | |
for i, line in enumerate(fp): | |
for word in line.split(): | |
key = word2key(word) | |
encoding[key] += 1 | |
if (i + 1) % 10000 == 0: | |
trace('Loaded', i + 1, 'lines', nolf=True) | |
trace('Loaded', i + 1, 'lines') | |
trace('#unique encodings:', len(encoding)) | |
return i + 1, encoding | |
def filter_chars(encoding, min_freq): | |
freq = calculate_unigram_freq(encoding) | |
trace('#unique characters:', len(freq)) | |
filtered = [c for c in freq if freq[c] >= min_freq] | |
trace('#filtered characters:', len(filtered)) | |
result = defaultdict(int) | |
for key, val in encoding.items(): | |
subwords = key2subwords(key) | |
new_subwords = [(sw if sw in filtered else '<unk>') for sw in subwords] | |
new_key = subwords2key(new_subwords) | |
result[new_key] += val | |
trace('#filtered encodings:', len(result)) | |
return filtered, result | |
def make_shards(encoding, n): | |
shards = [{} for _ in range(n)] | |
for i, (key, val) in enumerate(encoding.items()): | |
shards[i % n][key] = val | |
return shards | |
def merge_freqs(freqs): | |
total = defaultdict(int) | |
for freq in freqs: | |
for key, val in freq.items(): | |
total[key] += val | |
return total | |
def merge_bigram(encoding): | |
bigram = encoding[0] | |
before = '\t' + bigram[0] + '\t' + bigram[1] + '\t' | |
after = '\t' + bigram[0] + ' ' + bigram[1] + '\t' | |
result = {} | |
for key, val in encoding.items(): | |
if key == 0: | |
continue | |
new_key = key.replace(before, after) | |
result[new_key] = val | |
return result | |
def main(): | |
args = parse_args() | |
total_begin_time = time.time() | |
max_words = args.size - 3 | |
num_lines, encoding = load_initial_encoding(args.input) | |
chars, encoding = filter_chars(encoding, args.min_freq) | |
assert len(chars) <= max_words | |
pool = multiprocessing.Pool(args.threads) | |
shards = make_shards(encoding, args.threads) | |
for i, shard in enumerate(shards): | |
print('Shard %d size: %d' % (i, len(shard))) | |
ops = [] | |
for i in range(len(chars), max_words): | |
begin_time = time.time() | |
freq = merge_freqs(pool.imap_unordered(calculate_bigram_freq, shards)) | |
bigram = max(freq, key=freq.get) | |
for shard in shards: | |
shard[0] = bigram | |
shards = pool.map(merge_bigram, shards) | |
ops.append(bigram) | |
elapsed = time.time() - begin_time | |
l_str = ''.join(bigram[0].split(' ')) | |
r_str = ''.join(bigram[1].split(' ')) | |
trace('Merged %d/%d: %s + %s (freq=%d, time=%fs)' % (i + 1, max_words, l_str, r_str, freq[bigram], elapsed)) | |
trace('Writing vocabulary file ...') | |
freq = merge_freqs(pool.imap_unordered(calculate_unigram_freq, shards)) | |
num_unk = freq['<unk>'] if '<unk>' in freq else 0 | |
with open(args.output, 'w') as fp: | |
print('0\t<unk>\t%d' % num_unk, file=fp) | |
print('1\t<s>\t%d' % num_lines, file=fp) | |
print('2\t</s>\t%d' % num_lines, file=fp) | |
print('3\t</w>\t%d' % freq['</w>'], file=fp) | |
for i, c in enumerate(sorted(chars)): | |
if c == '</w>': | |
continue | |
print('%d\t%s\t%d' % (i + 4, c, freq[c]), file=fp) | |
for i, bigram in enumerate(ops): | |
key = ' '.join(bigram) | |
print('%d\t%s\t%s\t%d' % (i + 3 + len(chars), bigram[0], bigram[1], freq[key]), file=fp) | |
total_elapsed = time.time() - total_begin_time | |
trace('Total time elapsed: %fs' % total_elapsed) | |
if __name__ == '__main__': | |
main() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import sys | |
from argparse import ArgumentParser | |
from collections import defaultdict | |
def parse_args(): | |
p = ArgumentParser('Converts integer to words using byte-pair encoding.') | |
p.add_argument( | |
'--input', | |
type=str, metavar='FILE', required=True, help='source corpus') | |
p.add_argument( | |
'--output', | |
type=str, metavar='FILE', required=True, help='output corpus') | |
p.add_argument( | |
'--vocab', | |
type=str, metavar='FILE', required=True, help='BPE vocabulary file') | |
args = p.parse_args() | |
return args | |
def load_vocab(filename): | |
vocab = defaultdict(lambda: '<unk>') | |
with open(filename) as fp: | |
for line in fp: | |
wid, left, right, *rest = line.strip().split('\t') | |
if len(rest) == 0: | |
vocab[int(wid)] = left.split(' ') | |
else: | |
vocab[int(wid)] = left.split(' ') + right.split(' ') | |
return vocab | |
def convert(line, vocab): | |
cache = [] | |
words = [] | |
for wid in line.split(): | |
cache += vocab[int(wid)] | |
if cache[-1] == '</w>': | |
words.append(''.join(cache[:-1])) | |
cache = [] | |
if cache: | |
words.append(''.join(cache)) | |
return ' '.join(words) | |
def main(): | |
args = parse_args() | |
vocab = load_vocab(args.vocab) | |
with open(args.input) as ifp, open(args.output, 'w') as ofp: | |
for line in ifp: | |
print(convert(line, vocab), file=ofp) | |
if __name__ == '__main__': | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment