pytorch函数(如RoIPool)是如何工作的?

ryevplcw  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(143)

例如,我试图查看RoI Pooling在pytorch中的实现。
下面的代码片段展示了如何在pytorch中使用RoIPool

import torch
from torchvision.ops.roi_pool import RoIPool

device = torch.device('cuda')

# create feature layer, proposals and targets

num_proposals = 10
feature_map = torch.randn(1, 64, 32, 32)

proposals = torch.zeros((num_proposals, 4))
proposals[:, 0] = torch.randint(0, 16, (num_proposals,))
proposals[:, 1] = torch.randint(0, 16, (num_proposals,))
proposals[:, 2] = torch.randint(16, 32, (num_proposals,))
proposals[:, 3] = torch.randint(16, 32, (num_proposals,))

roi_pool_obj = RoIPool(3, 2**-1)
roi_pool = roi_pool_obj(feature_map, [proposals])

我使用的是pychram,所以当我从第二行开始跟随RoIPool时,它会打开一个位于~/anaconda3/envs/CV/lib/python3.8/site-package/torchvision/ops/roi_pool.py的文件,这与文档中的代码完全相同。
我粘贴了下面的代码,没有文档。

from typing import List, Union

import torch
from torch import nn, Tensor
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops

from ..utils import _log_api_usage_once
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape

def roi_pool(
    input: Tensor,
    boxes: Union[Tensor, List[Tensor]],
    output_size: BroadcastingList2[int],
    spatial_scale: float = 1.0,
) -> Tensor:

    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(roi_pool)
    _assert_has_ops()
    check_roi_boxes_shape(boxes)
    rois = boxes
    output_size = _pair(output_size)
    if not isinstance(rois, torch.Tensor):
        rois = convert_boxes_to_roi_format(rois)
    output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, output_size[0], output_size[1])
    return output

class RoIPool(nn.Module):

    def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float):
        super().__init__()
        _log_api_usage_once(self)
        self.output_size = output_size
        self.spatial_scale = spatial_scale

    def forward(self, input: Tensor, rois: Tensor) -> Tensor:
        return roi_pool(input, rois, self.output_size, self.spatial_scale)

    def __repr__(self) -> str:
        s = f"{self.__class__.__name__}(output_size={self.output_size}, spatial_scale={self.spatial_scale})"
        return s

因此,在程式码范例中:
当运行roi_pool_obj = RoIPool(3, 2**-1)时,它将通过调用__init__方法创建RoIPool的示例,该方法只初始化了两个示例变量;
当运行roi_pool = roi_pool_obj(feature_map, [proposals])时,它必须调用forward()方法(但我不知道如何调用),然后该方法调用上面的roi_pool()函数;
当运行roi_pool()函数时,它首先进行一些检查,然后使用output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, output_size[0], output_size[1])行计算输出。
但这并没有显示roi_pool是如何实现的细节,当我试图跟踪torch.ops.torchvision.roi_pool时,pycharm显示了Cannot find declaration to go to
总结一下,我有两个问题:
1.如何通过运行roi_pool = roi_pool_obj(feature_map, [proposals])来调用forward()
1.如何查看torch.ops.torchvision.roi_pool的源代码,或者包含其实现的文件位于何处?
最后但并非最不重要的是,我刚刚开始阅读源代码,这对我来说是相当困难的。如果你也能提供一些建议或教程,我将不胜感激。

lg40wkob

lg40wkob1#

  1. RoIPool是torch.nn.Module的一个子类。源代码:
    https://github.com/pytorch/vision/blob/07ae61bf9c21ddd1d5f65d326aa9636849b383ca/torchvision/ops/roi_pool.py#L56
  2. nn。模块定义了__call__方法,该方法依次调用forward方法。源代码:
    https://github.com/pytorch/pytorch/blob/b2311192e6c4745aac3fdd774ac9d56a36b396d4/torch/nn/modules/module.py#L1234
    1.当您执行roi_pool = roi_pool_obj(feature_map, [proposals])语句时,__call__方法使用RoiPool的forward()。源代码:
    https://github.com/pytorch/vision/blob/07ae61bf9c21ddd1d5f65d326aa9636849b383ca/torchvision/ops/roi_pool.py#L67
  3. RoiPool.forward调用torch.ops.torchvision.roi_pool
    https://github.com/pytorch/vision/blob/07ae61bf9c21ddd1d5f65d326aa9636849b383ca/torchvision/ops/roi_pool.py#L52
  4. ops是加载用C++实现的本机库的对象:
    https://github.com/pytorch/pytorch/blob/b2311192e6c4745aac3fdd774ac9d56a36b396d4/torch/_ops.py#L537
    因此当您调用torch.ops.torchvision时,它将使用torchvision库。
    1.这里注册了roi_pool函数:
    https://github.com/pytorch/vision/blob/7947fc8fb38b1d3a2aca03f22a2e6a3caa63f2a0/torchvision/csrc/ops/roi_pool.cpp#L53
    1.在这里,您可以找到rol_pool的实际实现
    中央处理器:
    https://github.com/pytorch/vision/blob/7947fc8fb38b1d3a2aca03f22a2e6a3caa63f2a0/torchvision/csrc/ops/cpu/roi_pool_kernel.cpp
    图形处理器:https://github.com/pytorch/vision/blob/7947fc8fb38b1d3a2aca03f22a2e6a3caa63f2a0/torchvision/csrc/ops/cuda/roi_pool_kernel.cu

相关问题