python 为rnnDataset实现一个嵌入器,它接收输入Tensor,将其转换为单词->标记->嵌入,然后存储嵌入

koaltpgm  于 2022-12-21  发布在  Python
关注(0)|答案(1)|浏览(121)

我需要写一个函数来嵌入一个RNN数据集,它将输入作为Tensor并将其转换为单词。我知道这个函数具有下面的结构,但不知道如何声明参数。

def rnn_embedder(tensor, embedding_length):
'''
# Takes a tensor and a vocabulary and returns the BoW embedding of that tensor
# Args:
    tensor (torch.Tensor): A tensor of words represented by their index in the vocabulary
    vocab_lenght (int): The number of entries in the vocabulary
Returns (torch.Tensor): An tensor containing the BoW embedding of the input tensor
'''

tensor = tensor.long()
embedding = ...
words = ...

for ...

return numpy.asarray(embedding)
xuo3flqw

xuo3flqw1#

要声明rnn_embedder函数的参数,您需要指定函数将接受的输入的类型和名称。
第一个参数,Tensor,是一个由词汇表中的索引表示的单词Tensor。这个参数应该是一个 Torch 。Tensor对象。
第二个参数embedding_length是一个整数,表示词汇表中的条目数,这个参数应该是一个int对象。
使用此信息,可以按如下所示声明rnn_embedder函数的参数:

def rnn_embedder(tensor: torch.Tensor, embedding_length: int) -> torch.Tensor:

函数定义指定rnn_embedder函数采用两个参数:turch.tensor类型的Tensor和int类型的embedding_length,并返回turch.tensor对象。

相关问题