pytorch 属性错误:“NoneType”对象没有属性“data”,fgsm攻击时出现训练错误

bn31dyow  于 2023-04-21  发布在  其他
关注(0)|答案(1)|浏览(467)

我试图训练一个代码,但我得到了这个错误。下面是我的代码:

for i in range(num_iter):
            # Forward pass to get logits
            logits = self(images)

            # Calculate loss
            loss = self.criterion(logits, masks)

            # Zero gradients
            self.optim.zero_grad()

            # Backward pass to compute gradients
            loss.backward()

            images = images.detach()

            images = torch.tensor(images, dtype=torch.float32, requires_grad=True)

            # Get the gradients of the loss w.r.t. the input image
            data_grad = images.grad.data

            # Generate the perturbed image using FGSM
            perturbed_images = fgsm(images, epsilon, data_grad)

            # Re-classify the perturbed image
            logits_perturbed = self(perturbed_images)

代码只是整个代码的一部分。
我看到的错误是:

File "/home/xx/xx/xx.py", line 73, in training_step
    data_grad = images.grad.data
AttributeError: 'NoneType' object has no attribute 'data'

尝试训练模型,期望它训练得很好,但它不起作用。

ctehm74n

ctehm74n1#

当第一次创建Tensorrequire_grad=True时,它的梯度初始值将是None,这会导致在尝试访问其数据属性时出现属性错误。

AttributeError: 'NoneType' object has no attribute 'data'

因此,您应该执行以下操作:首先创建可训练的图像,并将其传递给模型,如下所示

for i in range(num_iter):
            # Forward pass to get logits
            tr_images = torch.tensor(images, dtype=torch.float32, requires_grad=True) 
             # create trainable images(requires_grad=True)

            logits = self(tr_images)
            # Calculate loss
            loss = self.criterion(logits, masks)

            # Zero gradients
            self.optim.zero_grad() 

            # Backward pass to compute gradients
            loss.backward()

            # Get the gradients of the loss w.r.t. the input image
            data_grad = tr_images.grad.data

            # Generate the perturbed image using FGSM
            perturbed_images = fgsm(images, epsilon, data_grad)

            # Re-classify the perturbed image
            logits_perturbed = self(perturbed_images)

相关问题