从tf.keras.Model中检索Keras层属性

4smxwvx5  于 2023-05-18  发布在  其他
关注(0)|答案(1)|浏览(118)

我用下面的例子来说明我的问题:

class Encoder(K.layers.Layer):
    def __init__(self, filters):
        super(Encoder, self).__init__()
        self.conv1 = Conv2D(filters=filters[0], kernel_size=3, strides=1, activation='relu', padding='same')
        self.conv2 = Conv2D(filters=filters[1], kernel_size=3, strides=1, activation='relu', padding='same')
        self.conv3 = Conv2D(filters=filters[2], kernel_size=3, strides=1, activation='relu', padding='same')
        self.pool = MaxPooling2D((2, 2), padding='same')
               
    
    def call(self, input_features):
        x = self.conv1(input_features)
        #print("Ex1", x.shape)
        x = self.pool(x)
        #print("Ex2", x.shape)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.conv3(x)
        x = self.pool(x)
        return x

class Decoder(K.layers.Layer):
    def __init__(self, filters):
        super(Decoder, self).__init__()
        self.conv1 = Conv2D(filters=filters[2], kernel_size=3, strides=1, activation='relu', padding='same')
        self.conv2 = Conv2D(filters=filters[1], kernel_size=3, strides=1, activation='relu', padding='same')
        self.conv3 = Conv2D(filters=filters[0], kernel_size=3, strides=1, activation='relu', padding='valid')
        self.conv4 = Conv2D(1, 3, 1, activation='sigmoid', padding='same')
        self.upsample = UpSampling2D((2, 2))
  
    def call(self, encoded):
        x = self.conv1(encoded)
        print("dx1", x.shape)
        x = self.upsample(x)
        #print("dx2", x.shape)
        x = self.conv2(x)
        x = self.upsample(x)
        x = self.conv3(x)
        x = self.upsample(x)
        return self.conv4(x)

class Autoencoder(K.Model):
    def __init__(self, filters):
        super(Autoencoder, self).__init__()
        self.loss = []
        self.encoder = Encoder(filters)
        self.decoder = Decoder(filters)

    def call(self, input_features):
        #print(input_features.shape)
        encoded = self.encoder(input_features)
        #print(encoded.shape)
        reconstructed = self.decoder(encoded)
        #print(reconstructed.shape)
        return reconstructed

max_epochs = 5
model = Autoencoder(filters)

model.compile(loss='binary_crossentropy', optimizer='adam')

loss = model.fit(x_train_noisy,
                x_train,
                validation_data=(x_test_noisy, x_test),
                epochs=max_epochs,
                batch_size=batch_size)

正如你所看到的,model是使用keras.Layer中的一些层创建的,那么如果我想使用model.summary()函数来显示模型的架构,我将有:

Model: "autoencoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 encoder (Encoder)           multiple                  14192     
                                                                 
 decoder (Decoder)           multiple                  16497     
                                                                 
=================================================================
Total params: 30,689
Trainable params: 30,689
Non-trainable params: 0

对我来说,我想有一个更详细的描述编码器层和解码器层。有什么想法吗

km0tfn4u

km0tfn4u1#

之所以会得到这样的输出是因为使用了子类API来构建模型。众所周知,与顺序或函数API不同,子类API不允许您构建模型摘要或绘图函数。这里有两个非常相关的职位存在。

但是,在您的情况下,您可能需要更改设置以使summaryplot_model有用。那些是
1.子类keras.Model而不是编码器和解码器子组件的keras.layers.Layer
1.在init方法中初始化图层时,请确保这些图层的顺序与call方法的顺序相同。

编码器

在上面的1和2之后。

class Encoder(keras.Model):
    def __init__(self, filters):
        super().__init__(name='Encoder')
        self.conv1 = Conv2D(
            filters=filters[0], 
            kernel_size=3, 
            strides=1, 
            activation='relu', 
            padding='same'
        )
        self.pool1 = MaxPooling2D((2, 2), padding='same')
        self.conv2 = Conv2D(
            filters=filters[1], 
            kernel_size=3, 
            strides=1, 
            activation='relu', 
            padding='same'
        )
        self.pool2 = MaxPooling2D((2, 2), padding='same')
        self.conv3 = Conv2D(
            filters=filters[2], 
            kernel_size=3, 
            strides=1, 
            activation='relu', 
            padding='same'
        )
        self.pool3 = MaxPooling2D((2, 2), padding='same')
        
    def call(self, input_features):
        x = self.conv1(input_features)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.pool3(x)
        return x

解码器

在上面的1和2之后。

class Decoder(keras.Model):
    def __init__(self, filters):
        super().__init__(name='Decoder')
        self.conv1 = Conv2D(
            filters=filters[2], 
            kernel_size=3, 
            strides=1, 
            activation='relu', 
            padding='same'
        )
        self.upsample1 = UpSampling2D((2, 2))
        self.conv2 = Conv2D(
            filters=filters[1], 
            kernel_size=3, 
            strides=1, 
            activation='relu', 
            padding='same'
        )
        self.upsample2 = UpSampling2D((2, 2))
        self.conv3 = Conv2D(
            filters=filters[0],
            kernel_size=3, 
            strides=1, 
            activation='relu', 
            padding='valid'
        )
        self.upsample3 = UpSampling2D((2, 2))
        self.conv4 = Conv2D(1, 3, 1, activation='sigmoid', padding='same')
  
    def call(self, encoded):
        x = self.conv1(encoded)
        x = self.upsample1(x)
        x = self.conv2(x)
        x = self.upsample2(x)
        x = self.conv3(x)
        x = self.upsample3(x)
        return self.conv4(x)

自动编码器

由于上面的1和2,我们将如下构建自动编码器。

class Autoencoder(keras.Model):
    def __init__(self, filters):
        super().__init__(name='Autoencoder')
        self.encoder = Encoder(filters)
        self.decoder = Decoder(filters)

    def call(self, input_features):
        x = input_features
        
        for layer in self.encoder.layers:
            x = layer(x)
        
        for layer in self.decoder.layers: 
            x = layer(x)
        
        return x

构建模型

model = Autoencoder(filters)
model.build(input_shape=(1, 224, 224, 3))
model.summary(
    expand_nested=True, 
    line_length=80, show_trainable=True
)
Model: "Autoencoder"
___________________________________________________________________________________________
 Layer (type)                       Output Shape                    Param #     Trainable  
===========================================================================================
 Encoder (Encoder)                  multiple                        0 (unused)  Y          
|¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯|
| conv2d_69 (Conv2D)               multiple                        3584        Y          |
|                                                                                         |
| max_pooling2d_21 (MaxPooling2D)  multiple                        0           Y          |
|                                                                                         |
| conv2d_70 (Conv2D)               multiple                        147584      Y          |
|                                                                                         |
| max_pooling2d_22 (MaxPooling2D)  multiple                        0           Y          |
|                                                                                         |
| conv2d_71 (Conv2D)               multiple                        147584      Y          |
|                                                                                         |
| max_pooling2d_23 (MaxPooling2D)  multiple                        0           Y          |
¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
 Decoder (Decoder)                  multiple                        0 (unused)  Y          
|¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯|
| conv2d_72 (Conv2D)               multiple                        147584      Y          |
|                                                                                         |
| up_sampling2d_9 (UpSampling2D)   multiple                        0           Y          |
|                                                                                         |
| conv2d_73 (Conv2D)               multiple                        147584      Y          |
|                                                                                         |
| up_sampling2d_10 (UpSampling2D)  multiple                        0           Y          |
|                                                                                         |
| conv2d_74 (Conv2D)               multiple                        147584      Y          |
|                                                                                         |
| up_sampling2d_11 (UpSampling2D)  multiple                        0           Y          |
|                                                                                         |
| conv2d_75 (Conv2D)               multiple                        1153        Y          |
¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
===========================================================================================
Total params: 742,657
Trainable params: 742,657
Non-trainable params: 0
___________________________________________________________________________________________

很好。但是正如您所看到的,在摘要中,Output Shape列没有提供信息。为了解决这个问题,我们可以使用一个类方法(build_graph),如下所示:

class Autoencoder(K.Model):
    def __init__(self, filters):
        super().__init__(name='Autoencoder')
        self.encoder = Encoder(filters)
        self.decoder = Decoder(filters)

    def call(self, input_features):
        x = input_features
        
        for layer in self.encoder.layers:
            x = layer(x)
        
        for layer in self.decoder.layers: 
            x = layer(x)
        
        return x
    
    def build_graph(self, input_shape):
        x = K.Input(shape=(input_shape))
        return K.Model(
            inputs=[x], outputs=self.call(x)
        )

摘要

model.build_graph(
    input_shape=(224, 224, 3)
).summary(expand_nested=True)
# OK

keras.utils.plot_model(
    model.build_graph(input_shape=(224, 224, 3)), 
    expand_nested=True,
    show_shapes=True,
    show_dtype=True, 
    show_layer_activations=True, 
    show_layer_names=True
)
# OK

就是这样。但是,如果你认为这应该是开箱即用的支持,请随时在keras-github中打开ticket。

相关问题