keras model.load_weights()不加载我之前使用model.save_weights()存储的权重

qzlgjiam  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(106)

我的目标是保存,然后使用save_weightsload_weights函数加载模型的权重。
为了向你展示一个最小的可重复的例子,这些是你可以在我的整个例子中使用的依赖项:

import numpy as np
import tensorflow as tf
from keras.initializers import he_uniform
from keras.layers import Conv2DTranspose, BatchNormalization, Reshape, Dense, Conv2D, Flatten
from keras.optimizers.legacy import Adam
from keras.src.datasets import mnist
from skimage.transform import resize
from sklearn.base import BaseEstimator
from tensorflow import keras

这是我的模型,一个(变分)自动编码器:

class VAE(keras.Model, BaseEstimator):
    def __init__(self, encoder, decoder, epochs=None, l_rate=None, batch_size=None, patience=None, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.epochs = epochs  
        self.l_rate = l_rate  
        self.batch_size = batch_size  
        self.patience = patience 
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    def call(self, inputs, training=None, mask=None):
        _, _, z = self.encoder(inputs)
        outputs = self.decoder(z)
        return outputs

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        data, labels = data
        with tf.GradientTape() as tape:
            # Forward pass
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)

            # Compute losses
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss

        # Compute gradient
        grads = tape.gradient(total_loss, self.trainable_weights)

        # Update weights
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # Update metrics
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

    def test_step(self, data):
        data, labels = data
        # Forward pass
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)

        # Compute losses
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(
                keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
            )
        )
        kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        total_loss = reconstruction_loss + kl_loss

        # Update metrics
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

这是编码器:

@keras.saving.register_keras_serializable()
class Encoder(keras.layers.Layer):
    def __init__(self, latent_dimension):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dimension

        seed = 42

        self.conv1 = Conv2D(filters=64, kernel_size=3, activation="relu", strides=2, padding="same",
                            kernel_initializer=he_uniform(seed))
        self.bn1 = BatchNormalization()

        self.conv2 = Conv2D(filters=128, kernel_size=3, activation="relu", strides=2, padding="same",
                            kernel_initializer=he_uniform(seed))
        self.bn2 = BatchNormalization()

        self.conv3 = Conv2D(filters=256, kernel_size=3, activation="relu", strides=2, padding="same",
                            kernel_initializer=he_uniform(seed))
        self.bn3 = BatchNormalization()

        self.flatten = Flatten()
        self.dense = Dense(units=100, activation="relu")

        self.z_mean = Dense(latent_dimension, name="z_mean")
        self.z_log_var = Dense(latent_dimension, name="z_log_var")

        self.sampling = sample

    def call(self, inputs, training=None, mask=None):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.flatten(x)
        x = self.dense(x)
        z_mean = self.z_mean(x)
        z_log_var = self.z_log_var(x)
        z = self.sampling(z_mean, z_log_var)
        return z_mean, z_log_var, z

其中sample函数定义如下:

def sample(z_mean, z_log_var):
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = tf.random.normal(shape=(batch, dim))
    stddev = tf.exp(0.5 * z_log_var)
    return z_mean + stddev * epsilon

最后是解码器:

@keras.saving.register_keras_serializable()
class Decoder(keras.layers.Layer):
    def __init__(self):
        super(Decoder, self).__init__()
        self.dense1 = Dense(units=4096, activation="relu")
        self.bn1 = BatchNormalization()

        self.dense2 = Dense(units=1024, activation="relu")
        self.bn2 = BatchNormalization()

        self.dense3 = Dense(units=4096, activation="relu")
        self.bn3 = BatchNormalization()

        seed = 42

        self.reshape = Reshape((4, 4, 256))
        self.deconv1 = Conv2DTranspose(filters=256, kernel_size=3, activation="relu", strides=2, padding="same",
                                       kernel_initializer=he_uniform(seed))
        self.bn4 = BatchNormalization()

        self.deconv2 = Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=1, padding="same",
                                       kernel_initializer=he_uniform(seed))
        self.bn5 = BatchNormalization()

        self.deconv3 = Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=2, padding="valid",
                                       kernel_initializer=he_uniform(seed))
        self.bn6 = BatchNormalization()

        self.deconv4 = Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=1, padding="valid",
                                       kernel_initializer=he_uniform(seed))
        self.bn7 = BatchNormalization()

        self.deconv5 = Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=2, padding="valid",
                                       kernel_initializer=he_uniform(seed))
        self.bn8 = BatchNormalization()

        self.deconv6 = Conv2DTranspose(filters=1, kernel_size=2, activation="sigmoid", padding="valid",
                                       kernel_initializer=he_uniform(seed))

    def call(self, inputs, training=None, mask=None):
        x = self.dense1(inputs)
        x = self.bn1(x)
        x = self.dense2(x)
        x = self.bn2(x)
        x = self.dense3(x)
        x = self.bn3(x)
        x = self.reshape(x)
        x = self.deconv1(x)
        x = self.bn4(x)
        x = self.deconv2(x)
        x = self.bn5(x)
        x = self.deconv3(x)
        x = self.bn6(x)
        x = self.deconv4(x)
        x = self.bn7(x)
        x = self.deconv5(x)
        x = self.bn8(x)
        decoder_outputs = self.deconv6(x)
        return decoder_outputs

下面是main代码:

def normalize(x):
    return (x - np.min(x)) / (np.max(x) - np.min(x))
    
def create_vae():
    latent_dimension = 25
    best_epochs = 2500
    best_l_rate = 10 ** -5
    best_batch_size = 32
    best_patience = 30

    encoder = Encoder(latent_dimension)
    decoder = Decoder()
    vae = VAE(encoder, decoder, best_epochs, best_l_rate, best_batch_size, best_patience)
    vae.compile(Adam(best_l_rate))
    return vae

if __name__ == '__main__':
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    new_shape = (40, 40)  # VAE deals with (None, 40, 40, 1) tensors
    x_train = np.array([resize(img, new_shape) for img in x_train])
    x_test = np.array([resize(img, new_shape) for img in x_test])
    x_train = np.expand_dims(x_train, axis=-1).astype("float32")
    x_test = np.expand_dims(x_test, axis=-1).astype("float32")

    x_train = normalize(x_train)
    x_test = normalize(x_test)

    # Let's consider the first 100 items only for speed purposes
    x_train = x_train[:100]
    y_train = y_train[:100]
    x_test = x_test[:100]
    y_test = y_test[:100]
    
    model = create_vae()
    model.fit(x_train, y_train, batch_size=64, epochs=10)
    weights_before_load = model.get_weights()
    model.save_weights("test-checkpoints/my-vae")

    del model

    model = create_vae()
    model.load_weights("test-checkpoints/my-vae")
    weights_after_load = model.get_weights()

    for layer_num, (w_before, w_after) in enumerate(zip(weights_before_load, weights_after_load), start=1):
        print(f"Layer {layer_num}:")
        print(f"Same weights? {w_before.all() == w_after.all()}")

这是输出:

Layer 1:
Same weights? True

Layer 2:
Same weights? False  # WHY FALSE HERE?

Layer 3:
Same weights? True

Layer 4:
Same weights? True

Layer 5:
Same weights? True

Layer 6:
Same weights? True

但我希望负载前后的重量是一样的!为什么我使用load_weights方法加载后,第2层中的权重不相同?
此外,这是我收到的警告列表:

WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.conv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.conv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn1.gamma
...
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.conv3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.dense.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.dense.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.z_mean.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.z_mean.bias
...
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn2.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.dense3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.dense3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv2.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv4.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv4.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv5.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv5.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv6.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv6.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn1.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn1.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv2.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn2.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn2.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.dense.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.dense.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_mean.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_mean.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_log_var.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_log_var.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn1.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn1.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense2.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn2.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn2.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn4.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn4.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv2.bias
...
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'v' for (root).decoder.deconv6.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'v' for (root).decoder.deconv6.bias

Process finished with exit code 0

如何解决此问题?
请注意,由于警告的大小,我截断了它们的输出。
这很奇怪,因为如果我在另一个模型上重新运行我的例子,你可以在这里看到:

def create_model():
    model = tf.keras.Sequential([
        keras.layers.Dense(512, activation='relu', input_shape=(784,)),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10)
    ])

    model.compile(optimizer=Adam(),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    return model

if __name__ == '__main__':
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    y_train = y_train[:100]
    y_test = y_test[:100]
    x_train = x_train[:100].reshape(-1, 28 * 28) / 255.0
    x_test = x_test[:100].reshape(-1, 28 * 28) / 255.0

    model = create_model()
    model.fit(x_train, y_train, batch_size=32, epochs=10)

    model.save_weights("test-checkpoints/my-model")
    weights_before = model.get_weights()

    del model

    model = create_model()
    model.load_weights("test-checkpoints/my-model")
    weights_after = model.get_weights()

    for layer_num, (w_before, w_after) in enumerate(zip(weights_before, weights_after), start=1):
        print(f"Layer {layer_num}:")
        print(f"Same weights? {w_before.all() == w_after.all()}")

然后你可以注意到model.save_weights()model.load_weights()工作正常,因为权重都是一样的:

Layer 1:
Same weights? True

Layer 2:
Same weights? True

Layer 3:
Same weights? True

Layer 4:
Same weights? True
nhjlsmyf

nhjlsmyf1#

在加载权重之前,我必须在这样的批处理上训练模型:

vae.train_on_batch(x_train[:1], x_train[:1])

相关问题