pytorch 如何 得到 张 量 在 某 个 维度 上 的 最 大 值 ?

vdzxcuhz  于 2022-11-09  发布在  其他
关注(0)|答案(2)|浏览(537)

我有一个3DTensor,想在Libtorch中沿第0维取最大值。
我知道如何在Python(PyTorch)中完成此操作,但在LibTorch中却遇到了麻烦。
在LibTorch中,我的代码是

auto target_q_T = torch::rand({5, 10, 1});
auto max_q = torch::max({target_q_T}, 0);
std::cout << max_q;

它会传回这个长且重复的错误。

note: candidate: ‘template<class _Traits> std::basic_ostream<char, _Traits>& std::operator<<(std::basic_ostream<char, _Traits>&, const char*)’
  611 |     operator<<(basic_ostream<char, _Traits>& __out, const char* __s)
      |     ^~~~~~~~
/usr/include/c++/11/ostream:611:5: note:   template argument deduction/substitution failed:
/home/iii/tor/m_gym/multiv_normal.cpp:432:18: note:   cannot convert ‘max_q’ (type ‘std::tuple<at::Tensor, at::Tensor>’) to type ‘const char*’
  432 |     std::cout << max_q;
      |                  ^~~~~
In file included from /usr/include/c++/11/istream:39,
                 from /usr/include/c++/11/sstream:38,
                 from /home/iii/tor/m_gym/libtorch/include/c10/macros/Macros.h:246,
                 from /home/iii/tor/m_gym/libtorch/include/c10/core/DeviceType.h:8,
                 from /home/iii/tor/m_gym/libtorch/include/c10/core/Device.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/ATen/core/TensorBody.h:11,
                 from /home/iii/tor/m_gym/libtorch/include/ATen/core/Tensor.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/ATen/Tensor.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/function_hook.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/cpp_hook.h:2,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/variable.h:6,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/autograd.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/autograd.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/all.h:7,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/torch.h:3,
                 from /home/iii/tor/m_gym/multiv_normal.cpp:2:
/usr/include/c++/11/ostream:624:5: note: candidate: ‘template<class _Traits> std::basic_ostream<char, _Traits>& std::operator<<(std::basic_ostream<char, _Traits>&, const signed char*)’
  624 |     operator<<(basic_ostream<char, _Traits>& __out, const signed char* __s)
      |     ^~~~~~~~

这是Python中的工作方式。

target_q_np = torch.rand(5, 10, 1)
max_q = torch.max(target_q_np, 0)
max_q

torch.return_types.max(
values=tensor([[0.8517],
        [0.7526],
        [0.6546],
        [0.9913],
        [0.8521],
        [0.9757],
        [0.9080],
        [0.9376],
        [0.9901],
        [0.7445]]),
indices=tensor([[4],
        [2],
        [3],
        [4],
        [1],
        [0],
        [2],
        [4],
        [4],
        [4]]))
nzkunb0c

nzkunb0c1#

如果你读到编译器错误,它基本上告诉你你正在尝试打印一个由两个Tensor组成的元组。这是因为C++代码的工作方式和python代码完全一样,返回最大值和它们各自的索引(你的python代码打印的就是这个)。你需要std get来从元组中提取Tensor:

auto target_q_T = torch::rand({5, 10, 1});
auto max_q = torch::max({target_q_T}, 0);
std::cout << "max: " << std::get<0>(max_q) 
          << "indices: " << std::get<1>(max_q)
          << std::endl;

在C++17中,您还应该能够编写

auto [max_t, idx_t] = torch::max({target_q_T}, 0);
std::cout << ... ;
acruukt9

acruukt92#

我从来没有发现max在LibTorch中的等价使用,就像在PyTorch中一样,所以我做了一个变通办法。
LibTorch中的max将从一个一维数组中获取max,因此我循环遍历索引数组并连接结果,实际上返回的结果与torch.max(target_q_np,0)相同。
我在LibTorch(C++)中的解决方案。最大值数组以与原始Tensor相反的顺序返回,所以我将其翻转。

auto target_q_T = torch::rand({5, 10, 1});

torch::Tensor zero_max;
for (int i=0; i<5; i++) {
    zero_max = torch::cat({torch::max({target_q_T[i]}).unsqueeze(-1), zero_max}, 0);
}
zero_max = zero_max.flip(-1);

相关问题