Pytorch矩阵乘法错误(形状不匹配)

ddarikpa  于 2023-10-20  发布在  其他
关注(0)|答案(2)|浏览(115)

我从pytorch实现了一个类,它看起来像

class GenderClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,32,(3,3)),
            nn.ReLU(),
            nn.Conv2d(32,64,(3,3)),
            nn.ReLU(),
            nn.Conv2d(64,64,(3,3)),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*104*74,2),
            nn.Sigmoid()
        )
        
    def forward(self,x):
        return self.model(x)

我在训练RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x7696 and 492544x2)时得到了这个错误。输入形状是110*80大小的图像,具有3个通道(3,110,80)。
训练代码看起来像

for epoch in range(30): # train for 10 epochs    
    for batch in dataset: 
        
        X,y = batch 
        X, y = X.to('cuda'), y 
        yhat = clf(X) 
        loss = loss_fn(yhat, y) 

        # Apply backprop 
        opt.zero_grad()
        loss.backward() 
        opt.step() 

    print(f"Epoch:{epoch} loss is {loss.item()}")
with open("model.pt","wb") as f:
    save(clf.state_dict(),f)```
um6iljoc

um6iljoc1#

两个合理的解决方案:
1.错误消息引用了输入批量大小的计算。它表示64x7696,这意味着它将输入解释为一批64个项目,每个项目的大小为7696。因此,也就是说,您的数据输入大小为(batch_size,3,110,80)。
要进行调试,请在将X var传递给模型之前打印它:

print(X.shape)

检查并验证批量大小,以及图像增强或预处理正在改变它。
1.最后一个卷积层的输出形状和全连接层的输入形状之间的另一个潜在失配。错误消息是关于将大小为(64x7696)的矩阵乘以另一个大小为(492544x2)的矩阵,这是无效的。主要的变化是线性层:我在第一个层之后添加了一个额外的线性层,输出大小为128个单位。您可以根据您的实验调整此线性层的输出大小。最后一个线性层输出2个单位,没有激活函数,因为你在损失函数中使用了sigmoid激活。

class GenderClassifier(nn.Module): 
    def __init__(self): 
       super().__init__() 
       self.model = nn.Sequential(
           nn.Conv2d(3, 32, (3, 3)), 
           nn.ReLU(), 
           nn.Conv2d(32, 64, (3, 3)), 
           nn.ReLU(), 
           nn.Conv2d(64, 64, (3, 3)), 
           nn.ReLU(), 
           nn.Flatten(), 
           nn.Linear(64 * 104 * 74, 128), 
           nn.ReLU(), nn.Linear(128, 2),
          )

     def forward(self, x):
        return self.model(x)
jv4diomz

jv4diomz2#

首先,您的输入形状结构应该如下所示:
(N_BATCH_SIZE、C_CHANNEL_SIZE、H_HEIGHT、W_WIDTH)
但是您提供的形状(3,110,180)与网络不兼容。
我测试了(5,3,110,180),(1000,3,110,180)[不同的批量大小)形状和网络运行良好。
但是当我输入模型(3,110,180)时,我得到了同样的错误。

相关问题