keras TensorFlow中的自定义图层输入形状

h7appiyu  于 2023-03-02  发布在  其他
关注(0)|答案(1)|浏览(186)

我在模型中使用过于简单的自定义层时遇到输入形状问题。我知道我可以在模型外预处理数据,但为了提高可移植性和效率,必须使用模型编译此过程。鉴于我的自定义层“ImagePreprocessingLayer”,

class ImagePreprocessingLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(ImagePreprocessingLayer, self).__init__()
        self.trainable = False
    
    def call(self, inputs):
        # Determine the number of rows in the input image
        n_rows = tf.shape(inputs)[0]
    
        # If the number of rows is greater than 100, downsample to (100, 543, 3)
        if n_rows > 100:
            inputs = tf.image.resize(inputs, size=(100, 543))

        # If the number of rows is less than 100, pad with zeros until row 100
        elif n_rows < 100:
            padding = tf.zeros(shape=(100 - n_rows, 543, 3), dtype=inputs.dtype)
            inputs = tf.concat([inputs, padding], axis=0)
    
        # Lastly filling na's with 0's
        inputs = tf.where(tf.math.is_nan(inputs), tf.zeros_like(inputs), inputs)

        return inputs

我想做模型,

# The height of inputs is unknown
inputs = tf.keras.layers.Input(shape=(None, 543, 3))

preprocessed_inputs = ImagePreprocessingLayer()(inputs)

x = tf.keras.layers.Flatten()(preprocessed_inputs)
x = tf.keras.layers.Dense(250, activation='relu')(x)

outputs = tf.keras.layers.Dense(250, activation='softmax')(x)

model = tf.keras.models.Model(inputs=inputs, outputs=outputs)

在定义“preprocessed_inputs”时执行时,我收到错误,

ValueError: Exception encountered when calling layer "image_preprocessing_layer_5" (type ImagePreprocessingLayer).

in user code:

File "/tmp/ipykernel_27/4007721329.py", line 19, in call  *
    inputs = tf.concat([inputs, padding], axis=0)

ValueError: Shape must be rank 4 but is rank 3 for '{{node image_preprocessing_layer_5/cond/cond/concat}} = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32](image_preprocessing_layer_5/cond/cond/concat/Placeholder, image_preprocessing_layer_5/cond/cond/zeros, image_preprocessing_layer_5/cond/cond/concat/axis)' with input shapes: [?,?,543,3], [?,543,3], [].

Call arguments received by layer "image_preprocessing_layer_5" (type ImagePreprocessingLayer):

·输入=tf.Tensor(形状=(无,无,543,3),数据类型=浮点数32)
我知道这与输入维度有关,但有人知道这意味着什么吗?当直接在样本Tensor上使用时,我的层可以完美地按预期工作,

ImagePreprocessingLayer()(example_sample)
xeufq47z

xeufq47z1#

在你的代码中,inputspadding应该是兼容的,这样才能连接。另外,由于其中一个输入是None,我认为layers.Flatten()是不可能使用的。但是,这里有一种方法可以让它运行。

class ImagePreprocessingLayer(layers.Layer):
    def call(self, inputs):
        shape = tf.shape(inputs)
        width, height = shape[0], shape[1]

        outputs = tf.cond(
            tf.greater_equal(height, 100),
            lambda: self.resize(inputs), 
            lambda: self.padded_resize(inputs, height)
        )
        outputs = tf.where(
            tf.math.is_nan(outputs), 
            tf.zeros_like(outputs), 
            outputs
        )
        return outputs
    
    def resize(self, inputs):
        outputs = tf.image.resize(inputs, size=(100, 543))
        return outputs
    
    def padded_resize(self, inputs, height):
        padding_dims = 100 - height
        padding_dims = padding_dims // 2
        outputs = tf.pad(
            inputs, [
                (0, 0), 
                (padding_dims, padding_dims), 
                (0, 0), 
                (0, 0)
            ]
        )
        return outputs

定义模型

inputs = keras.Input(shape=(None, 543, 3))
x = ImagePreprocessingLayer()(inputs)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(250, activation='relu')(x)
outputs = layers.Dense(250, activation='softmax')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.summary()
Model: "model_39"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_59 (InputLayer)       [(None, None, 543, 3)]    0         
                                                                 
 image_preprocessing_layer_4  (None, None, 543, 3)     0         
 2 (ImagePreprocessingLayer)                                     
                                                                 
 global_average_pooling2d_22  (None, 3)                0         
  (GlobalAveragePooling2D)                                       
                                                                 
 dense_65 (Dense)            (None, 250)               1000      
                                                                 
 dense_66 (Dense)            (None, 250)               62750     
                                                                 
=================================================================
Total params: 63,750
Trainable params: 63,750
Non-trainable params: 0
_________________________________________________________________

试验

print(model(tf.ones(shape=(3, 90, 543, 3))).shape)
print(model(tf.ones(shape=(3, 100, 543, 3))).shape)
print(model(tf.ones(shape=(3, 200, 543, 3))).shape)
(3, 250)
(3, 250)
(3, 250)

相关问题