Pytorch中的分批束流搜索

k2fxgqgv  于 2023-08-05  发布在  其他
关注(0)|答案(3)|浏览(96)

我正在尝试在文本生成模型中实现一种波束搜索解码策略。这是我用来解码输出概率的函数。

def beam_search_decoder(data, k):
    sequences = [[list(), 0.0]]
    # walk over each step in sequence
    for row in data:
        all_candidates = list()
        for i in range(len(sequences)):
            seq, score = sequences[i]
            for j in range(len(row)):
                candidate = [seq + [j], score - torch.log(row[j])]
                all_candidates.append(candidate)
        # sort candidates by score
        ordered = sorted(all_candidates, key=lambda tup:tup[1])
        sequences = ordered[:k]
    return sequences

字符串
现在您可以看到,此函数是在考虑batch_size 1的情况下实现的。为批处理大小添加另一个循环将使算法变为O(n^4)。它像现在这样缓慢。有没有办法提高这个函数的速度。我的模型输出的大小通常为(32, 150, 9907),其格式为(batch_size, max_len, vocab_size)

2izufjch

2izufjch1#

下面是我的实现,它可能比for循环实现快一点。

import torch

def beam_search_decoder(post, k):
    """Beam Search Decoder

    Parameters:

        post(Tensor) – the posterior of network.
        k(int) – beam size of decoder.

    Outputs:

        indices(Tensor) – a beam of index sequence.
        log_prob(Tensor) – a beam of log likelihood of sequence.

    Shape:

        post: (batch_size, seq_length, vocab_size).
        indices: (batch_size, beam_size, seq_length).
        log_prob: (batch_size, beam_size).

    Examples:

        >>> post = torch.softmax(torch.randn([32, 20, 1000]), -1)
        >>> indices, log_prob = beam_search_decoder(post, 3)

    """

    batch_size, seq_length, _ = post.shape
    log_post = post.log()
    log_prob, indices = log_post[:, 0, :].topk(k, sorted=True)
    indices = indices.unsqueeze(-1)
    for i in range(1, seq_length):
        log_prob = log_prob.unsqueeze(-1) + log_post[:, i, :].unsqueeze(1).repeat(1, k, 1)
        log_prob, index = log_prob.view(batch_size, -1).topk(k, sorted=True)
        indices = torch.cat([indices, index.unsqueeze(-1)], dim=-1)
    return indices, log_prob

字符串

ftf50wuq

ftf50wuq2#

/!\投票最多的答案未执行正确的波束搜索!

Based on the version proposed by 防暴队大盾, I decided to implement a version of the beam-search algorithm that does not overlook sequences that share initial tokens.这是通过从flatten数组的索引中检索正确的索引来完成的

def beam_search(prediction, k=10):
    batch_size, seq_length, vocab_size = prediction.shape
    log_prob, indices = prediction[:, 0, :].topk(k, sorted=True)
    indices = indices.unsqueeze(-1)
    for n1 in range(1, seq_length):
        log_prob_temp = log_prob.unsqueeze(-1) + prediction[:, n1, :].unsqueeze(1).repeat(1, k, 1)
        log_prob, index_temp = log_prob_temp.view(batch_size, -1).topk(k, sorted=True)
        idx_begin = index_temp // vocab_size  # retrieve index of start sequence
        idx_concat = index_temp % vocab_size  # retrieve index of new token
        new_indices = torch.zeros((batch_size, k, n1+1), dtype=torch.int64)
        for n2 in range(batch_size):
            new_indices[n2, :, :-1] = indices[n2][idx_begin[n2]]
            new_indices[n2, :, -1] = idx_concat[n2]
        indices = new_indices
    return indices, log_prob

字符串
这个版本假设prediction对应于交叉熵分数,而不是概率。因此,没有必要在这里记录。
如果有人知道如何使用一些花哨的索引来避免最内层的循环,那么他可能会让这更快。

xt0899hw

xt0899hw3#

你可以使用这个图书馆
https://pypi.org/project/pytorch-beam-search/
它实现了Beam Search,Greedy Search和PyTorch序列模型的采样。
下面的代码片段实现了一个Transformer seq2seq模型,并使用它来生成预测。

#pip install pytorch-beam-search
from pytorch_beam_search import seq2seq

# Create vocabularies
# Tokenize the way you need
source = [list("abcdefghijkl"), list("mnopqrstwxyz")]
target = [list("ABCDEFGHIJKL"), list("MNOPQRSTWXYZ")]
# An Index object represents a mapping from the vocabulary to
# to integers (indices) to feed into the models
source_index = seq2seq.Index(source)
target_index = seq2seq.Index(target)

# Create tensors
X = source_index.text2tensor(source)
Y = target_index.text2tensor(target)
# X.shape == (n_source_examples, len_source_examples) == (2, 11)
# Y.shape == (n_target_examples, len_target_examples) == (2, 12)

# Create and train the model
model = seq2seq.Transformer(source_index, target_index)    # just a PyTorch model
model.fit(X, Y, epochs = 100)    # basic method included

# Generate new predictions
new_source = [list("new first in"), list("new second in")]
new_target = [list("new first out"), list("new second out")]
X_new = source_index.text2tensor(new_source)
Y_new = target_index.text2tensor(new_target)
loss, error_rate = model.evaluate(X_new, Y_new)    # basic method included
predictions, log_probabilities = seq2seq.beam_search(model, X_new) 
output = [target_index.tensor2text(p) for p in predictions]
output

字符串

相关问题