#!/usr/bin/env python3 import multiprocessing import math import sys import time from argparse import ArgumentParser from collections import defaultdict def parse_args(): p = ArgumentParser('Constructs vocabulary file.') p.add_argument( '--input', type=str, metavar='FILE', required=True, help='source corpus') p.add_argument( '--output', type=str, metavar='FILE', required=True, help='vocabulary file') p.add_argument( '--size', type=int, metavar='N', required=True, help='vicabulary size') p.add_argument( '--min-freq', type=int, metavar='N', required=True, help='minimum occurence per each character') p.add_argument( '--threads', type=int, metavar='N', required=True, help='number of threads') args = p.parse_args() assert args.size > 3 return args def trace(*args, nolf=False): print(*args, file=sys.stderr, end='\r' if nolf else '\n') sys.stderr.flush() def word2key(word): return '\t' + '\t'.join(list(word)) + '\t</w>\t' def key2subwords(key): return key.strip('\t').split('\t') def subwords2key(subwords): return '\t' + '\t'.join(subwords) + '\t' def calculate_unigram_freq(encoding): freq = defaultdict(int) for key, val in encoding.items(): for sw in key2subwords(key): freq[sw] += val return freq def calculate_bigram_freq(encoding): freq = defaultdict(int) for key, val in encoding.items(): subwords = key2subwords(key) for i in range(len(subwords) - 1): freq[subwords[i], subwords[i + 1]] += val return freq def load_initial_encoding(filename): encoding = defaultdict(int) with open(filename) as fp: for i, line in enumerate(fp): for word in line.split(): key = word2key(word) encoding[key] += 1 if (i + 1) % 10000 == 0: trace('Loaded', i + 1, 'lines', nolf=True) trace('Loaded', i + 1, 'lines') trace('#unique encodings:', len(encoding)) return i + 1, encoding def filter_chars(encoding, min_freq): freq = calculate_unigram_freq(encoding) trace('#unique characters:', len(freq)) filtered = [c for c in freq if freq[c] >= min_freq] trace('#filtered characters:', len(filtered)) result = defaultdict(int) for key, val in encoding.items(): subwords = key2subwords(key) new_subwords = [(sw if sw in filtered else '<unk>') for sw in subwords] new_key = subwords2key(new_subwords) result[new_key] += val trace('#filtered encodings:', len(result)) return filtered, result def make_shards(encoding, n): shards = [{} for _ in range(n)] for i, (key, val) in enumerate(encoding.items()): shards[i % n][key] = val return shards def merge_freqs(freq, diffs): for diff in diffs: for key, val in diff.items(): freq[key] += val def merge_bigram(config): encoding, (left, right) = config before = '\t' + left + '\t' + right + '\t' after = '\t' + left + ' ' + right + '\t' new_encoding = {} diff = defaultdict(int) for key, val in encoding.items(): new_key = key.replace(before, after) if new_key != key: subwords = key2subwords(new_key) for i in range(len(subwords) - 1): diff[subwords[i], subwords[i + 1]] += val subwords = key2subwords(key) for i in range(len(subwords) - 1): diff[subwords[i], subwords[i + 1]] -= val new_encoding[new_key] = val return new_encoding, {key: val for key, val in diff.items() if val != 0} def main(): args = parse_args() total_begin_time = time.time() max_words = args.size - 3 num_lines, encoding = load_initial_encoding(args.input) chars, encoding = filter_chars(encoding, args.min_freq) assert len(chars) <= max_words pool = multiprocessing.Pool(args.threads) shards = make_shards(encoding, args.threads) for i, shard in enumerate(shards): print('Shard %d size: %d' % (i, len(shard))) freq = defaultdict(int) merge_freqs(freq, pool.imap_unordered(calculate_bigram_freq, shards)) ops = [] for i in range(len(chars), max_words): begin_time = time.time() left, right = max(freq, key=freq.get) merged_freq = freq[left, right] results = pool.map(merge_bigram, ((x, (left, right)) for x in shards)) shards = [x[0] for x in results] merge_freqs(freq, (x[1] for x in results)) ops.append((left, right)) elapsed = time.time() - begin_time trace('Merged %d/%d: "%s" + "%s" (freq=%d, time=%fs)' % (i + 1, max_words, left, right, merged_freq, elapsed)) trace('Writing vocabulary file ...') freq = defaultdict(int) merge_freqs(freq, pool.imap_unordered(calculate_unigram_freq, shards)) num_unk = freq['<unk>'] if '<unk>' in freq else 0 with open(args.output, 'w') as fp: print('0\t<unk>\t%d' % num_unk, file=fp) print('1\t<s>\t%d' % num_lines, file=fp) print('2\t</s>\t%d' % num_lines, file=fp) print('3\t</w>\t%d' % freq['</w>'], file=fp) for i, c in enumerate(sorted(chars)): if c == '</w>': continue print('%d\t%s\t%d' % (i + 4, c, freq[c]), file=fp) for i, bigram in enumerate(ops): key = ' '.join(bigram) print('%d\t%s\t%s\t%d' % (i + 3 + len(chars), bigram[0], bigram[1], freq[key]), file=fp) total_elapsed = time.time() - total_begin_time trace('Total time elapsed: %fs' % total_elapsed) if __name__ == '__main__': main()