无法构建具有整数输入的Keras图层

oyxsuwqo  于 2023-01-26  发布在  其他
关注(0)|答案(2)|浏览(237)

我有一个复杂的keras模型,其中一个层是一个自定义的预训练层,它需要“int32”作为输入。这个模型是作为一个继承自Model的类实现的,它是这样实现的:

class MyModel(tf.keras.models.Model):

    def __init__(self, size, input_shape):
        super(MyModel, self).__init__()
        self.layer = My_Layer()
        self.build(input_shape)

    def call(self, inputs):
        return self.layer(inputs)

但是当它到达self.build方法时,它抛出下一个错误:

ValueError: You cannot build your model by calling `build` if your layers do not support float type inputs. Instead, in order to instantiate and build your model, `call` your model on real tensor data (of the correct dtype).

我该怎么修呢?

7cjasjjr

7cjasjjr1#

使用model.build生成模型时会引发异常。
model.build 函数根据给定的输入形状构建模型。
引发错误的原因是,当我们尝试构建模型时,它首先调用一个带有x参数的模型,该参数取决于以下代码中的输入形状类型

if (isinstance(input_shape, list) and
    all(d is None or isinstance(d, int) for d in input_shape)):
  input_shape = tuple(input_shape)
if isinstance(input_shape, list):
  x = [base_layer_utils.generate_placeholders_from_shape(shape)
        for shape in input_shape]
elif isinstance(input_shape, dict):
  x = {
      k: base_layer_utils.generate_placeholders_from_shape(shape)
      for k, shape in input_shape.items()
  }
else:
  x = base_layer_utils.generate_placeholders_from_shape(input_shape)

x在这里是TensorFlow占位符。因此,当尝试使用x作为输入调用模型时,它将弹出TypeError,并且结果(block除外)将正常工作并给予错误。
假设输入形状为16x16,而不是使用self.build([(16,16)]),调用基于真实的Tensor的模型

inputs = tf.keras.Input(shape=(16,))
self.call(inputs)
rqenqsqc

rqenqsqc2#

变通方案

我在尝试将带有多个整型输入Tensor的模型导出为SavedModel时遇到了同样的问题。我通过覆盖build方法并手动指定self._build_input_shape来解决这个问题。

class MyModel(tf.keras.models.Model):
    def __init__(self, size, input_shape):
        super(MyModel, self).__init__()
        self.layer = My_Layer()
        self.build(input_shape)

    def call(self, inputs):
        return self.layer(inputs)

    def build(self, input_shapes):
        super(tf.keras.Model, self).build(input_shapes)

在原始代码中发生了什么

tf.keras.Model对象的默认build方法默认将输入Tensor视为浮点Tensor,最终抛出异常。tf.keras.Model的此类行为在此处定义,其中模型的输入由base_layer_utils.generate_placeholders_from_shape创建,它将dtype指定为float

变通方案会发生什么情况

由于tf.keras.Model.build最终会调用其超类的构建函数tf.keras.layer.Layer.build,因此解决方法跳过了导致问题的tf.keras.Model.build逻辑,但如果您依赖于tf.keras.Model.build中定义的其他逻辑,则可能必须在此之后添加补充代码

相关问题