我的目标是保存,然后使用save_weights
和load_weights
函数加载模型的权重。
为了向你展示一个最小的可重复的例子,这些是你可以在我的整个例子中使用的依赖项:
import numpy as np
import tensorflow as tf
from keras.initializers import he_uniform
from keras.layers import Conv2DTranspose, BatchNormalization, Reshape, Dense, Conv2D, Flatten
from keras.optimizers.legacy import Adam
from keras.src.datasets import mnist
from skimage.transform import resize
from sklearn.base import BaseEstimator
from tensorflow import keras
这是我的模型,一个(变分)自动编码器:
class VAE(keras.Model, BaseEstimator):
def __init__(self, encoder, decoder, epochs=None, l_rate=None, batch_size=None, patience=None, **kwargs):
super().__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.epochs = epochs
self.l_rate = l_rate
self.batch_size = batch_size
self.patience = patience
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
def call(self, inputs, training=None, mask=None):
_, _, z = self.encoder(inputs)
outputs = self.decoder(z)
return outputs
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def train_step(self, data):
data, labels = data
with tf.GradientTape() as tape:
# Forward pass
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
# Compute losses
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
# Compute gradient
grads = tape.gradient(total_loss, self.trainable_weights)
# Update weights
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
# Update metrics
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
def test_step(self, data):
data, labels = data
# Forward pass
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
# Compute losses
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
# Update metrics
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
这是编码器:
@keras.saving.register_keras_serializable()
class Encoder(keras.layers.Layer):
def __init__(self, latent_dimension):
super(Encoder, self).__init__()
self.latent_dim = latent_dimension
seed = 42
self.conv1 = Conv2D(filters=64, kernel_size=3, activation="relu", strides=2, padding="same",
kernel_initializer=he_uniform(seed))
self.bn1 = BatchNormalization()
self.conv2 = Conv2D(filters=128, kernel_size=3, activation="relu", strides=2, padding="same",
kernel_initializer=he_uniform(seed))
self.bn2 = BatchNormalization()
self.conv3 = Conv2D(filters=256, kernel_size=3, activation="relu", strides=2, padding="same",
kernel_initializer=he_uniform(seed))
self.bn3 = BatchNormalization()
self.flatten = Flatten()
self.dense = Dense(units=100, activation="relu")
self.z_mean = Dense(latent_dimension, name="z_mean")
self.z_log_var = Dense(latent_dimension, name="z_log_var")
self.sampling = sample
def call(self, inputs, training=None, mask=None):
x = self.conv1(inputs)
x = self.bn1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.flatten(x)
x = self.dense(x)
z_mean = self.z_mean(x)
z_log_var = self.z_log_var(x)
z = self.sampling(z_mean, z_log_var)
return z_mean, z_log_var, z
其中sample
函数定义如下:
def sample(z_mean, z_log_var):
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.random.normal(shape=(batch, dim))
stddev = tf.exp(0.5 * z_log_var)
return z_mean + stddev * epsilon
最后是解码器:
@keras.saving.register_keras_serializable()
class Decoder(keras.layers.Layer):
def __init__(self):
super(Decoder, self).__init__()
self.dense1 = Dense(units=4096, activation="relu")
self.bn1 = BatchNormalization()
self.dense2 = Dense(units=1024, activation="relu")
self.bn2 = BatchNormalization()
self.dense3 = Dense(units=4096, activation="relu")
self.bn3 = BatchNormalization()
seed = 42
self.reshape = Reshape((4, 4, 256))
self.deconv1 = Conv2DTranspose(filters=256, kernel_size=3, activation="relu", strides=2, padding="same",
kernel_initializer=he_uniform(seed))
self.bn4 = BatchNormalization()
self.deconv2 = Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=1, padding="same",
kernel_initializer=he_uniform(seed))
self.bn5 = BatchNormalization()
self.deconv3 = Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=2, padding="valid",
kernel_initializer=he_uniform(seed))
self.bn6 = BatchNormalization()
self.deconv4 = Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=1, padding="valid",
kernel_initializer=he_uniform(seed))
self.bn7 = BatchNormalization()
self.deconv5 = Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=2, padding="valid",
kernel_initializer=he_uniform(seed))
self.bn8 = BatchNormalization()
self.deconv6 = Conv2DTranspose(filters=1, kernel_size=2, activation="sigmoid", padding="valid",
kernel_initializer=he_uniform(seed))
def call(self, inputs, training=None, mask=None):
x = self.dense1(inputs)
x = self.bn1(x)
x = self.dense2(x)
x = self.bn2(x)
x = self.dense3(x)
x = self.bn3(x)
x = self.reshape(x)
x = self.deconv1(x)
x = self.bn4(x)
x = self.deconv2(x)
x = self.bn5(x)
x = self.deconv3(x)
x = self.bn6(x)
x = self.deconv4(x)
x = self.bn7(x)
x = self.deconv5(x)
x = self.bn8(x)
decoder_outputs = self.deconv6(x)
return decoder_outputs
下面是main
代码:
def normalize(x):
return (x - np.min(x)) / (np.max(x) - np.min(x))
def create_vae():
latent_dimension = 25
best_epochs = 2500
best_l_rate = 10 ** -5
best_batch_size = 32
best_patience = 30
encoder = Encoder(latent_dimension)
decoder = Decoder()
vae = VAE(encoder, decoder, best_epochs, best_l_rate, best_batch_size, best_patience)
vae.compile(Adam(best_l_rate))
return vae
if __name__ == '__main__':
(x_train, y_train), (x_test, y_test) = mnist.load_data()
new_shape = (40, 40) # VAE deals with (None, 40, 40, 1) tensors
x_train = np.array([resize(img, new_shape) for img in x_train])
x_test = np.array([resize(img, new_shape) for img in x_test])
x_train = np.expand_dims(x_train, axis=-1).astype("float32")
x_test = np.expand_dims(x_test, axis=-1).astype("float32")
x_train = normalize(x_train)
x_test = normalize(x_test)
# Let's consider the first 100 items only for speed purposes
x_train = x_train[:100]
y_train = y_train[:100]
x_test = x_test[:100]
y_test = y_test[:100]
model = create_vae()
model.fit(x_train, y_train, batch_size=64, epochs=10)
weights_before_load = model.get_weights()
model.save_weights("test-checkpoints/my-vae")
del model
model = create_vae()
model.load_weights("test-checkpoints/my-vae")
weights_after_load = model.get_weights()
for layer_num, (w_before, w_after) in enumerate(zip(weights_before_load, weights_after_load), start=1):
print(f"Layer {layer_num}:")
print(f"Same weights? {w_before.all() == w_after.all()}")
这是输出:
Layer 1:
Same weights? True
Layer 2:
Same weights? False # WHY FALSE HERE?
Layer 3:
Same weights? True
Layer 4:
Same weights? True
Layer 5:
Same weights? True
Layer 6:
Same weights? True
但我希望负载前后的重量是一样的!为什么我使用load_weights
方法加载后,第2层中的权重不相同?
此外,这是我收到的警告列表:
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.conv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.conv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn1.gamma
...
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.conv3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.dense.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.dense.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.z_mean.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.z_mean.bias
...
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn2.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.dense3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.dense3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv2.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv4.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv4.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv5.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv5.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv6.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv6.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn1.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn1.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv2.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn2.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn2.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.dense.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.dense.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_mean.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_mean.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_log_var.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_log_var.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn1.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn1.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense2.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn2.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn2.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn4.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn4.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv2.bias
...
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'v' for (root).decoder.deconv6.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'v' for (root).decoder.deconv6.bias
Process finished with exit code 0
如何解决此问题?
请注意,由于警告的大小,我截断了它们的输出。
这很奇怪,因为如果我在另一个模型上重新运行我的例子,你可以在这里看到:
def create_model():
model = tf.keras.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
model.compile(optimizer=Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
return model
if __name__ == '__main__':
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
y_train = y_train[:100]
y_test = y_test[:100]
x_train = x_train[:100].reshape(-1, 28 * 28) / 255.0
x_test = x_test[:100].reshape(-1, 28 * 28) / 255.0
model = create_model()
model.fit(x_train, y_train, batch_size=32, epochs=10)
model.save_weights("test-checkpoints/my-model")
weights_before = model.get_weights()
del model
model = create_model()
model.load_weights("test-checkpoints/my-model")
weights_after = model.get_weights()
for layer_num, (w_before, w_after) in enumerate(zip(weights_before, weights_after), start=1):
print(f"Layer {layer_num}:")
print(f"Same weights? {w_before.all() == w_after.all()}")
然后你可以注意到model.save_weights()
和model.load_weights()
工作正常,因为权重都是一样的:
Layer 1:
Same weights? True
Layer 2:
Same weights? True
Layer 3:
Same weights? True
Layer 4:
Same weights? True
1条答案
按热度按时间nhjlsmyf1#
在加载权重之前,我必须在这样的批处理上训练模型: