Skip to content

Instantly share code, notes, and snippets.

@MitI-7
Created May 27, 2015 06:41
Show Gist options
  • Save MitI-7/24b1a6d9a448ed8b5bcd to your computer and use it in GitHub Desktop.
Save MitI-7/24b1a6d9a448ed8b5bcd to your computer and use it in GitHub Desktop.
ビタビアルゴリズムでN-bestをだす
import sys
from collections import defaultdict
import queue
class Node:
def __init__(self, id_, name, weight):
self.id = id_
self.name = name
self.weight = weight
def __str__(self):
return self.name
class Edge:
def __init__(self, start_node, end_node, weight):
self.start_node = start_node
self.end_node = end_node
self.weight = weight
def __str__(self):
return str(self.start_node.name) + "-" + str(self.end_node.name) + "(" + str(self.weight) + ")"
class Graph:
def __init__(self):
self.id_node = {} # ノードidとノードオブジェクトの辞書
self.id_nextEdges = defaultdict(list) # {ノードid: ノードidからでているedgeオブジェクトのリスト}
self.id_prevEdges = defaultdict(list) # {ノードid: ノードidへとでているedgeオブジェクトのリスト}
self.last_id = -1 # 最後に追加されたノードのid
def get_next_edges(self, node_id):
return self.id_nextEdges[node_id]
def get_prev_edges(self, node_id):
return self.id_prevEdges[node_id]
def add_node(self, name, weight):
self.last_id += 1
self.id_node[self.last_id] = Node(self.last_id, name, weight)
return self.last_id
def add_edge(self, start_id, end_id, weight):
edge = Edge(self.id_node[start_id], self.id_node[end_id], weight)
self.id_nextEdges[start_id].append(edge) # next_edge
self.id_prevEdges[end_id].append(edge) # prev_edge
def viterbi(graph):
id_sumCost = {0: 0} # {ノードid: そこに至るまでの最短のコスト}
# warning: graphのノードは先頭から順番に追加されている前提
for node_id, node in sorted(graph.id_node.items()):
if node_id == 0:
continue
min_cost = sys.maxsize
# nodeに到達するまでの最小コストをもとめる
for prev_edge in graph.get_prev_edges(node_id):
# startからprev_nodeまでのコスト + prev_nodeからnodeまでのエッジのコスト
prev_node_id = prev_edge.start_node.id
cost = id_sumCost[prev_node_id] + prev_edge.weight
min_cost = min(min_cost, cost)
id_sumCost[node_id] = min_cost + node.weight
return id_sumCost
def n_best(n, graph, id_sumCost):
shortest_path_list = []
q = queue.PriorityQueue()
nodeid_nextNodeid = {}
q.put((0, graph.last_id))
while not q.empty() and len(shortest_path_list) < n:
cost, node_id = q.get()
# 開始ノードまで辿れた
if node_id == 0:
shortest_path = []
# 開始ノードから最後のノードまで辿る
while True:
shortest_path.append(str(graph.id_node[node_id]))
if node_id in nodeid_nextNodeid:
node_id = nodeid_nextNodeid[node_id]
else:
break
shortest_path_list.append(shortest_path)
continue
for prev_edge in graph.get_prev_edges(node_id):
prev_node_id = prev_edge.start_node.id
q.put((id_sumCost[prev_node_id], prev_node_id))
nodeid_nextNodeid[prev_node_id] = node_id
return shortest_path_list
def main():
graph = Graph()
s = graph.add_node("<S>", 0)
sa = graph.add_node("さ", 100)
graph.add_edge(s, sa, 30)
saka = graph.add_node("さか", 200)
graph.add_edge(s, saka, 30)
sakana = graph.add_node("さかな", 100)
graph.add_edge(s, sakana, 30)
kana = graph.add_node("かな", 200)
graph.add_edge(sa, kana, 30)
nada = graph.add_node("なだ", 200)
graph.add_edge(saka, nada, 30)
da = graph.add_node("だ", 10)
graph.add_edge(kana, da, 30)
graph.add_edge(sakana, da, 30)
yo = graph.add_node("よ", 10)
graph.add_edge(da, yo, 30)
graph.add_edge(nada, yo, 45)
end = graph.add_node("</S>", 0)
graph.add_edge(yo, end, 30)
id_sumCost = viterbi(graph)
shortest_path_list = n_best(3, graph, id_sumCost)
for p in shortest_path_list:
print(p)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment