我在模型中使用过于简单的自定义层时遇到输入形状问题。我知道我可以在模型外预处理数据,但为了提高可移植性和效率,必须使用模型编译此过程。鉴于我的自定义层“ImagePreprocessingLayer”,
class ImagePreprocessingLayer(tf.keras.layers.Layer):
def __init__(self):
super(ImagePreprocessingLayer, self).__init__()
self.trainable = False
def call(self, inputs):
# Determine the number of rows in the input image
n_rows = tf.shape(inputs)[0]
# If the number of rows is greater than 100, downsample to (100, 543, 3)
if n_rows > 100:
inputs = tf.image.resize(inputs, size=(100, 543))
# If the number of rows is less than 100, pad with zeros until row 100
elif n_rows < 100:
padding = tf.zeros(shape=(100 - n_rows, 543, 3), dtype=inputs.dtype)
inputs = tf.concat([inputs, padding], axis=0)
# Lastly filling na's with 0's
inputs = tf.where(tf.math.is_nan(inputs), tf.zeros_like(inputs), inputs)
return inputs
我想做模型,
# The height of inputs is unknown
inputs = tf.keras.layers.Input(shape=(None, 543, 3))
preprocessed_inputs = ImagePreprocessingLayer()(inputs)
x = tf.keras.layers.Flatten()(preprocessed_inputs)
x = tf.keras.layers.Dense(250, activation='relu')(x)
outputs = tf.keras.layers.Dense(250, activation='softmax')(x)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
在定义“preprocessed_inputs”时执行时,我收到错误,
ValueError: Exception encountered when calling layer "image_preprocessing_layer_5" (type ImagePreprocessingLayer).
in user code:
File "/tmp/ipykernel_27/4007721329.py", line 19, in call *
inputs = tf.concat([inputs, padding], axis=0)
ValueError: Shape must be rank 4 but is rank 3 for '{{node image_preprocessing_layer_5/cond/cond/concat}} = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32](image_preprocessing_layer_5/cond/cond/concat/Placeholder, image_preprocessing_layer_5/cond/cond/zeros, image_preprocessing_layer_5/cond/cond/concat/axis)' with input shapes: [?,?,543,3], [?,543,3], [].
Call arguments received by layer "image_preprocessing_layer_5" (type ImagePreprocessingLayer):
·输入=tf.Tensor(形状=(无,无,543,3),数据类型=浮点数32)
我知道这与输入维度有关,但有人知道这意味着什么吗?当直接在样本Tensor上使用时,我的层可以完美地按预期工作,
ImagePreprocessingLayer()(example_sample)
1条答案
按热度按时间xeufq47z1#
在你的代码中,
inputs
和padding
应该是兼容的,这样才能连接。另外,由于其中一个输入是None
,我认为layers.Flatten()
是不可能使用的。但是,这里有一种方法可以让它运行。定义模型
试验