Created
May 27, 2015 06:41
-
-
Save MitI-7/24b1a6d9a448ed8b5bcd to your computer and use it in GitHub Desktop.
ビタビアルゴリズムでN-bestをだす
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
from collections import defaultdict | |
import queue | |
class Node: | |
def __init__(self, id_, name, weight): | |
self.id = id_ | |
self.name = name | |
self.weight = weight | |
def __str__(self): | |
return self.name | |
class Edge: | |
def __init__(self, start_node, end_node, weight): | |
self.start_node = start_node | |
self.end_node = end_node | |
self.weight = weight | |
def __str__(self): | |
return str(self.start_node.name) + "-" + str(self.end_node.name) + "(" + str(self.weight) + ")" | |
class Graph: | |
def __init__(self): | |
self.id_node = {} # ノードidとノードオブジェクトの辞書 | |
self.id_nextEdges = defaultdict(list) # {ノードid: ノードidからでているedgeオブジェクトのリスト} | |
self.id_prevEdges = defaultdict(list) # {ノードid: ノードidへとでているedgeオブジェクトのリスト} | |
self.last_id = -1 # 最後に追加されたノードのid | |
def get_next_edges(self, node_id): | |
return self.id_nextEdges[node_id] | |
def get_prev_edges(self, node_id): | |
return self.id_prevEdges[node_id] | |
def add_node(self, name, weight): | |
self.last_id += 1 | |
self.id_node[self.last_id] = Node(self.last_id, name, weight) | |
return self.last_id | |
def add_edge(self, start_id, end_id, weight): | |
edge = Edge(self.id_node[start_id], self.id_node[end_id], weight) | |
self.id_nextEdges[start_id].append(edge) # next_edge | |
self.id_prevEdges[end_id].append(edge) # prev_edge | |
def viterbi(graph): | |
id_sumCost = {0: 0} # {ノードid: そこに至るまでの最短のコスト} | |
# warning: graphのノードは先頭から順番に追加されている前提 | |
for node_id, node in sorted(graph.id_node.items()): | |
if node_id == 0: | |
continue | |
min_cost = sys.maxsize | |
# nodeに到達するまでの最小コストをもとめる | |
for prev_edge in graph.get_prev_edges(node_id): | |
# startからprev_nodeまでのコスト + prev_nodeからnodeまでのエッジのコスト | |
prev_node_id = prev_edge.start_node.id | |
cost = id_sumCost[prev_node_id] + prev_edge.weight | |
min_cost = min(min_cost, cost) | |
id_sumCost[node_id] = min_cost + node.weight | |
return id_sumCost | |
def n_best(n, graph, id_sumCost): | |
shortest_path_list = [] | |
q = queue.PriorityQueue() | |
nodeid_nextNodeid = {} | |
q.put((0, graph.last_id)) | |
while not q.empty() and len(shortest_path_list) < n: | |
cost, node_id = q.get() | |
# 開始ノードまで辿れた | |
if node_id == 0: | |
shortest_path = [] | |
# 開始ノードから最後のノードまで辿る | |
while True: | |
shortest_path.append(str(graph.id_node[node_id])) | |
if node_id in nodeid_nextNodeid: | |
node_id = nodeid_nextNodeid[node_id] | |
else: | |
break | |
shortest_path_list.append(shortest_path) | |
continue | |
for prev_edge in graph.get_prev_edges(node_id): | |
prev_node_id = prev_edge.start_node.id | |
q.put((id_sumCost[prev_node_id], prev_node_id)) | |
nodeid_nextNodeid[prev_node_id] = node_id | |
return shortest_path_list | |
def main(): | |
graph = Graph() | |
s = graph.add_node("<S>", 0) | |
sa = graph.add_node("さ", 100) | |
graph.add_edge(s, sa, 30) | |
saka = graph.add_node("さか", 200) | |
graph.add_edge(s, saka, 30) | |
sakana = graph.add_node("さかな", 100) | |
graph.add_edge(s, sakana, 30) | |
kana = graph.add_node("かな", 200) | |
graph.add_edge(sa, kana, 30) | |
nada = graph.add_node("なだ", 200) | |
graph.add_edge(saka, nada, 30) | |
da = graph.add_node("だ", 10) | |
graph.add_edge(kana, da, 30) | |
graph.add_edge(sakana, da, 30) | |
yo = graph.add_node("よ", 10) | |
graph.add_edge(da, yo, 30) | |
graph.add_edge(nada, yo, 45) | |
end = graph.add_node("</S>", 0) | |
graph.add_edge(yo, end, 30) | |
id_sumCost = viterbi(graph) | |
shortest_path_list = n_best(3, graph, id_sumCost) | |
for p in shortest_path_list: | |
print(p) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment