keras 'model.summary()',TensorFlow模型将打印输出形状子类化为“multiple”

vulvrdjw  于 2022-12-04  发布在  其他
关注(0)|答案(1)|浏览(233)

我尝试用下面的VggBlock实现Vgg网络。

class VggBlock(tf.keras.Model):
  def __init__(self, filters, repetitions):
    super(VggBlock, self).__init__()
    self.repetitions = repetitions

    self.conv_layers = [Conv2D(filters=filters, kernel_size=(3, 3), padding='same', activation='relu') for _ in range(repetitions)]
    self.max_pool = MaxPool2D(pool_size=(2, 2))

  def call(self, inputs):
    x = inputs
    for layer in self.conv_layers:
      x = layer(x)
    return self.max_pool(x)

test_block = VggBlock(filters=64, repetitions=2)
temp_inputs = Input(shape=(224, 224, 3))
test_block(temp_inputs)
test_block.summary()

然后,上面的代码将打印:

Model: "vgg_block"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             multiple                  1792      
                                                                 
 conv2d_1 (Conv2D)           multiple                  36928     
                                                                 
 max_pooling2d (MaxPooling2D  multiple                 0         
 )                                                               
                                                                 
=================================================================
Total params: 38,720
Trainable params: 38,720
Non-trainable params: 0
_________________________________________________________________

而且如果我用这些块构建Vgg,它的summary()也会打印“multiple”。
有一些问题与我的问题类似,例如:https://github.com/keras-team/keras/issues/13782model.summary() can't print output shape while using subclass model
但是,我不能把答案延伸到第二个环节:以变化的input_shape表示。
我如何处理summary(),以便使“多个”成为一个适当的形状。

muk1a3rh

muk1a3rh1#

您已经链接了一些解决方案。您似乎在这里着陆,因为无法确定每个层的输出形状。如下所述:
您可以在“功能”或“顺序”模型中执行所有这些操作(打印输入/输出形状),因为这些模型是层的静态图形。
相反,子类化模型是一段Python代码(一个调用方法),这里没有层的图,我们无法知道层是如何相互连接的(因为这是在调用的主体中定义的,而不是显式的数据结构),所以我们无法推断输入/输出形状。
您也可以尝试以下操作:
第一个

相关问题