Created
March 18, 2021 01:55
-
-
Save xsthunder/ed1aaff2e45fdf6dd59a4f9e6b0814b8 to your computer and use it in GitHub Desktop.
su's single batch version for beam search
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
# # 学习 bean search | |
# | |
# ## ref | |
# | |
# [xueke.fm 束搜索](https://kexue.fm/archives/7500/comment-page-1) | |
# | |
# [C5W3L03 Beam Search - YouTube](https://www.youtube.com/watch?v=RLWuzLLSIgw) | |
# | |
# | |
# ## task | |
# [苏神单条版beam search解码](https://github.com/bojone/bert4keras/blob/34eab054227ea529c26b2311151327f1f3dd108d/bert4keras/snippets.py#L530)改写为batch版 | |
# | |
# 直接看看不懂,拿薛神代码看看 | |
# In[2]: | |
import numpy as np | |
# ## 测试数据 | |
# from [How to Implement a Beam Search Decoder for Natural Language Processing](https://machinelearningmastery.com/beam-search-decoder-natural-language-processing/) | |
# In[8]: | |
data = np.array([[0.1, 0.2, 0.3, 0.4, 0.5], | |
[0.5, 0.4, 0.3, 0.2, 0.1], | |
[0.1, 0.2, 0.3, 0.4, 0.5], | |
[0.5, 0.4, 0.3, 0.2, 0.1], | |
[0.1, 0.2, 0.3, 0.4, 0.5], | |
[0.5, 0.4, 0.3, 0.2, 0.1], | |
[0.1, 0.2, 0.3, 0.4, 0.5], | |
[0.5, 0.4, 0.3, 0.2, 0.1], | |
[0.1, 0.2, 0.3, 0.4, 0.5], | |
[0.5, 0.4, 0.3, 0.2, 0.1]]) | |
VOC_SIZE=len(data[0]) | |
BOS='<BOS>' | |
BOS_id=0 | |
EOS_id=1 | |
EOS='<EOS>' | |
maxlen=len(data) | |
# ## 移植苏神版 | |
# In[43]: | |
class SuGod: | |
# copy from su, 可能有问题 | |
first_output_ids=np.array([[]]) | |
maxlen=maxlen | |
minlen=1 | |
end_id=EOS_id | |
def predict(self, inputs, output_ids, states, *args): | |
""" | |
预测 | |
""" | |
# 不需要在beam_search内使用,持续回传给此方法 | |
logits = data[states] | |
# inputs = [(batch_size, ...), (batch_size, ...), ...] # a form of multi input | |
batch_size = len(inputs[0][0]) | |
# 广播 | |
logits = logits[None, ].repeat(batch_size, axis=0) | |
assert logits.shape == (batch_size, VOC_SIZE) | |
states = states+1 | |
return logits, states | |
def beam_search(self, inputs, topk, states=None, temperature=1, min_ends=1): | |
"""beam search解码 | |
说明:这里的topk即beam size; | |
返回:最优解码序列。 | |
inputs: 多输入,如[token_ids, segment_ids] | |
""" | |
inputs = [np.array([i]) for i in inputs] | |
output_ids, output_scores = self.first_output_ids, np.zeros(1) | |
for step in range(self.maxlen): | |
scores, states = self.predict( | |
inputs, output_ids, states, temperature, 'logits' | |
) # 计算当前得分 | |
if step == 0: # 第1步预测后将输入重复topk次 | |
inputs = [np.repeat(i, topk, axis=0) for i in inputs] | |
scores = output_scores.reshape((-1, 1)) + scores # 综合累积得分 | |
indices = scores.argpartition(-topk, axis=None)[-topk:] # 仅保留topk | |
indices_1 = indices // scores.shape[1] # 行索引 | |
indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) # 列索引 | |
output_ids = np.concatenate([output_ids[indices_1], indices_2], | |
1) # 更新输出 | |
output_scores = np.take_along_axis( | |
scores, indices, axis=None | |
) # 更新得分 | |
end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记 | |
if output_ids.shape[1] >= self.minlen: # 最短长度判断 | |
best_one = output_scores.argmax() # 得分最大的那个 | |
if end_counts[best_one] == min_ends: # 如果已经终止 | |
return output_ids[best_one] # 直接输出 | |
else: # 否则,只保留未完成部分 | |
flag = (end_counts < min_ends) # 标记未完成序列 | |
if not flag.all(): # 如果有已完成的 | |
inputs = [i[flag] for i in inputs] # 扔掉已完成序列 | |
output_ids = output_ids[flag] # 扔掉已完成序列 | |
output_scores = output_scores[flag] # 扔掉已完成序列 | |
end_counts = end_counts[flag] # 扔掉已完成end计数 | |
topk = flag.sum() # topk相应变化 | |
# 达到长度直接输出 | |
return output_ids[output_scores.argmax()] | |
# In[59]: | |
sg = SuGod() | |
print( | |
sg.beam_search([np.zeros((1,3))], 2, states=0) | |
) | |
# 应该输出两行,苏神的版本不支持多句 | |
print( | |
sg.beam_search([np.zeros((2,3))], 2, states=0) | |
) | |
# In[ ]: | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment