我正在尝试在文本生成模型中实现一种波束搜索解码策略。这是我用来解码输出概率的函数。
def beam_search_decoder(data, k):
sequences = [[list(), 0.0]]
# walk over each step in sequence
for row in data:
all_candidates = list()
for i in range(len(sequences)):
seq, score = sequences[i]
for j in range(len(row)):
candidate = [seq + [j], score - torch.log(row[j])]
all_candidates.append(candidate)
# sort candidates by score
ordered = sorted(all_candidates, key=lambda tup:tup[1])
sequences = ordered[:k]
return sequences
字符串
现在您可以看到,此函数是在考虑batch_size 1的情况下实现的。为批处理大小添加另一个循环将使算法变为O(n^4)
。它像现在这样缓慢。有没有办法提高这个函数的速度。我的模型输出的大小通常为(32, 150, 9907)
,其格式为(batch_size, max_len, vocab_size)
3条答案
按热度按时间2izufjch1#
下面是我的实现,它可能比for循环实现快一点。
字符串
ftf50wuq2#
/!\投票最多的答案未执行正确的波束搜索!
Based on the version proposed by 防暴队大盾, I decided to implement a version of the beam-search algorithm that does not overlook sequences that share initial tokens.这是通过从flatten数组的索引中检索正确的索引来完成的
字符串
这个版本假设
prediction
对应于交叉熵分数,而不是概率。因此,没有必要在这里记录。如果有人知道如何使用一些花哨的索引来避免最内层的循环,那么他可能会让这更快。
xt0899hw3#
你可以使用这个图书馆
https://pypi.org/project/pytorch-beam-search/
它实现了Beam Search,Greedy Search和PyTorch序列模型的采样。
下面的代码片段实现了一个Transformer seq2seq模型,并使用它来生成预测。
字符串