在PyTorch C++扩展中,如何访问Tensor中的单个元素并将其转换为标准的c++数据类型?

wwwo4jvm  于 2024-01-09  发布在  其他
关注(0)|答案(2)|浏览(106)

我正在为pytorch写一个c扩展,其中我需要通过索引访问一个Tensor的元素,我还需要将元素转换为标准的c类型。这里有一个简短的例子。假设我有一个2dTensora,我需要访问a[i][j]并将其转换为float。

  1. #include <torch/extension.h>
  2. float get(torch::Tensor a, int i, int j) {
  3. return a[i][j];
  4. }

字符串
上面的代码被放入一个名为tensortest.cpp的文件中。

  1. from setuptools import setup, Extension
  2. from torch.utils import cpp_extension
  3. setup(name='tensortest',
  4. ext_modules=[cpp_extension.CppExtension('tensortest_cpp', ['tensortest.cpp'])],
  5. cmdclass={'build_ext': cpp_extension.BuildExtension})


当我运行python setup.py install时,编译器报告以下错误

  1. running install
  2. running bdist_egg
  3. running egg_info
  4. creating tensortest.egg-info
  5. writing tensortest.egg-info/PKG-INFO
  6. writing dependency_links to tensortest.egg-info/dependency_links.txt
  7. writing top-level names to tensortest.egg-info/top_level.txt
  8. writing manifest file 'tensortest.egg-info/SOURCES.txt'
  9. /home/trisst/.local/lib/python3.8/site-packages/torch/utils/cpp_extension.py:335: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend.
  10. warnings.warn(msg.format('we could not find ninja.'))
  11. reading manifest file 'tensortest.egg-info/SOURCES.txt'
  12. writing manifest file 'tensortest.egg-info/SOURCES.txt'
  13. installing library code to build/bdist.linux-x86_64/egg
  14. running install_lib
  15. running build_ext
  16. building 'tensortest_cpp' extension
  17. creating build
  18. creating build/temp.linux-x86_64-3.8
  19. x86_64-linux-gnu-gcc -pthread -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/home/user/.local/lib/python3.8/site-packages/torch/include -I/home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/user/.local/lib/python3.8/site-packages/torch/include/TH -I/home/user/.local/lib/python3.8/site-packages/torch/include/THC -I/usr/include/python3.8 -c tensortest.cpp -o build/temp.linux-x86_64-3.8/tensortest.o -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=tensortest_cpp -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++14
  20. In file included from /home/user/.local/lib/python3.8/site-packages/torch/include/ATen/Parallel.h:149,
  21. from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/utils.h:3,
  22. from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/nn/cloneable.h:5,
  23. from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/nn.h:3,
  24. from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:7,
  25. from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/extension.h:4,
  26. from tensortest.cpp:1:
  27. /home/user/.local/lib/python3.8/site-packages/torch/include/ATen/ParallelOpenMP.h:84: warning: ignoring #pragma omp parallel [-Wunknown-pragmas]
  28. 84 | #pragma omp parallel for if ((end - begin) >= grain_size)
  29. |
  30. tensortest.cpp: In function float get(at::Tensor, int, int)’:
  31. tensortest.cpp:4:15: error: cannot convert at::Tensor to float in return
  32. 4 | return a[i][j];
  33. | ^
  34. error: command 'x86_64-linux-gnu-gcc' failed with exit status 1


我该怎么办?

xvw2m8pv

xvw2m8pv1#

编辑

  1. #include <torch/extension.h>
  2. float get(torch::Tensor a, int i, int j)
  3. {
  4. return a[i][j].item<float>();
  5. }

字符串

uujelgoq

uujelgoq2#

您可以使用tensor的访问器:它就像std::container上的迭代器,根据文档,它更有效。
https://pytorch.org/cppdocs/notes/tensor_basics.html

  1. float get(torch::Tensor a, int i, int j) {
  2. //a is 2-dimensional and holds floats.
  3. auto a_accessor = a.accessor<float, 2>();
  4. return a_accessor[i][j];
  5. }

字符串

相关问题