我有一个单标签、多类分类问题,也就是说,一个给定的样本正好在一个类中(比如说,类3),但是为了训练的目的,预测类2或5仍然是可以的,不会严重地惩罚模型。
例如,1个样本的地面真实值是5个类别的[0,1,1,0,1],而不是一个热点向量。这意味着,预测上述类别(2,3或5)的任何一个(不一定是全部)的模型是好的。
对于每个批次,预测的输出维度的形状为bs x n x nc
,其中bs是批次大小,n是每个点的样本数,nc是类数。
对于每一个批次,我希望我的损失函数比较nc
类的n
Tensor,然后在n
上求平均值。
例如:当尺寸为32 x 8 x 5000时,一个批中有32个 * 批点 *(对于bs=32)。每个批处理点有8个 * 矢量点 ,每个矢量点有5000个类。对于给定的批处理点,我希望计算所有(8) 向量点 *,计算它们的平均值,并对其余的 * 批点 *(32)进行计算。最终损失将是每个 * 批点 * 的所有损失的损失。
我如何设计这样一个损失函数呢?任何帮助都将不胜感激
附注:如果问题有歧义,请告诉我
1条答案
按热度按时间qyswt5oh1#
实现这一点的一种方法是在网络输出上使用S形函数,它消除了softmax函数所具有的类得分之间的隐式相互依赖性。
对于损失函数,您可以根据任何目标类的“最高”预测值计算损失,并忽略所有其他类预测值。
由于我们只关心"最高“类概率,所以我们将所有其他类得分设置为对于其中一个类所获得的最大值:
从中我们可以使用一个常规的交叉熵损失函数
当检查梯度时,我们看到,这“仅”更新了获得最高分数的预测以及任何类之外的所有预测。
其他目标类的预测不会接收梯度更新。请注意,如果可能类的比率非常高,则这可能会减慢收敛速度。