在TensorFlow keras中,如何加载另一个模型的权重,同时跳过不兼容的层?

6ss1mwsb  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(141)

我试图从tf.keras.Model使用load_weights功能,但它似乎不能正常工作。
如果我调用model.load_weights(weights_path, by_name=True, skip_mismatch=True),我会得到一个形状不匹配的错误,这正是我希望"skip_mismatch"参数处理的问题。
这个代码片段是一个相对简单的mnist数据集案例,重现了我的错误。我在google colab中运行它,发生的事情就像我自己的代码一样。

import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os

#import mnist dataset in x,y fashion
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

#build a simple model with sequential dense layers
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28,28)),
    tf.keras.layers.Rescaling(1./255, input_shape=(28,28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
])

#compile the model
model.compile(optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
#create checkpoint directory
os.makedirs("checkpoints", exist_ok=True)
cp_callback =  tf.keras.callbacks.ModelCheckpoint("checkpoints/cp-{epoch:04d}.ckpt",
                save_weights_only=True,
            )

#train the model
history = model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test), callbacks=[cp_callback])

model2 = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28,28)),
    tf.keras.layers.Rescaling(1./255, input_shape=(28,28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(5)
])
#compile model2
model2.compile(optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])

#load weights of model into model2
model2.load_weights("checkpoints/cp-0010.ckpt",by_name=True, skip_mismatch=True)

我得到
“在尝试恢复具有形状(5,)和名称dense_5/bias:0的变量时,接收到与形状(10,)不兼容的Tensor。”无论我使用skip_mismatch=True还是skip_mismatch=False
我是不是用错了功能?什么是正确的使用方法?

ecfdbz9o

ecfdbz9o1#

如果你看一下load_weights和保存_weights的文档,应该会更清楚。
load_weights

按名称加载重量

  • 如果权重保存为通过model.save_weights()创建的.h5文件,则可以使用参数by_name=True。*

在这种情况下,仅当权重共享相同的名称时,才会将它们加载到层中。这对于微调或迁移学习模型很有用,其中一些层已经改变。
save_weights

保存_format“tf”或“h5”。* 如果保存_format为None,则以“.h5”或“.keras”结尾的文件路径将默认为HDF5*。否则,None将变为“TF”。默认为None。

ModelCheckpoint开始:

保存_weights_only如果为True,则仅保存模型的权重(model.保存_weights(filepath)),否则保存完整模型(model.保存(filepath))。

似乎你不保存在.h5格式,因此不能使用by_name=True。ModelCheckpoint只调用model.save_weights(path)而不调用save_format,并且您的保存路径结束既不是.keras也不是.h5
注意:现在我不能测试这个解决方案,它只是引用文档。如果有必要的话,我会回来测试的。

相关问题