如何将pyTorchTensor字典的元组转换为Tensor字典?

vi4fp9gy  于 2023-10-20  发布在  其他
关注(0)|答案(2)|浏览(225)

我有一个字典的元组,它包含pyTorchTensor:

tuple_of_dicts_of_tensors = (
    {'key_1': torch.tensor([1,1,1]), 'key_2': torch.tensor([4,4,4])},
    {'key_1': torch.tensor([2,2,2]), 'key_2': torch.tensor([5,5,5])},
    {'key_1': torch.tensor([3,3,3]), 'key_2': torch.tensor([6,6,6])}
)

我想把它转换成Tensor字典:

dict_of_tensors = {
    'key_1': torch.tensor([[1,1,1], [2,2,2], [3,3,3]]),
    'key_2': torch.tensor([[4,4,4], [5,5,5], [6,6,6]])
}

你建议怎么做?最有效的方法是什么?Tensor在GPU设备上,因此需要最少的for循环。
谢谢你,谢谢

odopli94

odopli941#

你可以使用torch内置的default_collate()函数:

import torch
from torch.utils.data import default_collate

tuple_of_dicts_of_tensors = (
    {'key_1': torch.tensor([1,1,1]), 'key_2': torch.tensor([4,4,4])},
    {'key_1': torch.tensor([2,2,2]), 'key_2': torch.tensor([5,5,5])},
    {'key_1': torch.tensor([3,3,3]), 'key_2': torch.tensor([6,6,6])}
)

dict_of_tensors = default_collate(tuple_of_dicts_of_tensors)
print(dict_of_tensors)

# >>> {'key_1': tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]),
#      'key_2': tensor([[4, 4, 4], [5, 5, 5], [6, 6, 6]])}

这是一个非常强大的功能,尽管它的文档可能并不清楚。引用文件中的简短内容:
该函数接受一批数据,并将批内的元素放入具有额外外部维度(批大小)的Tensor中。
下面是一般的输入类型(基于批处理中元素的类型)到输出类型的Map:

  • torch.Tensor -> torch.Tensor(添加了外部尺寸批量大小)
  • Map[K,V_i] ->Map[K,default_collate([V_1,V_2,...])]

在您的情况下,批处理的元素(即你的元组)是Map(即,Tensor(tensors)。所以,

  • 在第一步中,Map被“移到了外部”--这意味着你最终得到了一个dict(引用文档中的第二个项目符号);
  • 在第二步中,该函数再次应用于dict的所有值,这些值是Tensor-这意味着,对于每个键,Tensor都被整理成一个具有新批处理维度的Tensor(引用文档的第一个项目符号)。

换句话说,您可以将default_collate()的任务视为 * 将批处理维度向内移动 *:一批包含 B 对象的 A 对象(在您的情况下:包含Tensor对象的字典对象的元组)变成 B 批对象的 A 对象(在您的情况下:批量Tensor的字典,其中每个“批量Tensor”再次是具有新的前置批量维度的单个Tensor)。

qybjjes1

qybjjes12#

在这段代码中,dict_of_lists_of_tensors首先是通过使用字典理解从字典元组中提取每个键的Tensor来构造的。然后,对于每个键,使用torch.stack()将Tensor列表沿一个新的维度沿着堆叠,这将为您提供所需的Tensor字典。这种方法最大限度地减少了显式循环的使用,并利用了PyTorch的GPU加速操作。

import torch

tuple_of_dicts_of_tensors = (
    {'key_1': torch.tensor([1,1,1]), 'key_2': torch.tensor([4,4,4])},
    {'key_1': torch.tensor([2,2,2]), 'key_2': torch.tensor([5,5,5])},
    {'key_1': torch.tensor([3,3,3]), 'key_2': torch.tensor([6,6,6])}
)
dict_of_lists_of_tensors = {key: [d[key] for d in tuple_of_dicts_of_tensors] for key in tuple_of_dicts_of_tensors[0]}
dict_of_tensors = {key: torch.stack(tensor_list) for key, tensor_list in dict_of_lists_of_tensors.items()}

print(dict_of_tensors)

输出:-

{'key_1': tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]]), 'key_2': tensor([[4, 4, 4],
        [5, 5, 5],
        [6, 6, 6]])}

相关问题