Skip to content

Instantly share code, notes, and snippets.

@xsthunder
Created March 18, 2021 01:55
Show Gist options
  • Save xsthunder/ed1aaff2e45fdf6dd59a4f9e6b0814b8 to your computer and use it in GitHub Desktop.
Save xsthunder/ed1aaff2e45fdf6dd59a4f9e6b0814b8 to your computer and use it in GitHub Desktop.
su's single batch version for beam search
#!/usr/bin/env python
# coding: utf-8
# # 学习 bean search
#
# ## ref
#
# [xueke.fm 束搜索](https://kexue.fm/archives/7500/comment-page-1)
#
# [C5W3L03 Beam Search - YouTube](https://www.youtube.com/watch?v=RLWuzLLSIgw)
#
#
# ## task
# [苏神单条版beam search解码](https://github.com/bojone/bert4keras/blob/34eab054227ea529c26b2311151327f1f3dd108d/bert4keras/snippets.py#L530)改写为batch版
#
# 直接看看不懂,拿薛神代码看看
# In[2]:
import numpy as np
# ## 测试数据
# from [How to Implement a Beam Search Decoder for Natural Language Processing](https://machinelearningmastery.com/beam-search-decoder-natural-language-processing/)
# In[8]:
data = np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1]])
VOC_SIZE=len(data[0])
BOS='<BOS>'
BOS_id=0
EOS_id=1
EOS='<EOS>'
maxlen=len(data)
# ## 移植苏神版
# In[43]:
class SuGod:
# copy from su, 可能有问题
first_output_ids=np.array([[]])
maxlen=maxlen
minlen=1
end_id=EOS_id
def predict(self, inputs, output_ids, states, *args):
"""
预测
"""
# 不需要在beam_search内使用,持续回传给此方法
logits = data[states]
# inputs = [(batch_size, ...), (batch_size, ...), ...] # a form of multi input
batch_size = len(inputs[0][0])
# 广播
logits = logits[None, ].repeat(batch_size, axis=0)
assert logits.shape == (batch_size, VOC_SIZE)
states = states+1
return logits, states
def beam_search(self, inputs, topk, states=None, temperature=1, min_ends=1):
"""beam search解码
说明:这里的topk即beam size;
返回:最优解码序列。
inputs: 多输入,如[token_ids, segment_ids]
"""
inputs = [np.array([i]) for i in inputs]
output_ids, output_scores = self.first_output_ids, np.zeros(1)
for step in range(self.maxlen):
scores, states = self.predict(
inputs, output_ids, states, temperature, 'logits'
) # 计算当前得分
if step == 0: # 第1步预测后将输入重复topk次
inputs = [np.repeat(i, topk, axis=0) for i in inputs]
scores = output_scores.reshape((-1, 1)) + scores # 综合累积得分
indices = scores.argpartition(-topk, axis=None)[-topk:] # 仅保留topk
indices_1 = indices // scores.shape[1] # 行索引
indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) # 列索引
output_ids = np.concatenate([output_ids[indices_1], indices_2],
1) # 更新输出
output_scores = np.take_along_axis(
scores, indices, axis=None
) # 更新得分
end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记
if output_ids.shape[1] >= self.minlen: # 最短长度判断
best_one = output_scores.argmax() # 得分最大的那个
if end_counts[best_one] == min_ends: # 如果已经终止
return output_ids[best_one] # 直接输出
else: # 否则,只保留未完成部分
flag = (end_counts < min_ends) # 标记未完成序列
if not flag.all(): # 如果有已完成的
inputs = [i[flag] for i in inputs] # 扔掉已完成序列
output_ids = output_ids[flag] # 扔掉已完成序列
output_scores = output_scores[flag] # 扔掉已完成序列
end_counts = end_counts[flag] # 扔掉已完成end计数
topk = flag.sum() # topk相应变化
# 达到长度直接输出
return output_ids[output_scores.argmax()]
# In[59]:
sg = SuGod()
print(
sg.beam_search([np.zeros((1,3))], 2, states=0)
)
# 应该输出两行,苏神的版本不支持多句
print(
sg.beam_search([np.zeros((2,3))], 2, states=0)
)
# In[ ]:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment