Pytorch抱怨输入和标签批量大小不匹配

fivyi3re  于 2022-12-18  发布在  其他
关注(0)|答案(1)|浏览(473)

我使用Huggingface来实现一个使用BertForSequenceClassification.from_pretrained()的BERT模型。
模型试图预测24个类中的1个。我使用的批量大小为32,序列长度为66。
当我尝试在训练中调用模型时,出现以下错误:

ValueError: Expected input batch_size (32) to match target batch_size (768).

但是,我的目标形状是32x24。当调用模型时,似乎在某个地方将其展平为768x1。下面是我运行的一个测试:

for i in train_dataloader:
    i = tuple(t.to(device) for t in i)
    print(i[0].shape, i[1].shape, i[2].shape) # here i[2].shape is (32, 24)
    output = model(i[0], attention_mask=i[1], labels=i[2]) # here PyTorch complains that i[2]'s shape is now (768, 1)
    print(output.logits.shape)
    break

这将输出:

torch.Size([32, 66]) torch.Size([32, 66]) torch.Size([32, 24])
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-68-c69db6168cc3> in <module>
      2     i = tuple(t.to(device) for t in i)
      3     print(i[0].shape, i[1].shape, i[2].shape)
----> 4     output = model(i[0], attention_mask=i[1], labels=i[2])
      5     print(output.logits.shape)
      6     break

4 frames
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3024     if size_average is not None or reduce is not None:
   3025         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3026     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
   3027 
   3028 

ValueError: Expected input batch_size (32) to match target batch_size (768).

axr492tv

axr492tv1#

Pytorch的CrossEntropyLoss实现要求目标是整数索引,而不是一个热函数类向量,因此target的大小应该是[batch_size],而不是[batch_size,n_classes]
你可以非常简单地将类分解如下(假设每个类向量确实是one-hot):

raveler = torch.arange(0,n_classes).unsqueeze(0).expand(batch_size,n_classes)
target = (target * raveler).sum(dim = 1)

相关问题