pytorch 单标签、多类问题的自定义丢失

yvgpqqbh  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(191)

我有一个单标签、多类分类问题,也就是说,一个给定的样本正好在一个类中(比如说,类3),但是为了训练的目的,预测类2或5仍然是可以的,不会严重地惩罚模型。
例如,1个样本的地面真实值是5个类别的[0,1,1,0,1],而不是一个热点向量。这意味着,预测上述类别(2,3或5)的任何一个(不一定是全部)的模型是好的。
对于每个批次,预测的输出维度的形状为bs x n x nc,其中bs是批次大小,n是每个点的样本数,nc是类数。
对于每一个批次,我希望我的损失函数比较nc类的nTensor,然后在n上求平均值。
例如:当尺寸为32 x 8 x 5000时,一个批中有32个 * 批点 *(对于bs=32)。每个批处理点有8个 * 矢量点 ,每个矢量点有5000个类。对于给定的批处理点,我希望计算所有(8) 向量点 *,计算它们的平均值,并对其余的 * 批点 *(32)进行计算。最终损失将是每个 * 批点 * 的所有损失的损失。
我如何设计这样一个损失函数呢?任何帮助都将不胜感激
附注:如果问题有歧义,请告诉我

qyswt5oh

qyswt5oh1#

实现这一点的一种方法是在网络输出上使用S形函数,它消除了softmax函数所具有的类得分之间的隐式相互依赖性。
对于损失函数,您可以根据任何目标类的“最高”预测值计算损失,并忽略所有其他类预测值。

  1. # your model output
  2. y_out = torch.tensor([[0.1, 0.2, 0.95, 0.1, 0.01]], requires_grad=True)
  3. # class labels
  4. y = torch.tensor([[0,1,1,0,1]])

由于我们只关心"最高“类概率,所以我们将所有其他类得分设置为对于其中一个类所获得的最大值:

  1. class_mask = y == 1
  2. max_class_score = torch.max(y_out[class_mask])
  3. y_hat = torch.where(class_mask, max_class_score, y_out)

从中我们可以使用一个常规的交叉熵损失函数

  1. loss_fn = torch.nn.CrossEntropyLoss()
  2. loss = loss_fn(y_hat, y.float())
  3. loss.backward()

当检查梯度时,我们看到,这“仅”更新了获得最高分数的预测以及任何类之外的所有预测。

  1. >>> y_out.grad
  2. tensor([[ 0.3326, 0.0000, -0.6653, 0.3326, 0.0000]])

其他目标类的预测不会接收梯度更新。请注意,如果可能类的比率非常高,则这可能会减慢收敛速度。

展开查看全部

相关问题