keras UNet:InvalidArgumentError:调用图层“concatenate”时遇到异常(类型Concatenate)

t1qtbnec  于 2023-04-21  发布在  其他
关注(0)|答案(1)|浏览(121)

即使连接层的输入的形状是相同的(如我所打印的),也有一个错误,显示不同的形状。

class UNet(keras.Model):
  def __init__(self, shape=(572, 572, 1), **kwargs):
    self.concat = keras.layers.Concatenate(axis=-1) # concats through depth
    ...

  class CONV2_BLOCK(keras.layers.Layer):
    ...      

  class CONV_T(keras.layers.Layer):
    def __init__(self, filters, **kwargs):
      super().__init__(**kwargs)
      self.conv_t = keras.layers.Conv2DTranspose(filters=filters, kernel_size=2, strides=2)

    def call(self, inputs):
      outputs = self.conv_t(inputs)
      return outputs

  class CROP(keras.layers.Layer):
    def __init__(self, cropping, **kwargs):
      super().__init__(**kwargs)
      self.cropping = cropping
      self.crop = keras.layers.Cropping2D(cropping=self.cropping)

    def call(self, inputs):
      outputs = self.crop(inputs)
      return outputs

  def call(self, inputs):
    # self.conv_arr = [64, 128, 256, 512, 1024]
    # self.crop_arr = [4, 17, 40, 88] down to up

    x1 = self.CONV2_BLOCK(filters=64)(inputs)
    print(x1.shape)
    x = self.maxpool(x1)
    print(x.shape)
    ...

    x = self.CONV2_BLOCK(filters=1024)(x)
    print(x.shape)

    print(f"convt shape{self.CONV_T(filters=512)(x).shape}")
    print(f"crop shape{self.CROP(cropping=4)(x4).shape}")
    x = self.concat([self.CONV_T(filters=512)(x), self.CROP(cropping=4)(x4)])
    x = self.CONV2_BLOCK(filters=512)(x)

    ...

    x = self.concat([self.CONV_T(filters=64)(x), self.CROP(cropping=88)(x1)])
    x = self.CONV2_BLOCK(filters=64)(x)

    outputs = self.conv_sz1(x)

    return outputs

以上代码的输出:

***[conv_t shape(2,56,56,512),crop shape(2,56,56,512)]***#已打印

错误

--->83 x = self.concat([self.CONV_T(filters=216)(x), self.CROP(cropping=17)(x3)])84 x = self.CONV2_BLOCK(filters=256)(x)
两个形状中的尺寸1必须相等:shape[0] = [2,104,104,216]与shape[1] = [2,102,102,256] [Op:ConcatV2]名称:康卡特

uttx8gqw

uttx8gqw1#

以下代码适用于我:

class UNet(keras.Model):
  """
  argument: input_shape=(572, 572, 1) => default
  """
  def __init__(self, shape=(572, 572, 1), **kwargs):
    super().__init__(**kwargs)
    self.shape = shape
    self.maxpool = keras.layers.MaxPool2D(pool_size=2, strides=2)
    self.concat = keras.layers.Concatenate(axis=-1) # concats through depth
    self.conv_sz1 = keras.layers.Conv2D(filters=2, kernel_size=1, padding="same")

  class CONV2_BLOCK(keras.layers.Layer):
    def __init__(self, filters, **kwargs):
      super().__init__(**kwargs)
      self.filters = filters
      self.conv1 = keras.layers.Conv2D(filters=self.filters, kernel_size=3, use_bias=False)
      self.batchnorm = keras.layers.BatchNormalization()
      self.relu = keras.layers.Activation(keras.activations.relu)
      self.conv2 = keras.layers.Conv2D(filters=self.filters, kernel_size=3, use_bias=False)

    def call(self, inputs):
      x = self.conv1(inputs)
      x = self.batchnorm(x)
      x = self.relu(x)
      x = self.conv2(x)
      x = self.batchnorm(x)
      outputs = self.relu(x)
      return outputs

  class CONV_T(keras.layers.Layer):
    def __init__(self, filters, **kwargs):
      super().__init__(**kwargs)
      self.conv_t = keras.layers.Conv2DTranspose(filters=filters, kernel_size=2, strides=2)

    def call(self, inputs):
      outputs = self.conv_t(inputs)
      return outputs

  class CROP(keras.layers.Layer):
    def __init__(self, cropping, **kwargs):
      super().__init__(**kwargs)
      self.crop = keras.layers.Cropping2D(cropping=cropping)

    def call(self, inputs):
      outputs = self.crop(inputs)
      return outputs

  def call(self, inputs):
    # self.conv_arr = [64, 128, 256, 512, 1024]
    # self.crop_arr = [4, 17, 40, 88] down to up

    x1 = self.CONV2_BLOCK(filters=64)(inputs)
    print(x1.shape)
    x = self.maxpool(x1)
    print(x.shape)

    x2 = self.CONV2_BLOCK(filters=128)(x)
    print(x2.shape)
    x = self.maxpool(x2)
    print(x.shape)

    x3 = self.CONV2_BLOCK(filters=256)(x)
    print(x3.shape)
    x = self.maxpool(x3)
    print(x.shape)

    x4 = self.CONV2_BLOCK(filters=512)(x)
    print(x4.shape)
    x = self.maxpool(x4)
    print(x.shape)

    x = self.CONV2_BLOCK(filters=1024)(x)
    print(x.shape)

    x = self.concat([self.CONV_T(filters=512)(x), self.CROP(cropping=4)(x4)])
    x = self.CONV2_BLOCK(filters=512)(x)
    
    # line edited
    x = self.concat([self.CONV_T(filters=256)(x), self.CROP(cropping=16)(x3)])
    x = self.CONV2_BLOCK(filters=256)(x)

    x = self.concat([self.CONV_T(filters=128)(x), self.CROP(cropping=40)(x2)])
    x = self.CONV2_BLOCK(filters=128)(x)

    x = self.concat([self.CONV_T(filters=64)(x), self.CROP(cropping=88)(x1)])
    x = self.CONV2_BLOCK(filters=64)(x)

    outputs = self.conv_sz1(x)

    return outputs

我改变了什么:

  • self.CONV_T(filters=216)(x)self.CONV_T(filters=256)(x)
  • 在同一行:self.CROP(cropping=17)(x3)self.CROP(cropping=16)(x3)

相关问题