pytorch中是否有稀疏分类交叉熵的版本?

cu6pst1q  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(138)

我看到一个数独解算器CNN使用TensorFlow框架将稀疏分类交叉熵作为损失函数,我想知道Pytorch是否有类似的函数?如果没有,我怎么可能使用Pytorch计算2D阵列的损失?

1szpjjfi

1szpjjfi1#

下面是一个使用nn.CrossEntropyLoss进行图像分割的示例,图像分割的批处理包含大小为1、宽度为2、高度为2和3的类。
图像分割是一个像素级的分类问题。当然,你也可以使用nn.CrossEntropyLoss进行基本的图像分类。
问题中的数独问题可以被看作是一个图像分割问题,其中你有10个类(10位数字)(尽管神经网络不适合解决像数独这样的组合问题,因为它已经有了有效的精确解决算法)。

nn.CrossEntropyLoss直接接受地面真实值标签作为[0,N_CLASSES[中的整数(不需要对标签进行onehot编码)

import torch
from torch import nn
import numpy as np

# logits predicted

x = np.array([[
    [[1,0,0],[1,0,0]], # predict class 0 for pixel (0,0) and class 0 for pixel (0,1)
    [[0,1,0],[0,0,1]], # predict class 1 for pixel (1,0) and class 2 for pixel (1,1)
]])*5  # multiply by 5 to give bigger losses
print("logits map :")
print(x)

# ground truth labels

y = np.array([[
    [0,1], # must predict class 0 for pixel (0,0) and class 1 for pixel (0,1)
    [1,2], # must predict class 1 for pixel (1,0) and class 2 for pixel (1,1)
]])  
print("\nlabels map :")
print(y)

x=torch.Tensor(x).permute((0,3,1,2))  # shape of preds must be (N, C, H, W) instead of (N, H, W, C)
y=torch.Tensor(y).long() #  shape of labels must be (N, H, W) and type must be long integer

losses = nn.CrossEntropyLoss(reduction="none")(x, y)  # reduction="none" to get the loss by pixel 
print("\nLosses map :")
print(losses)

# notice that the loss is big only for pixel (0,1) where we predicted 0 instead of 1

相关问题