PyTorch获取模型的所有层

xriantvc  于 2022-11-09  发布在  其他
关注(0)|答案(7)|浏览(216)

获取pytorch模型并获得所有层的列表而不进行任何nn.Sequence分组的最简单的方法是什么?例如,更好的方法是什么?

import pretrainedmodels

def unwrap_model(model):
    for i in children(model):
        if isinstance(i, nn.Sequential): unwrap_model(i)
        else: l.append(i)

model = pretrainedmodels.__dict__['xception'](num_classes=1000, pretrained='imagenet')
l = []
unwrap_model(model)            

print(l)
ki1q1bka

ki1q1bka1#

你可以使用modules()方法迭代模型的所有模块(包括每个Sequential中的模块)。

>>> model = nn.Sequential(nn.Linear(2, 2), 
                          nn.ReLU(),
                          nn.Sequential(nn.Linear(2, 1),
                          nn.Sigmoid()))

>>> l = [module for module in model.modules() if not isinstance(module, nn.Sequential)]

>>> l

[Linear(in_features=2, out_features=2, bias=True),
 ReLU(),
 Linear(in_features=2, out_features=1, bias=True),
 Sigmoid()]
bd1hkmkf

bd1hkmkf2#

我为一个更深层次的模型计算了它,并不是所有的块都来自nn.sequential。

def get_children(model: torch.nn.Module):
    # get children form model!
    children = list(model.children())
    flatt_children = []
    if children == []:
        # if model has no children; model is last child! :O
        return model
    else:
       # look for children from children... to the last child!
       for child in children:
            try:
                flatt_children.extend(get_children(child))
            except TypeError:
                flatt_children.append(get_children(child))
    return flatt_children
zd287kbt

zd287kbt3#

如果你想把图层放在一个名为dict的地方,这是最简单的方法:

named_layers = dict(model.named_modules())

这将返回类似于以下内容的结果:

{
    'conv1': <some conv layer>,
    'fc1': < some fc layer>,
     ### and other layers 
}

示例:

import torchvision.models as models

model = models.inception_v3(pretrained = True)
named_layers = dict(model.named_modules())
jmp7cifd

jmp7cifd4#

我是这样做的:

def flatten(el):
    flattened = [flatten(children) for children in el.children()]
    res = [el]
    for c in flattened:
        res += c
    return res

cnn = nn.Sequential(Custom_block_1, Custom_block_2)
layers = flatten(cnn)
vaj7vani

vaj7vani5#

如果您想要一个以名称作为键,以模块作为值的嵌套字典,例如:

{'conv1': Conv2d(...),
 'bn1': BatchNorm2d(...),
 'block1':{
    'group1':{
        'conv1': Conv2d(...),
        'bn1': BatchNorm2d(...),
        'conv2': Conv2d(...),
        'bn2': BatchNorm2d(...),
    },
    'group2':{ ...
    }, ...
}

你可以把Kees和Mayukh Deb的答案结合起来得到:

def nested_children(m: torch.nn.Module):
    children = dict(m.named_children())
    output = {}
    if children == {}:
        # if module has no children; m is last child! :O
        return m
    else:
        # look for children from children... to the last child!
        for name, child in children.items():
            try:
                output[name] = nested_children(child)
            except TypeError:
                output[name] = nested_children(child)
    return output
798qvoo8

798qvoo86#

这是我的方法,通常可以在这里输入任何模型,它将返回一个所有torch.nn. * 的列表

def flatten_model(modules):
    def flatten_list(_2d_list):
        flat_list = []
        # Iterate through the outer list
        for element in _2d_list:
            if type(element) is list:
                # If the element is of type list, iterate through the sublist
                for item in element:
                    flat_list.append(item)
            else:
                flat_list.append(element)
        return flat_list

    ret = []
    try:
        for _, n in modules:
            ret.append(loopthrough(n))
    except:
        try:
            if str(modules._modules.items()) == "odict_items([])":
                ret.append(modules)
            else:
                for _, n in modules._modules.items():
                    ret.append(loopthrough(n))
        except:
            ret.append(modules)
    return flatten_list(ret)
nmpmafwu

nmpmafwu7#

扩展Ivan的答案https://stackoverflow.com/a/69544742/429476

target_layers =[]
module_list =[module for module in model.modules()] # this is needed
flatted_list= flatten_model(module_list)

for count, value in enumerate(flatted_list):

    if isinstance(value, (nn.Conv2d,nn.AvgPool2d,nn.BatchNorm2d)):
    #if isinstance(value, (nn.Conv2d)):
        print(count, value)
        target_layers.append(value)

ResNet50的结果

相关问题