keras.Model.fit无法正确使用生成器和稀疏分类交叉熵损失

atmip9wb  于 2024-01-08  发布在  其他
关注(0)|答案(1)|浏览(170)

tf.keras.Model.fit(x=generator)不能正确使用SparseCategoricalCrossentropy/sparce_categorical_crossentropy损失函数,其中生成器作为训练数据。使用ImageDataGenerator TensorFlow Keras时,准确性中报告的相同症状被杀死。
请告知此行为是否符合预期,或者如果代码不正确,请指出。
代码摘录。整个代码在底部。

# --------------------------------------------------------------------------------
# CIFAR 10
# --------------------------------------------------------------------------------
USE_SPARCE_LABEL = True

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train, x_validation, y_train, y_validation = train_test_split(
    x_train, y_train, test_size=0.2, random_state=42
)

# One Hot Encoding the labels when USE_SPARCE_LABEL is False
if not USE_SPARCE_LABEL:
    y_train = keras.utils.to_categorical(y_train, NUM_CLASSES)
    y_validation = keras.utils.to_categorical(y_validation, NUM_CLASSES)
    y_test = keras.utils.to_categorical(y_test, NUM_CLASSES)

# --------------------------------------------------------------------------------
# Model
# --------------------------------------------------------------------------------
model: Model = Model(
    inputs=inputs, outputs=outputs, name="cifar10"
)

# --------------------------------------------------------------------------------
# Compile
# --------------------------------------------------------------------------------
if USE_SPARCE_LABEL:
    loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)   # <--- cause incorrect behavior
else:
    loss_fn=tf.keras.losses.CategoricalCrossentropy(from_logits=False)

learning_rate = 1e-3
model.compile(
    optimizer=Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
    loss=loss_fn,     # <---- sparse categorical causes the incorrect behavior
    metrics=["accuracy"]
)

# --------------------------------------------------------------------------------
# Train 
# --------------------------------------------------------------------------------
batch_size = 16
number_of_epochs = 10

def data_label_generator(x, y):
    def _f():
        index = 0
        length = len(x)
        try: 
            while True:                
                yield x[index:index+batch_size], y[index:index+batch_size]
                index = (index + batch_size) % length
        except StopIteration:
            return
        
    return _f

earlystop_callback = tf.keras.callbacks.EarlyStopping(
    patience=5,
    restore_best_weights=True,
    monitor='val_accuracy'
)

steps_per_epoch = len(y_train) // batch_size
validation_steps = (len(y_validation) // batch_size) - 1  # To avoid run out of data for validation

history = model.fit(
    x=data_label_generator(x_train, y_train)(),  # <--- Generator
    batch_size=batch_size,
    epochs=number_of_epochs,
    verbose=1,
    validation_data=data_label_generator(x_validation, y_validation)(),
    shuffle=True,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    validation_batch_size=batch_size,
    callbacks=[
        earlystop_callback
    ]
)

字符串

症状

使用稀疏索引作为标签,SparseCategoricalCrossentropy作为损失函数(USE_SPARSE_LABEL=True)。准确度值变得不稳定和低,导致提前停止。

2500/2500 [...] - 24s 8ms/step - loss: 1.4824 - accuracy: 0.0998 - val_loss: 1.1893 - val_accuracy: 0.1003
Epoch 2/10
2500/2500 [...] - 21s 8ms/step - loss: 1.0730 - accuracy: 0.1010 - val_loss: 0.8896 - val_accuracy: 0.0832
Epoch 3/10
2500/2500 [...] - 20s 8ms/step - loss: 0.9272 - accuracy: 0.1016 - val_loss: 0.9150 - val_accuracy: 0.0720
Epoch 4/10
2500/2500 [...] - 20s 8ms/step - loss: 0.7987 - accuracy: 0.1019 - val_loss: 0.8087 - val_accuracy: 0.0864
Epoch 5/10
2500/2500 [...] - 20s 8ms/step - loss: 0.7081 - accuracy: 0.1012 - val_loss: 0.8707 - val_accuracy: 0.0928
Epoch 6/10
2500/2500 [...] - 21s 8ms/step - loss: 0.6056 - accuracy: 0.1019 - val_loss: 0.7688 - val_accuracy: 0.0851


使用One Hot Encoding作为标签,CategoricalCrossentropy作为损失函数(USE_SPARSE_LABEL=True)。按预期工作。

2500/2500 [...] - 24s 8ms/step - loss: 1.4146 - accuracy: 0.4997 - val_loss: 1.0906 - val_accuracy: 0.6105
Epoch 2/10
2500/2500 [...] - 21s 9ms/step - loss: 1.0306 - accuracy: 0.6375 - val_loss: 0.9779 - val_accuracy: 0.6532
Epoch 3/10
2500/2500 [...] - 22s 9ms/step - loss: 0.8780 - accuracy: 0.6925 - val_loss: 0.8194 - val_accuracy: 0.7127
Epoch 4/10
2500/2500 [...] - 21s 8ms/step - loss: 0.7641 - accuracy: 0.7315 - val_loss: 0.9330 - val_accuracy: 0.7014
Epoch 5/10
2500/2500 [...] - 21s 8ms/step - loss: 0.6797 - accuracy: 0.7614 - val_loss: 0.7908 - val_accuracy: 0.7311
Epoch 6/10
2500/2500 [...] - 21s 9ms/step - loss: 0.6182 - accuracy: 0.7841 - val_loss: 0.7371 - val_accuracy: 0.7533
Epoch 7/10
2500/2500 [...] - 21s 9ms/step - loss: 0.4981 - accuracy: 0.8217 - val_loss: 0.8221 - val_accuracy: 0.7373
Epoch 8/10
2500/2500 [...] - 22s 9ms/step - loss: 0.4363 - accuracy: 0.8437 - val_loss: 0.7865 - val_accuracy: 0.7525
Epoch 9/10
2500/2500 [...] - 23s 9ms/step - loss: 0.3962 - accuracy: 0.8596 - val_loss: 0.8198 - val_accuracy: 0.7505
Epoch 10/10
2500/2500 [...] - 22s 9ms/step - loss: 0.3463 - accuracy: 0.8776 - val_loss: 0.8472 - val_accuracy: 0.7512

代码

import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import (
    __version__
)

from keras.layers import (
    Layer,
    Normalization,
    Conv2D,
    MaxPooling2D,
    BatchNormalization,
    Dense,
    Flatten,
    Dropout,
    Reshape,
    Activation,
    ReLU,
    LeakyReLU,
)
from keras.models import (
    Model,
)
from keras.layers import (
    Layer
)
from keras.optimizers import (
    Adam
)
from sklearn.model_selection import train_test_split

print("TensorFlow version: {}".format(tf.__version__))
tf.keras.__version__ = __version__
print("Keras version: {}".format(tf.keras.__version__))

# --------------------------------------------------------------------------------
# CIFAR 10
# --------------------------------------------------------------------------------
NUM_CLASSES = 10
INPUT_SHAPE = (32, 32, 3)
USE_SPARCE_LABEL = False   # Setting False make it work as expected

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train, x_validation, y_train, y_validation = train_test_split(
    x_train, y_train, test_size=0.2, random_state=42
)

# One Hot Encoding the labels
if not USE_SPARCE_LABEL:
    y_train = keras.utils.to_categorical(y_train, NUM_CLASSES)
    y_validation = keras.utils.to_categorical(y_validation, NUM_CLASSES)
    y_test = keras.utils.to_categorical(y_test, NUM_CLASSES)

# --------------------------------------------------------------------------------
# Model
# --------------------------------------------------------------------------------
inputs = tf.keras.Input(
    name='image',
    shape=INPUT_SHAPE,
    dtype=tf.float32
) 

x = Conv2D(                                           
    filters=32, 
    kernel_size=(3, 3), 
    strides=(1, 1), 
    padding="same",
    activation='relu', 
    input_shape=INPUT_SHAPE
)(inputs)
x = BatchNormalization()(x)
x = Conv2D(                                           
    filters=64, 
    kernel_size=(3, 3), 
    strides=(1, 1), 
    padding="same",
    activation='relu'
)(x)
x = MaxPooling2D(                                     
    pool_size=(2, 2)
)(x)
x = Dropout(0.20)(x)

x = Conv2D(                                           
    filters=128, 
    kernel_size=(3, 3), 
    strides=(1, 1), 
    padding="same",
    activation='relu'
)(x)
x = BatchNormalization()(x)
x = MaxPooling2D(                                     
    pool_size=(2, 2)
)(x)
x = Dropout(0.20)(x)

x = Flatten()(x)
x = Dense(300, activation="relu")(x)
x = BatchNormalization()(x)
x = Dropout(0.20)(x)
x = Dense(200, activation="relu")(x)
outputs = Dense(NUM_CLASSES, activation="softmax")(x)

model: Model = Model(
    inputs=inputs, outputs=outputs, name="cifar10"
)

# --------------------------------------------------------------------------------
# Compile
# --------------------------------------------------------------------------------
learning_rate = 1e-3

if USE_SPARCE_LABEL:
    loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
else:
    loss_fn=tf.keras.losses.CategoricalCrossentropy(from_logits=False)

model.compile(
    optimizer=Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
    loss=loss_fn,
    metrics=["accuracy"]
)
model.summary()

# --------------------------------------------------------------------------------
# Train
# --------------------------------------------------------------------------------
batch_size = 16
number_of_epochs = 10

def data_label_generator(x, y):
    def _f():
        index = 0
        length = len(x)
        try: 
            while True:                
                yield x[index:index+batch_size], y[index:index+batch_size]
                index = (index + batch_size) % length
        except StopIteration:
            return
        
    return _f

earlystop_callback = tf.keras.callbacks.EarlyStopping(
    patience=5,
    restore_best_weights=True,
    monitor='val_accuracy'
)

steps_per_epoch = len(y_train) // batch_size
validation_steps = (len(y_validation) // batch_size) - 1  # -1 to avoid run out of data for validation

history = model.fit(
    x=data_label_generator(x_train, y_train)(),
    batch_size=batch_size,
    epochs=number_of_epochs,
    verbose=1,
    validation_data=data_label_generator(x_validation, y_validation)(),
    shuffle=True,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    validation_batch_size=batch_size,
    callbacks=[
        earlystop_callback
    ]
)

环境

TensorFlow version: 2.14.1
Keras version: 2.14.0
Python 3.10.12
Ubuntu 22.04LTS

解决方案

innat的回答奏效了。

model.compile(
    optimizer=Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
    #metrics=["accuracy"]
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')])
model.summary()

hlswsv35

hlswsv351#

metrics=["accuracy"]中发现的使用sparse目标向量的行为似乎是API中的一个潜在错误。根据文档,字符串标识符accuracy应该转换为适当的丢失示例。
当您传递字符串accuracyacc时,我们根据目标和模型输出的形状将其转换为tf.keras.metrics.BinaryAccuracytf.keras.metrics.CategoricalAccuracytf.keras.metrics.SparseCategoricalAccuracy之一
在这种情况下,您需要专门使用keras.metrics.SparseCategoricalAccuracy(name='accuracy')来使其工作。

相关问题