即使连接层的输入的形状是相同的(如我所打印的),也有一个错误,显示不同的形状。
class UNet(keras.Model):
def __init__(self, shape=(572, 572, 1), **kwargs):
self.concat = keras.layers.Concatenate(axis=-1) # concats through depth
...
class CONV2_BLOCK(keras.layers.Layer):
...
class CONV_T(keras.layers.Layer):
def __init__(self, filters, **kwargs):
super().__init__(**kwargs)
self.conv_t = keras.layers.Conv2DTranspose(filters=filters, kernel_size=2, strides=2)
def call(self, inputs):
outputs = self.conv_t(inputs)
return outputs
class CROP(keras.layers.Layer):
def __init__(self, cropping, **kwargs):
super().__init__(**kwargs)
self.cropping = cropping
self.crop = keras.layers.Cropping2D(cropping=self.cropping)
def call(self, inputs):
outputs = self.crop(inputs)
return outputs
def call(self, inputs):
# self.conv_arr = [64, 128, 256, 512, 1024]
# self.crop_arr = [4, 17, 40, 88] down to up
x1 = self.CONV2_BLOCK(filters=64)(inputs)
print(x1.shape)
x = self.maxpool(x1)
print(x.shape)
...
x = self.CONV2_BLOCK(filters=1024)(x)
print(x.shape)
print(f"convt shape{self.CONV_T(filters=512)(x).shape}")
print(f"crop shape{self.CROP(cropping=4)(x4).shape}")
x = self.concat([self.CONV_T(filters=512)(x), self.CROP(cropping=4)(x4)])
x = self.CONV2_BLOCK(filters=512)(x)
...
x = self.concat([self.CONV_T(filters=64)(x), self.CROP(cropping=88)(x1)])
x = self.CONV2_BLOCK(filters=64)(x)
outputs = self.conv_sz1(x)
return outputs
以上代码的输出:
***[conv_t shape(2,56,56,512),crop shape(2,56,56,512)]***#已打印
错误
--->83 x = self.concat([self.CONV_T(filters=216)(x), self.CROP(cropping=17)(x3)])
84 x = self.CONV2_BLOCK(filters=256)(x)
两个形状中的尺寸1必须相等:shape[0] = [2,104,104,216]与shape[1] = [2,102,102,256] [Op:ConcatV2]名称:康卡特
1条答案
按热度按时间uttx8gqw1#
以下代码适用于我:
我改变了什么:
self.CONV_T(filters=216)(x)
至self.CONV_T(filters=256)(x)
self.CROP(cropping=17)(x3)
到self.CROP(cropping=16)(x3)
。