如何从PyTorch模型中获取特定层的输出?

wixjitnu  于 11个月前  发布在  其他
关注(0)|答案(3)|浏览(129)

如何从预训练的PyTorch模型(如ResNet或VGG)中提取特定层的特征,而无需再次进行正向传递?

y3bcpkx1

y3bcpkx11#

新答案

编辑:there's a new feature in torchvision v0.11.0 that allows extracting features

例如,如果你想从层layer4.2.relu_2中提取特征,你可以这样做:

import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import create_feature_extractor

x = torch.rand(1, 3, 224, 224)

model = resnet50()

return_nodes = {
    "layer4.2.relu_2": "layer4"
}
model2 = create_feature_extractor(model, return_nodes=return_nodes)
intermediate_outputs = model2(x)

字符串

老答案

你可以在你想要的特定层上注册一个forward hook。比如:

def some_specific_layer_hook(module, input_, output):
    pass  # the value is in 'output'

model.some_specific_layer.register_forward_hook(some_specific_layer_hook)
    
model(some_input)


例如,要在ResNet中获得res5c输出,您可能需要使用nonlocal变量(或Python 2中的global):

res5c_output = None

def res5c_hook(module, input_, output):
    nonlocal res5c_output
    res5c_output = output

resnet.layer4.register_forward_hook(res5c_hook)

resnet(some_input)
    
# Then, use `res5c_output`.

5lhxktic

5lhxktic2#

接受的答案是非常有帮助的!我在这里发布了一个完整的例子(使用@bryant1410描述的注册钩子),为懒惰的人寻找一个工作解决方案:

import torch 
import torchvision.models as models
from torchvision import transforms
from PIL import Image

def get_feat_vector(path_img, model):
    '''
    Input: 
        path_img: string, /path/to/image
        model: a pretrained torch model
    Output:
        my_output: torch.tensor, output of avgpool layer
    '''
    input_image = Image.open(path_img)
    
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)

    with torch.no_grad():
        my_output = None
        
        def my_hook(module_, input_, output_):
            nonlocal my_output
            my_output = output_

        a_hook = model.avgpool.register_forward_hook(my_hook)        
        model(input_batch)
        a_hook.remove()
        return my_output

字符串
在这里,您可以使用要素提取函数,只需使用以下代码片段调用该函数即可从resnet18.avgpool图层中获取要素

model = models.resnet18(pretrained=True)
model.eval()
path_ = '/path/to/image'
my_feature = get_feat_vector(path_, model)

rsaldnfx

rsaldnfx3#

使用register_forward_hook的替代方案,但使用类而不是全局变量。
简单的例子:

class FeatureExtractor:
    def __init__(self):
        self.extracted_features = None

    def __call__(self, module, input_, output):
        self.extracted_features = output

extractor = FeatureExtractor()
model.some_specific_layer.register_forward_hook(extractor)
model(some_input)
extractor.extracted_features

字符串
从多个图层中提取(存储在字典中):

class FeatureExtractor:
    def __init__(self):
        self.extracted_features = dict()

    def extract_features(self, module, input_, output, name):
        self.extracted_features[name] = output

    def get_forward_hook(self, name):
        return functools.partial(self.extract_features, name=name)

model.some_specific_layer.register_forward_hook(extractor.get_forward_hook(layer_name))
model(some_input)
extractor.extracted_features[layer_name]


functools.partial允许我们创建一个Map到FeatureExtractor.extract_featurescallable,其中特定参数已经传递给 name 参数。

相关问题