keras 加权稀疏分类交叉熵

lnlaulya  于 2022-11-13  发布在  其他
关注(0)|答案(3)|浏览(209)

我正在处理一个语义分割问题,其中我感兴趣的两个类(除了背景)在图像像素中是安静的不平衡的。我实际上是使用稀疏分类交叉熵作为损失,由于训练掩码的编码方式。有没有考虑类权重的版本?我还没能找到它,我以前从来没有研究过tf的源代码,但是API页面上的源代码链接似乎并没有链接到损失函数的真实的实现。

z8dt9xmd

z8dt9xmd1#

据我所知,你可以在www.example.com中使用类权重model.fit计算任何损失函数。我已经用它来计算categorical_cross_entropy,它是有效的。它只是用类权重来衡量损失,所以我认为没有理由不使用sparse_categorical_cross_entropy。

1szpjjfi

1szpjjfi2#

我认为this是在Keras中为sparse_categorical_crossentropy加权的解决方案。他们使用以下方法向数据集添加“第二个掩码”(包含掩码图像的每个类的权重)。

def add_sample_weights(image, label):
  # The weights for each class, with the constraint that:
  #     sum(class_weights) == 1.0
  class_weights = tf.constant([2.0, 2.0, 1.0])
  class_weights = class_weights/tf.reduce_sum(class_weights)

  # Create an image of `sample_weights` by using the label at each pixel as an 
  # index into the `class weights` .
  sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))

  return image, label, sample_weights

train_dataset.map(add_sample_weights).element_spec

然后,他们只使用tf.keras.losses.SparseCategoricalCrossentropy作为损失函数,拟合如下:

weighted_model.fit(
    train_dataset.map(add_sample_weights),
    epochs=1,
    steps_per_epoch=10)
osh3o9ms

osh3o9ms3#

看起来Keras稀疏分类交叉熵不适用于类权重。我已经找到了Keras稀疏分类交叉熵损失的this实现,它对我有效。链接中的实现有一个小bug,可能是由于一些版本不兼容,所以我已经修复了它。

import tensorflow as tf
    from tensorflow import keras

    class WeightedSCCE(keras.losses.Loss):
        def __init__(self, class_weight, from_logits=False, name='weighted_scce'):
            if class_weight is None or all(v == 1. for v in class_weight):
                self.class_weight = None
            else:
                self.class_weight = tf.convert_to_tensor(class_weight,
                    dtype=tf.float32)
            self.name = name
            self.reduction = keras.losses.Reduction.NONE
            self.unreduced_scce = keras.losses.SparseCategoricalCrossentropy(
                from_logits=from_logits, name=name,
                reduction=self.reduction)
    
        def __call__(self, y_true, y_pred, sample_weight=None):
            loss = self.unreduced_scce(y_true, y_pred, sample_weight)
            if self.class_weight is not None:
                weight_mask = tf.gather(self.class_weight, y_true)
                loss = tf.math.multiply(loss, weight_mask)
            return loss

损失应通过将权重列表或数组作为参数来调用。

相关问题