pytorch 使用我定制的损失函数,损失不能反向传播到模型参数

yqyhoc1h  于 2024-01-09  发布在  其他
关注(0)|答案(1)|浏览(229)

我设计了一个定制的损失:

class CustomIndicesEdgeAccuracyLoss(torch.nn.Module):
    def __init__(self, num_classes: int, selected_indices: list):
        super(CustomIndicesEdgeAccuracyLoss, self).__init__()
        self.num_classes = num_classes
        self.selected_indices = selected_indices

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        batch_size, num_classes, feature_size = input.shape
        selected_input = input[::, ::, self.selected_indices]
        selected_target = target[::, self.selected_indices]
        selected_preds = torch.argmax(selected_input, dim=1)
        edge_acc = torch.eq(selected_preds, selected_target).sum()/torch.numel(selected_preds)
        loss = 1 – edge_acc
        loss.requires_grad = True

        return loss

字符串
但是loss不会反向传播到模型参数,换句话说,模型参数的梯度总是为0,模型参数不能更新。可能的原因是什么?我应该如何修改代码?
下面是forward()的局部变量的一些信息:

input.shape: torch.Size([64, 3, 5])
target.shape:torch.Size([64, 5])
selected_input.shape: torch.Size([64, 3, 2]) 
selected_target.shape:torch.Size([64, 2])


PS.我在here中问了同样的问题,所以我会把答案从这篇文章复制到那里,反之亦然。

3zwtqj6y

3zwtqj6y1#

你不能使用精度作为损失函数,因为它是不可微的。通过torch.argmax没有梯度传播。你需要使用一个可微的损失函数。

相关问题