Skip to content

Instantly share code, notes, and snippets.

@kzinmr
Last active May 19, 2020 01:52
Show Gist options
  • Save kzinmr/1a916232909f41594838a025283f68ef to your computer and use it in GitHub Desktop.
Save kzinmr/1a916232909f41594838a025283f68ef to your computer and use it in GitHub Desktop.
Generate window contexts and count cooccurence within them.
from typing import List
from itertools import tee, combinations
from collections import Counter
# def count_cooccurrence_in_window(context_window, delimiter=' '):
# return Counter([delimiter.join(bi) for bi in combinations(context_window, 2)])
def window_cooccurrence(sentence: List[str], window: int = 5) -> Counter:
""" Count cooccurrence in window in given sentence
1. enumerate contexts in window by
[list(it)[i:i + window] for i, it in enumerate(tee(sentence, num))]:
['A','B','C','D', 'E', 'F', 'G'], 4 ->
[['A', 'B', 'C', 'D'],
['B', 'C', 'D', 'E'],
['C', 'D', 'E', 'F'],
['D', 'E', 'F', 'G']]
2. generete and count cooccurrence within the each context above by
[pair for pair in combinations(list(it)[i:i + window], 2)]:
['A', 'B', 'C', 'D'] ->
[('A', 'B'), ('A', 'C'), ('A', 'D'), ('B', 'C'), ('B', 'D'), ('C', 'D')]
"""
assert len(sentence) >= window
num = len(sentence) - window + 1
return Counter([
pair
for i, it in enumerate(tee(sentence, num))
for pair in combinations(list(it)[i:i + window], 2)
])
@kzinmr
Copy link
Author

kzinmr commented May 15, 2020

window_cooccurence_documents.py

from typing import List
from itertools import tee, combinations
from collections import Counter

def window_cooccurrence_documents(documents: List[List[str]], window: int = 5) -> Counter:

    return Counter([
        tuple(sorted(pair))
        for sentence in documents
        for i, it in enumerate(tee(sentence, len(sentence) - window + 1))
        for pair in combinations(list(it)[i:i + window], 2)
        if len(sentence) >= window
    ])

if __name__ == '__main__':
    # wakatis
    # occs = Counter([w for ws in wakatis for w in ws])
    # coocs = sum(cooc_counters, Counter())

@kzinmr
Copy link
Author

kzinmr commented May 15, 2020

LM-generation-candidates-spans.py

from heapq import nlargest
from operator import itemgetter
import math

target = 'target'
coocs_target = {k[0] if k[1] == target else k[1]: v for k, v in coocs.items() if target in k}


def topk_average(nums, k=5):
    return sum(heapq.nlargest(k, nums)) / k

def context_score(context, occs, coocs_target):
    """ log P(target | context) """
    N = sum(occs.values())
    N_target = sum(coocs_target.values())
    return sum([math.log(coocs_target[c] * N / occs[c] / N_target) for c in context if c in coocs_target and c in occs])
#     return topk_average([coocs_target[x] for x in context if x in coocs_target])

def cooc_score(text, occs, coocs_target, window=10, topk=30):
    candidates = [list(it)[i:i + window] for i, it in enumerate(tee(text, len(text) - window + 1))]
    cands_scores = [(cand, context_score(cand, occs, coocs_target)) for cand in candidates]
    return nlargest(topk, cands_scores, key=itemgetter(1))

if __name__ == '__main__':
    # wakatis, occs, coocs_target
    # cooc_score(wakatis[0], occs, coocs_target)

@kzinmr
Copy link
Author

kzinmr commented May 18, 2020

import math
from collections import Counter
from heapq import nlargest
from itertools import accumulate, combinations, groupby, tee
from operator import itemgetter
from typing import Dict, List


def window_cooccurrence_documents(
    documents: List[List[str]], window: int = 5
) -> Counter:

    return Counter(
        [
            tuple(sorted(pair))
            for sentence in documents
            for i, it in enumerate(tee(sentence, len(sentence) - window + 1))
            for pair in combinations(list(it)[i : i + window], 2)
            if len(sentence) >= window
        ]
    )


def topk_average(nums: List[float], k: int = 5) -> float:
    return sum(heapq.nlargest(k, nums)) / k


def context_score(
    context: List[str], occs: Dict[str, int], coocs_target: Dict[str, int]
) -> float:
    """
    ターゲットシンボルが固定されている場合に、文脈->ターゲット生成確率スコア
    log P(target | context) を以下の近似で算出:
        log Π_{c in context} P(c, target) / P(c)
        = Σ_c log P(c, target) / P(c)
        = Σ_c log (N(c, target) / Σ_x N(x, target)) * (Σ_x N(x) / N(c))
    Arguments:
        occs: N(c) を与えるmap
        coocs_target: N(c, target) を与えるmap(targetは着目する解析対象を表す特殊文字)
    """
    N = sum(occs.values())
    N_target = sum(coocs_target.values())
    return sum(
        [
            math.log(coocs_target[c] * N / occs[c] / N_target)
            for c in context
            if c in coocs_target and c in occs
        ]
    )


def consecutive_ints(data: List[int], gap: int = 1) -> List[List[int]]:
    """
    INPUT: [ 1, 4,5,6, 10, 15,16,17,18, 22, 25,26,27,28 ]
    OUTPUT: [[1], [4, 5, 6], [10], [15, 16, 17, 18], [22], [25, 26, 27, 28]]
    """
    return [
        list(map(itemgetter(1), g))
        for k, g in groupby(enumerate(data), lambda ix: ix[0] - ix[1])
    ]


def ordered_intersection(fst: List, snd: List) -> List:
    set_snd = frozenset(snd)
    return [x for x in fst if x in set_snd]


def extract_spans_by_context_overlap(
    text: List[str],
    occs: Dict[str, int],
    coocs_target: Dict[str, int],
    window: int = 15,
    topk: int = 100,
    group_size_threshold: int = 3,
) -> List[List[str]]:
    """
    参照文脈集合がある場合に、類似文脈を持つ単語列を抽出。
    発想的には、スパンをその文脈とターゲット文脈の重複度合いにより抽出 & 抽出された共通部分を認識:
        1. スパンを文脈スコアでスコアリングして最大 topk 件列挙する
        2. 列挙後、インデックスで並べ直して、連続するスパン部分列を抽出する(連続するか要チェック)
        3. スパン部分列の共通部分をターゲットとみなすことで抽出対象を認識する
    Arguments:
        window: スパンのサイズ(ハイパーパラメータ)
        topk: 最大スコアで足切りする件数(ハイパーパラメータ)
    """
    # スパンを文脈スコアでスコアリングして最大 topk 件列挙する
    # TODO: spanの begin(i), end(i+window) indexを保持 (適当にSpan class定義して)
    candidates = [
        list(it)[i : i + window]
        for i, it in enumerate(tee(text, len(text) - window + 1))
    ]
    cands_scores = [
        (i, cand, context_score(cand, occs, coocs_target))
        for i, cand in enumerate(candidates)
    ]
    large_spans = nlargest(topk, cands_scores, key=itemgetter(-1))
    # 列挙後、インデックスで並べ直して、連続するスパン部分列を抽出する(連続するか要チェック)
    # TODO: spanの begin(j->i), end(j->i+window) indexを保持
    consecutive_spans = sorted(large_spans, key=itemgetter(0))
    id2span = {i: c for i, c, s in consecutive_spans}
    consecutive_ids = list(map(itemgetter(0), consecutive_spans))
    consecutive_id_groups = consecutive_ints(consecutive_ids, gap=1)
    consecutive_span_groups = [
        [id2span[i] for i in ids] for ids in consecutive_id_groups if len(ids) > group_size_threshold
    ]
    # スパン部分列の共通部分をターゲットとみなすことで抽出対象を認識する
    span_group_intersections = [
        list(accumulate(span_group, ordered_intersection))[-1]
        for span_group in consecutive_span_groups
    ]
    return span_group_intersections


def fit(wakatis: List[List[str]]) -> List[List[List[str]]]:
    occs = Counter([w for ws in wakatis for w in ws])
    coocs = window_cooccurrence_documents(wakatis, window=5)
    target = "target"
    coocs_target = {
        k[0] if k[1] == target else k[1]: v for k, v in coocs.items() if target in k
    }
    return occs, coocs_target


def predict(
    wakatis: List[List[str]],
    occs: Dict[str, int],
    coocs_target: Dict[str, int],
    window: int = 15,
    topk: int = 100,
) -> List[List[List[str]]]:
    return [
        extract_spans_by_context_overlap(text_wakati, occs, coocs_target, window, topk)
        for text_wakati in wakatis
    ]


if __name__ == "__main__":
    pass

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment