#!/usr/bin/env python3


import multiprocessing
import math
import sys
import time
from argparse import ArgumentParser
from collections import defaultdict


def parse_args():
  p = ArgumentParser('Constructs vocabulary file.')
  p.add_argument(
      '--input',
      type=str, metavar='FILE', required=True, help='source corpus')
  p.add_argument(
      '--output',
      type=str, metavar='FILE', required=True, help='vocabulary file')
  p.add_argument(
      '--size',
      type=int, metavar='N', required=True, help='vicabulary size')
  p.add_argument(
      '--min-freq',
      type=int, metavar='N', required=True, help='minimum occurence per each character')
  p.add_argument(
      '--threads',
      type=int, metavar='N', required=True, help='number of threads')
  args = p.parse_args()
  assert args.size > 3
  return args


def trace(*args, nolf=False):
  print(*args, file=sys.stderr, end='\r' if nolf else '\n')
  sys.stderr.flush()


def word2key(word):
  return '\t' + '\t'.join(list(word)) + '\t</w>\t'


def key2subwords(key):
  return key.strip('\t').split('\t')


def subwords2key(subwords):
  return '\t' + '\t'.join(subwords) + '\t'


def calculate_unigram_freq(encoding):
  freq = defaultdict(int)
  for key, val in encoding.items():
    for sw in key2subwords(key):
      freq[sw] += val
  return freq


def calculate_bigram_freq(encoding):
  freq = defaultdict(int)
  for key, val in encoding.items():
    subwords = key2subwords(key)
    for i in range(len(subwords) - 1):
      freq[subwords[i], subwords[i + 1]] += val
  return freq


def load_initial_encoding(filename):
  encoding = defaultdict(int)
  with open(filename) as fp:
    for i, line in enumerate(fp):
      for word in line.split():
        key = word2key(word)
        encoding[key] += 1
      if (i + 1) % 10000 == 0:
        trace('Loaded', i + 1, 'lines', nolf=True)
  trace('Loaded', i + 1, 'lines')
  trace('#unique encodings:', len(encoding))
  return i + 1, encoding


def filter_chars(encoding, min_freq):
  freq = calculate_unigram_freq(encoding)
  trace('#unique characters:', len(freq))
  filtered = [c for c in freq if freq[c] >= min_freq]
  trace('#filtered characters:', len(filtered))
  result = defaultdict(int)
  for key, val in encoding.items():
    subwords = key2subwords(key)
    new_subwords = [(sw if sw in filtered else '<unk>') for sw in subwords]
    new_key = subwords2key(new_subwords)
    result[new_key] += val
  trace('#filtered encodings:', len(result))
  return filtered, result


def make_shards(encoding, n):
  shards = [{} for _ in range(n)]
  for i, (key, val) in enumerate(encoding.items()):
    shards[i % n][key] = val
  return shards


def merge_freqs(freq, diffs):
  for diff in diffs:
    for key, val in diff.items():
      freq[key] += val


def merge_bigram(config):
  encoding, (left, right) = config
  before = '\t' + left + '\t' + right + '\t'
  after = '\t' + left + ' ' + right + '\t'
  new_encoding = {}
  diff = defaultdict(int)
  for key, val in encoding.items():
    new_key = key.replace(before, after)
    if new_key != key:
      subwords = key2subwords(new_key)
      for i in range(len(subwords) - 1):
        diff[subwords[i], subwords[i + 1]] += val
      subwords = key2subwords(key)
      for i in range(len(subwords) - 1):
        diff[subwords[i], subwords[i + 1]] -= val
    new_encoding[new_key] = val
  return new_encoding, {key: val for key, val in diff.items() if val != 0}


def main():
  args = parse_args()

  total_begin_time = time.time()

  max_words = args.size - 3

  num_lines, encoding = load_initial_encoding(args.input)
  chars, encoding = filter_chars(encoding, args.min_freq)
  assert len(chars) <= max_words

  pool = multiprocessing.Pool(args.threads)
  shards = make_shards(encoding, args.threads)
  for i, shard in enumerate(shards):
    print('Shard %d size: %d' % (i, len(shard)))

  freq = defaultdict(int)
  merge_freqs(freq, pool.imap_unordered(calculate_bigram_freq, shards))
  
  ops = []
  for i in range(len(chars), max_words):
    begin_time = time.time()
    left, right = max(freq, key=freq.get)
    merged_freq = freq[left, right]
    results = pool.map(merge_bigram, ((x, (left, right)) for x in shards))
    shards = [x[0] for x in results]
    merge_freqs(freq, (x[1] for x in results))
    ops.append((left, right))
    elapsed = time.time() - begin_time
    trace('Merged %d/%d: "%s" + "%s" (freq=%d, time=%fs)' % (i + 1, max_words, left, right, merged_freq, elapsed))

  trace('Writing vocabulary file ...')
  freq = defaultdict(int)
  merge_freqs(freq, pool.imap_unordered(calculate_unigram_freq, shards))
  num_unk = freq['<unk>'] if '<unk>' in freq else 0

  with open(args.output, 'w') as fp:
    print('0\t<unk>\t%d' % num_unk, file=fp)
    print('1\t<s>\t%d' % num_lines, file=fp)
    print('2\t</s>\t%d' % num_lines, file=fp)
    print('3\t</w>\t%d' % freq['</w>'], file=fp)
    for i, c in enumerate(sorted(chars)):
      if c == '</w>':
        continue
      print('%d\t%s\t%d' % (i + 4, c, freq[c]), file=fp)
    for i, bigram in enumerate(ops):
      key = ' '.join(bigram)
      print('%d\t%s\t%s\t%d' % (i + 3 + len(chars), bigram[0], bigram[1], freq[key]), file=fp)

  total_elapsed = time.time() - total_begin_time
  trace('Total time elapsed: %fs' % total_elapsed)


if __name__ == '__main__':
  main()