Last active
May 19, 2020 01:52
-
-
Save kzinmr/1a916232909f41594838a025283f68ef to your computer and use it in GitHub Desktop.
Generate window contexts and count cooccurence within them.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
from itertools import tee, combinations | |
from collections import Counter | |
# def count_cooccurrence_in_window(context_window, delimiter=' '): | |
# return Counter([delimiter.join(bi) for bi in combinations(context_window, 2)]) | |
def window_cooccurrence(sentence: List[str], window: int = 5) -> Counter: | |
""" Count cooccurrence in window in given sentence | |
1. enumerate contexts in window by | |
[list(it)[i:i + window] for i, it in enumerate(tee(sentence, num))]: | |
['A','B','C','D', 'E', 'F', 'G'], 4 -> | |
[['A', 'B', 'C', 'D'], | |
['B', 'C', 'D', 'E'], | |
['C', 'D', 'E', 'F'], | |
['D', 'E', 'F', 'G']] | |
2. generete and count cooccurrence within the each context above by | |
[pair for pair in combinations(list(it)[i:i + window], 2)]: | |
['A', 'B', 'C', 'D'] -> | |
[('A', 'B'), ('A', 'C'), ('A', 'D'), ('B', 'C'), ('B', 'D'), ('C', 'D')] | |
""" | |
assert len(sentence) >= window | |
num = len(sentence) - window + 1 | |
return Counter([ | |
pair | |
for i, it in enumerate(tee(sentence, num)) | |
for pair in combinations(list(it)[i:i + window], 2) | |
]) |
LM-generation-candidates-spans.py
from heapq import nlargest
from operator import itemgetter
import math
target = 'target'
coocs_target = {k[0] if k[1] == target else k[1]: v for k, v in coocs.items() if target in k}
def topk_average(nums, k=5):
return sum(heapq.nlargest(k, nums)) / k
def context_score(context, occs, coocs_target):
""" log P(target | context) """
N = sum(occs.values())
N_target = sum(coocs_target.values())
return sum([math.log(coocs_target[c] * N / occs[c] / N_target) for c in context if c in coocs_target and c in occs])
# return topk_average([coocs_target[x] for x in context if x in coocs_target])
def cooc_score(text, occs, coocs_target, window=10, topk=30):
candidates = [list(it)[i:i + window] for i, it in enumerate(tee(text, len(text) - window + 1))]
cands_scores = [(cand, context_score(cand, occs, coocs_target)) for cand in candidates]
return nlargest(topk, cands_scores, key=itemgetter(1))
if __name__ == '__main__':
# wakatis, occs, coocs_target
# cooc_score(wakatis[0], occs, coocs_target)
import math
from collections import Counter
from heapq import nlargest
from itertools import accumulate, combinations, groupby, tee
from operator import itemgetter
from typing import Dict, List
def window_cooccurrence_documents(
documents: List[List[str]], window: int = 5
) -> Counter:
return Counter(
[
tuple(sorted(pair))
for sentence in documents
for i, it in enumerate(tee(sentence, len(sentence) - window + 1))
for pair in combinations(list(it)[i : i + window], 2)
if len(sentence) >= window
]
)
def topk_average(nums: List[float], k: int = 5) -> float:
return sum(heapq.nlargest(k, nums)) / k
def context_score(
context: List[str], occs: Dict[str, int], coocs_target: Dict[str, int]
) -> float:
"""
ターゲットシンボルが固定されている場合に、文脈->ターゲット生成確率スコア
log P(target | context) を以下の近似で算出:
log Π_{c in context} P(c, target) / P(c)
= Σ_c log P(c, target) / P(c)
= Σ_c log (N(c, target) / Σ_x N(x, target)) * (Σ_x N(x) / N(c))
Arguments:
occs: N(c) を与えるmap
coocs_target: N(c, target) を与えるmap(targetは着目する解析対象を表す特殊文字)
"""
N = sum(occs.values())
N_target = sum(coocs_target.values())
return sum(
[
math.log(coocs_target[c] * N / occs[c] / N_target)
for c in context
if c in coocs_target and c in occs
]
)
def consecutive_ints(data: List[int], gap: int = 1) -> List[List[int]]:
"""
INPUT: [ 1, 4,5,6, 10, 15,16,17,18, 22, 25,26,27,28 ]
OUTPUT: [[1], [4, 5, 6], [10], [15, 16, 17, 18], [22], [25, 26, 27, 28]]
"""
return [
list(map(itemgetter(1), g))
for k, g in groupby(enumerate(data), lambda ix: ix[0] - ix[1])
]
def ordered_intersection(fst: List, snd: List) -> List:
set_snd = frozenset(snd)
return [x for x in fst if x in set_snd]
def extract_spans_by_context_overlap(
text: List[str],
occs: Dict[str, int],
coocs_target: Dict[str, int],
window: int = 15,
topk: int = 100,
group_size_threshold: int = 3,
) -> List[List[str]]:
"""
参照文脈集合がある場合に、類似文脈を持つ単語列を抽出。
発想的には、スパンをその文脈とターゲット文脈の重複度合いにより抽出 & 抽出された共通部分を認識:
1. スパンを文脈スコアでスコアリングして最大 topk 件列挙する
2. 列挙後、インデックスで並べ直して、連続するスパン部分列を抽出する(連続するか要チェック)
3. スパン部分列の共通部分をターゲットとみなすことで抽出対象を認識する
Arguments:
window: スパンのサイズ(ハイパーパラメータ)
topk: 最大スコアで足切りする件数(ハイパーパラメータ)
"""
# スパンを文脈スコアでスコアリングして最大 topk 件列挙する
# TODO: spanの begin(i), end(i+window) indexを保持 (適当にSpan class定義して)
candidates = [
list(it)[i : i + window]
for i, it in enumerate(tee(text, len(text) - window + 1))
]
cands_scores = [
(i, cand, context_score(cand, occs, coocs_target))
for i, cand in enumerate(candidates)
]
large_spans = nlargest(topk, cands_scores, key=itemgetter(-1))
# 列挙後、インデックスで並べ直して、連続するスパン部分列を抽出する(連続するか要チェック)
# TODO: spanの begin(j->i), end(j->i+window) indexを保持
consecutive_spans = sorted(large_spans, key=itemgetter(0))
id2span = {i: c for i, c, s in consecutive_spans}
consecutive_ids = list(map(itemgetter(0), consecutive_spans))
consecutive_id_groups = consecutive_ints(consecutive_ids, gap=1)
consecutive_span_groups = [
[id2span[i] for i in ids] for ids in consecutive_id_groups if len(ids) > group_size_threshold
]
# スパン部分列の共通部分をターゲットとみなすことで抽出対象を認識する
span_group_intersections = [
list(accumulate(span_group, ordered_intersection))[-1]
for span_group in consecutive_span_groups
]
return span_group_intersections
def fit(wakatis: List[List[str]]) -> List[List[List[str]]]:
occs = Counter([w for ws in wakatis for w in ws])
coocs = window_cooccurrence_documents(wakatis, window=5)
target = "target"
coocs_target = {
k[0] if k[1] == target else k[1]: v for k, v in coocs.items() if target in k
}
return occs, coocs_target
def predict(
wakatis: List[List[str]],
occs: Dict[str, int],
coocs_target: Dict[str, int],
window: int = 15,
topk: int = 100,
) -> List[List[List[str]]]:
return [
extract_spans_by_context_overlap(text_wakati, occs, coocs_target, window, topk)
for text_wakati in wakatis
]
if __name__ == "__main__":
pass
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
window_cooccurence_documents.py