# -*- coding: utf-8 -*-
from __future__ import print_function
from __future__ import unicode_literals

# author: Kyle Kastner

# References:
# needleman wunsch (could use other alignment algorithms instead)
# https://colab.research.google.com/github/zaneveld/full_spectrum_bioinformatics/blob/master/content/08_phylogenetic_trees/needleman_wunsch_alignment.ipynb

# ROVER discussion
# Data-Diverse Redundant Processing for Noise-Robust Automatic Speech Recognition
# Mustafa K. Hotaki
# https://libraetd.lib.virginia.edu/downloads/n583xv77p?filename=2_Hotaki_Mustafa_2020_MS.pdf

# Related ROVER code
# https://github.com/Toloka/CrowdSpeech/blob/main/rover.py 

import numpy as np
from pprint import pprint
import string
from collections import OrderedDict, defaultdict

def pprint_arr(arr, seq_1=None, seq_2=None):
    # print util
    all_ = []
    if seq_2 is not None:
        # use ^ as the start of seq filler mark
        all_.append(["  ^"] + [el for el in seq_2])

    for r_i in range(arr.shape[0]):
        c_s = [el for el in arr[r_i]]
        if seq_1 is not None:
            if r_i == 0:
                c_s = ["^"] + c_s
            else:
                c_s = [seq_1[r_i - 1]] + c_s
        all_.append(c_s)
    for r_i in range(arr.shape[0]):
        print(*all_[r_i])

def needleman_wuntch_align(seq_1, seq_2):
    n_rows = len(seq_1) + 1 #need an extra row up top
    n_cols = len(seq_2) + 1 #need an extra column on the left

    scoring_array = np.full([n_rows, n_cols], 0)
    traceback_array = np.full([n_rows, n_cols], "-")

    up_arrow = "\u2191"
    right_arrow = "\u2192"
    down_arrow = "\u2193"
    left_arrow = "\u2190"
    down_right_arrow = "\u2198"
    up_left_arrow = "\u2196"
    arrow = "-"
    # INS and DEL penalty
    gap_penalty = -1
    # SUB penalty
    mismatch_penalty = -1
    match_bonus = 1

    for r in range(n_rows):
        for c in range(n_cols):
            if r == 0 and c == 0:
                # in the top left
                score = 0
                # 3 is "- aka start"
                score_i = 3
            elif r == 0:
                # first row, non-corner
                # score from left
                l_c_s = scoring_array[r, c - 1]
                score = l_c_s + gap_penalty
                arrow = left_arrow
                # 0 is "left"
                score_i = 0
            elif c == 0:
                # first col, non-corner
                # score from above
                a_c_s = scoring_array[r - 1, c]
                arrow = up_arrow
                # 1 is "above"
                score_i = 1
                score = a_c_s + gap_penalty
            else:
                l_c_s = scoring_array[r, c - 1]
                a_c_s = scoring_array[r - 1, c]
                d_c_s = scoring_array[r - 1, c - 1]
                score_f_l = l_c_s + gap_penalty
                score_f_a = a_c_s + gap_penalty
                score_f_d = d_c_s + (match_bonus if seq_1[r - 1] == seq_2[c - 1] else mismatch_penalty)
                # 2 is "diag"
                # note that order of this argmax list should match score_i descriptions above
                grp = [score_f_l, score_f_a, score_f_d]
                score_i = np.argmax(grp)
                score = grp[score_i]
                arrow_grp = [left_arrow, up_arrow, up_left_arrow]
                arrow = arrow_grp[score_i]
            scoring_array[r, c] = score
            traceback_array[r, c] = arrow
    return scoring_array, traceback_array


def traceback_alignment(traceback_array, seq1, seq2,
                        up_arrow="\u2191", left_arrow="\u2190", up_left_arrow="\u2196", stop="-", debug_print=False):
    """Align seq1 and seq2 using the traceback matrix and return as two strings
    traceback_array -- a numpy array with arrow characters indicating the direction from
    which the best path to a given alignment position originated
    seq1 - a sequence represented as a string
    seq2 - a sequence represented as a string
    up_arrow - the unicode used for the up arrows (there are several arrow symbols in Unicode)
    left_arrow - the unicode used for the left arrows
    up_left_arrow - the unicode used for the diagonal arrows
    stop - the symbol used in the upper left to indicate the end of the alignment
    
    from:
    https://colab.research.google.com/github/zaneveld/full_spectrum_bioinformatics/blob/master/content/08_phylogenetic_trees/needleman_wunsch_alignment.ipynb
    """

    n_rows = len(seq1) + 1 #need an extra row up top
    n_columns = len(seq2) + 1 #need an extra row up top
    row = len(seq1)
    col = len(seq2)
    arrow = traceback_array[row,col]
    aligned_seq1 = ""
    aligned_seq2 = ""
    alignment_indicator = ""
    while arrow is not stop:
        if debug_print:
            print("Currently on row:",row)
            print("Currently on col:",col)
        arrow = traceback_array[row,col]
        if debug_print:
            print("Arrow:",arrow)
        if arrow == up_arrow:
            if debug_print:
                print("insert indel into top sequence")
            #We want to add the new indel onto the left 
            #side of the growing aligned sequence
            aligned_seq2 = "-" + aligned_seq2
            aligned_seq1 = seq1[row-1] + aligned_seq1
            alignment_indicator = " "+alignment_indicator
            row -=1
        elif arrow == up_left_arrow:
            if debug_print:
                print("match or mismatch")
            #Note that we look up the row-1 and col-1 indexes
            #because there is an extra "-" character at the
            #start of each sequence
            seq1_character = seq1[row - 1]
            seq2_character = seq2[col - 1]
            aligned_seq1 = seq1[row - 1] + aligned_seq1
            aligned_seq2 = seq2[col - 1] + aligned_seq2
            if seq1_character == seq2_character:
                alignment_indicator = "|"+alignment_indicator
            else:
                alignment_indicator = " "+alignment_indicator
            row -=1
            col -=1
        elif arrow == left_arrow:
            if debug_print:
                print("Insert indel into left sequence")
            aligned_seq1 = "-" + aligned_seq1
            aligned_seq2 = seq2[col-1] + aligned_seq2
            alignment_indicator = " " + alignment_indicator
            col -=1
        elif arrow == stop:
            if debug_print:
                print("Finished!")
            break
        else:
            raise ValueError("Traceback array entry at {},{}: {} is not recognized as an up arrow ({}),left_arrow ({}), up_left_arrow ({}), or a stop ({}).".format(row, col, arrow, up_arrow, left_arrow, up_left_arrow, stop))
        if debug_print:
            print(aligned_seq1)
            print(alignment_indicator)
            print(aligned_seq2)
    return aligned_seq1, aligned_seq2

class OrderedLambdaDefaultDict(OrderedDict):
    factory = lambda: ([], 0)

    def __missing__(self, key):
        self[key] = value = self.factory()
        return value


class WTN:
    def __init__(self):
        self.transitions = OrderedDict()
        self.words_positions = set()

    def _add_word(self, word_position):
        if word_position not in self.words_positions:
            self.words_positions.add(word_position)
            if word_position not in self.transitions:
                self.transitions[word_position] = OrderedDict()

    def add_transition(self, from_word_position, to_word_position, weight=1.0):
        self._add_word(from_word_position)
        self._add_word(to_word_position)
        if to_word_position not in self.transitions[from_word_position]:
            self.transitions[from_word_position][to_word_position] = 0
        self.transitions[from_word_position][to_word_position] += weight

    def get_best_path_and_score(self, rover_alpha, confidence_fn):
        best_paths = OrderedLambdaDefaultDict()
        for word_position in self.words_positions:
            best_paths[word_position] = ([], 0)

        for from_word_position in self.transitions:
            for to_word_position, weight in self.transitions[from_word_position].items():
                current_path, current_score = best_paths[from_word_position]
                new_score = current_score + (rover_alpha * (weight / float(len(self.transitions[from_word_position]))) + (1.0 - rover_alpha) * confidence_fn(current_path))
                if new_score > best_paths[to_word_position][1]:
                    best_paths[to_word_position] = (current_path + [from_word_position], new_score)

        # Find the best path by the maximum score
        final_word = max(best_paths, key=lambda word: best_paths[word][1])
        best_path, path_score = best_paths[final_word]
        #print(self.transitions[best_path[-1][0]][best_path[-1][1]])
        best_path.append(final_word)
        # do we want the per step score totals?
        return best_path, path_score

random_state = np.random.RandomState(2145)
def fake_confidences(current_preds):
    if len(current_preds) > 0:
        word_position = current_preds[-1]
        # totally fake confidences
    return random_state.rand()

if __name__ == "__main__":
    # Example from 
    # https://libraetd.lib.virginia.edu/downloads/n583xv77p?filename=2_Hotaki_Mustafa_2020_MS.pdf
    word_based = True
    base_seq_1 = "the cat in the hat sat on the mat"
    rover_alpha = 1.0
    # rover_alpha = 1.0
    # [(u'-', 0), (u'the', 1), (u'cat', 2), (u'in', 3), (u'the', 4), (u'hat', 5), (u'sat', 6), (u'on', 7), (u'the', 8), (u'mat', 9)]
    # 22.1666666667

    # rover_alpha = 0.0
    # [(u'-', 0), (u'the', 1), (u'cat', 2), (u'in', 3), (u'the', 4), (u'hat', 5), (u'sat', 6), (u'on', 7), (u'-', 8), (u'mat', 9)]
    # 3.94283457846

    all_seq_2 = ["the cat and the hat on mat",
                 "the bat in that hat sat in the mat",
                 "the cat end hat on the mat",
                 "the cat end hat on the mat",
                 "the cat in the at sat on the mat",
                 "the cat in at sat on mat"]
    core_wtn = WTN()
    _s = 0
    for base_seq_2 in all_seq_2:
        if word_based:
            # make these into words, instead of doing char based align
            seq_1 = base_seq_1.split(" ")
            seq_2 = base_seq_2.split(" ")

            seq_1_o = seq_1
            seq_2_o = seq_2

            vocab = OrderedDict()
            # set up 100k vocab, won't use most
            indexer = [str(el) for el in range(100000)]
            _i = 0
            for seq in [seq_1, seq_2]:
                for w in seq:
                    if w not in vocab:
                        vocab[w] = (1, indexer[_i])
                        _i += 1
                    else:
                        vocab[w] = (vocab[w][0], vocab[w][1])

            rev_vocab = {v[1]: (v[0], k) for k, v in vocab.items()}
            seq_1 = [vocab[el][1] for el in seq_1]
            seq_2 = [vocab[el][1] for el in seq_2]

        # get alignment and traceback for dynamic programming path
        scoring_array, traceback_array = needleman_wuntch_align(seq_1, seq_2)

        align_1, align_2 = traceback_alignment(traceback_array, seq_1, seq_2, debug_print=False)
        if word_based:
            align_1 = " ".join([rev_vocab[el][1] if el in rev_vocab else el for el in align_1])
            align_2 = " ".join([rev_vocab[el][1] if el in rev_vocab else el for el in align_2])
            # prepend "-"
            words_align_1 = ["-"] + align_1.split(" ")
            words_align_2 = ["-"] + align_2.split(" ")
            words_align_1 = list(zip(words_align_1, range(len(words_align_1))))
            words_align_2 = list(zip(words_align_2, range(len(words_align_2))))
        else:
           raise ValueError("Fix char based")
        if _s == 0:
            for fwa, twa in zip(words_align_1[:-1], words_align_1[1:]):
                core_wtn.add_transition(fwa, twa)
        for fwa, twa in zip(words_align_2[:-1], words_align_2[1:]):
            core_wtn.add_transition(fwa, twa)
        _s += 1
    # now that we have all aligned sequences, build word transition network and do ROVER scoring
    best_path, best_path_score = core_wtn.get_best_path_and_score(rover_alpha, fake_confidences)
    print(best_path)
    print(best_path_score)