keras 如何让`model.fit`在调用一个没有可训练层的朴素模型时立即停止训练?

xesrikrc  于 2024-01-08  发布在  其他
关注(0)|答案(1)|浏览(196)

我正在为时间序列预测模型编写一个训练管道,我使用一个朴素的季节模型作为基线,它只输出输入的最后一个out_steps

class Naive(tf.keras.Model):
def __init__(self, out_steps: int,
             **kwargs):
    super().__init__(**kwargs)
    self.out_steps = out_steps

def call(self, inputs, training=None):
    features = inputs

    return features[:, -self.out_steps:, :]

字符串
然后我可以使用通用训练阶段:

def train_model(model_name, **model_params):
    model = instantiate_model(model_name, **model_params)
    model.compile(loss='mse', optimizer='adam')
    model.fit(train_dataset)


有没有办法让fit明白模型没有可训练的层,并立即停止?

vuktfyat

vuktfyat1#

只有一行 Package :

if model.trainable_variables:
    model.fit(...)

字符串
model.trainable_variables包含可训练变量,对于没有这些变量的模型,this将为空,if条件为False。
但是请注意,如果你使用Sequential创建模型,并且没有在第一层提供输入形状,也没有调用模型的build()函数,这样的模型在调用fit之前也不会有任何变量!这些变量只会在第一次调用模型时创建。所以你必须小心,或者添加这样的东西:

if model.trainable_variables:
    model.fit(...)
elif not model.built:
    raise SomeError  # or print a warning

相关问题