keras Tensorflow文本分类示例为什么需要from_logits=True?

ercv8c1e  于 2023-06-23  发布在  Git
关注(0)|答案(1)|浏览(170)

我正在运行Tensorflow here中的一个基本文本分类示例。
有一件事我不明白,为什么我们需要使用from_logits=TrueBinaryCrossentropy损失?当我试图删除它并将activation="sigmoid"添加到最后一个Dense层时,binary_accuracy在训练时根本不移动。

更改代码:

model = tf.keras.Sequential([
  layers.Embedding(max_features + 1, embedding_dim),
  layers.Dropout(0.2),
  layers.GlobalAveragePooling1D(),
  layers.Dropout(0.2),
  layers.Dense(1, activation="sigmoid")]) # <-- Add activation = sigmoid here

model.compile(loss=losses.BinaryCrossentropy(), # <-- Remove from_logits=True here
              optimizer='adam',
              metrics=tf.metrics.BinaryAccuracy(threshold=0.0))

epochs = 10
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs)

培训产出:

Epoch 1/10
    625/625 [==============================] - 4s 4ms/step - loss: 0.6635 - binary_accuracy: 0.4981 - val_loss: 0.6149 - val_binary_accuracy: 0.5076
    Epoch 2/10
    625/625 [==============================] - 2s 4ms/step - loss: 0.5492 - binary_accuracy: 0.4981 - val_loss: 0.4990 - val_binary_accuracy: 0.5076
    Epoch 3/10
    625/625 [==============================] - 2s 4ms/step - loss: 0.4453 - binary_accuracy: 0.4981 - val_loss: 0.4208 - val_binary_accuracy: 0.5076
    Epoch 4/10
    625/625 [==============================] - 2s 4ms/step - loss: 0.3792 - binary_accuracy: 0.4981 - val_loss: 0.3741 - val_binary_accuracy: 0.5076
    Epoch 5/10
    625/625 [==============================] - 3s 4ms/step - loss: 0.3360 - binary_accuracy: 0.4981 - val_loss: 0.3454 - val_binary_accuracy: 0.5076
    Epoch 6/10
    625/625 [==============================] - 3s 4ms/step - loss: 0.3054 - binary_accuracy: 0.4981 - val_loss: 0.3262 - val_binary_accuracy: 0.5076
    Epoch 7/10
    625/625 [==============================] - 3s 4ms/step - loss: 0.2813 - binary_accuracy: 0.4981 - val_loss: 0.3126 - val_binary_accuracy: 0.5076
    Epoch 8/10
    625/625 [==============================] - 3s 4ms/step - loss: 0.2616 - binary_accuracy: 0.4981 - val_loss: 0.3033 - val_binary_accuracy: 0.5076
    Epoch 9/10
    625/625 [==============================] - 3s 4ms/step - loss: 0.2456 - binary_accuracy: 0.4981 - val_loss: 0.2967 - val_binary_accuracy: 0.5076
    Epoch 10/10
    625/625 [==============================] - 2s 4ms/step - loss: 0.2306 - binary_accuracy: 0.4981 - val_loss: 0.2920 - val_binary_accuracy: 0.5076
bhmjp9jg

bhmjp9jg1#

看起来模型正在正常训练,但目前显示模型如何训练的计算方法是错误的。
我认为BinaryAccuracy中的threshold正在实现度量的结果。例如,因为您将损失函数的输入更改为sigmoid之后的一个,所以值的范围将在01之间,但您的BinaryAccuracythreshold现在是0.0,应该是0.5

如果您想根据需要修改模型架构,请尝试将该值更改为0.5

相关问题