我试图从tf.keras.Model
使用load_weights
功能,但它似乎不能正常工作。
如果我调用model.load_weights(weights_path, by_name=True, skip_mismatch=True)
,我会得到一个形状不匹配的错误,这正是我希望"skip_mismatch"
参数处理的问题。
这个代码片段是一个相对简单的mnist
数据集案例,重现了我的错误。我在google colab中运行它,发生的事情就像我自己的代码一样。
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
#import mnist dataset in x,y fashion
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
#build a simple model with sequential dense layers
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28,28)),
tf.keras.layers.Rescaling(1./255, input_shape=(28,28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
#compile the model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#create checkpoint directory
os.makedirs("checkpoints", exist_ok=True)
cp_callback = tf.keras.callbacks.ModelCheckpoint("checkpoints/cp-{epoch:04d}.ckpt",
save_weights_only=True,
)
#train the model
history = model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test), callbacks=[cp_callback])
model2 = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28,28)),
tf.keras.layers.Rescaling(1./255, input_shape=(28,28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(5)
])
#compile model2
model2.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#load weights of model into model2
model2.load_weights("checkpoints/cp-0010.ckpt",by_name=True, skip_mismatch=True)
我得到
“在尝试恢复具有形状(5,)和名称dense_5/bias:0的变量时,接收到与形状(10,)不兼容的Tensor。”无论我使用skip_mismatch=True
还是skip_mismatch=False
。
我是不是用错了功能?什么是正确的使用方法?
1条答案
按热度按时间ecfdbz9o1#
如果你看一下load_weights和保存_weights的文档,应该会更清楚。
从
load_weights
:按名称加载重量
在这种情况下,仅当权重共享相同的名称时,才会将它们加载到层中。这对于微调或迁移学习模型很有用,其中一些层已经改变。
save_weights
:保存_format“tf”或“h5”。* 如果保存_format为None,则以“.h5”或“.keras”结尾的文件路径将默认为HDF5*。否则,None将变为“TF”。默认为None。
从
ModelCheckpoint
开始:保存_weights_only如果为True,则仅保存模型的权重(model.保存_weights(filepath)),否则保存完整模型(model.保存(filepath))。
似乎你不保存在.h5格式,因此不能使用
by_name=True
。ModelCheckpoint只调用model.save_weights(path)
而不调用save_format
,并且您的保存路径结束既不是.keras
也不是.h5
。注意:现在我不能测试这个解决方案,它只是引用文档。如果有必要的话,我会回来测试的。