我有一个字典的元组,它包含pyTorchTensor:
tuple_of_dicts_of_tensors = (
{'key_1': torch.tensor([1,1,1]), 'key_2': torch.tensor([4,4,4])},
{'key_1': torch.tensor([2,2,2]), 'key_2': torch.tensor([5,5,5])},
{'key_1': torch.tensor([3,3,3]), 'key_2': torch.tensor([6,6,6])}
)
我想把它转换成Tensor字典:
dict_of_tensors = {
'key_1': torch.tensor([[1,1,1], [2,2,2], [3,3,3]]),
'key_2': torch.tensor([[4,4,4], [5,5,5], [6,6,6]])
}
你建议怎么做?最有效的方法是什么?Tensor在GPU设备上,因此需要最少的for循环。
谢谢你,谢谢
2条答案
按热度按时间odopli941#
你可以使用torch内置的
default_collate()
函数:这是一个非常强大的功能,尽管它的文档可能并不清楚。引用文件中的简短内容:
该函数接受一批数据,并将批内的元素放入具有额外外部维度(批大小)的Tensor中。
下面是一般的输入类型(基于批处理中元素的类型)到输出类型的Map:
在您的情况下,批处理的元素(即你的元组)是Map(即,Tensor(tensors)。所以,
换句话说,您可以将
default_collate()
的任务视为 * 将批处理维度向内移动 *:一批包含 B 对象的 A 对象(在您的情况下:包含Tensor对象的字典对象的元组)变成 B 批对象的 A 对象(在您的情况下:批量Tensor的字典,其中每个“批量Tensor”再次是具有新的前置批量维度的单个Tensor)。qybjjes12#
在这段代码中,dict_of_lists_of_tensors首先是通过使用字典理解从字典元组中提取每个键的Tensor来构造的。然后,对于每个键,使用torch.stack()将Tensor列表沿一个新的维度沿着堆叠,这将为您提供所需的Tensor字典。这种方法最大限度地减少了显式循环的使用,并利用了PyTorch的GPU加速操作。
输出:-