pytorch TypeError:default_collate:batch必须包含Tensor、numpy数组、数字、字典或列表;找到对象

6l7fqoea  于 2023-11-19  发布在  其他
关注(0)|答案(1)|浏览(122)

我试图在训练集中测试批次。我的训练集在一个.tsv文件中,有3列:质量(1表示两个句子相似,0表示相反),#1字符串(第一个字符串),#2字符串(第二个字符串)。
我尝试将X和Y转换为其他类型,如列表,但错误仍然存在。
你有什么建议吗?谢谢!

def get_dataloaders(ds, lengths=[0.6, 0.2, 0.2], batch_size=32, seed=42, num_workers=2):
    train_set, val_set, test_set = random_split(ds, lengths=lengths, generator=torch.Generator().manual_seed(seed))

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_loader, val_loader, test_loader

data_dir = "/data.tsv"
data = pd.read_csv(data_dir, sep='\t')
y = data['Quality'].values                             #dtype: int64
X = data[['#1 String', '#2 String']].values            #dtype: O

data_input = np.column_stack((X, y))                   #dtype: O

train_loader, val_loader, test_loader = get_dataloaders(data_input)

for batch in train_loader:     #TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object

----------------------------------------------------------------------------------

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 119, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 169, in collate_numpy_array_fn
    raise TypeError(default_collate_err_msg_format.format(elem.dtype))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object

字符串

编辑:我知道我在哪里写的,它是X和y的dtype,而不是它们的类型。y的dtype是int 64,X是Obj。但是我想当我将合并2列组合成一个值时,它必须是dtype:obj。我应该如何解决这个问题?

csbfibhn

csbfibhn1#

Pytorch不会接收对象类型输入。你需要先将字符串特征化为数字形式。
例如:

相关问题