import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import create_feature_extractor
x = torch.rand(1, 3, 224, 224)
model = resnet50()
return_nodes = {
"layer4.2.relu_2": "layer4"
}
model2 = create_feature_extractor(model, return_nodes=return_nodes)
intermediate_outputs = model2(x)
字符串
老答案
你可以在你想要的特定层上注册一个forward hook。比如:
def some_specific_layer_hook(module, input_, output):
pass # the value is in 'output'
model.some_specific_layer.register_forward_hook(some_specific_layer_hook)
model(some_input)
3条答案
按热度按时间y3bcpkx11#
新答案
编辑:there's a new feature in torchvision v0.11.0 that allows extracting features。
例如,如果你想从层
layer4.2.relu_2
中提取特征,你可以这样做:字符串
老答案
你可以在你想要的特定层上注册一个forward hook。比如:
型
例如,要在ResNet中获得
res5c
输出,您可能需要使用nonlocal
变量(或Python 2中的global
):型
5lhxktic2#
接受的答案是非常有帮助的!我在这里发布了一个完整的例子(使用@bryant1410描述的注册钩子),为懒惰的人寻找一个工作解决方案:
字符串
在这里,您可以使用要素提取函数,只需使用以下代码片段调用该函数即可从
resnet18.avgpool
图层中获取要素型
rsaldnfx3#
使用
register_forward_hook
的替代方案,但使用类而不是全局变量。简单的例子:
字符串
从多个图层中提取(存储在字典中):
型
functools.partial
允许我们创建一个Map到FeatureExtractor.extract_features
的 callable,其中特定参数已经传递给 name 参数。