我试图在训练集中测试批次。我的训练集在一个.tsv文件中,有3列:质量(1表示两个句子相似,0表示相反),#1字符串(第一个字符串),#2字符串(第二个字符串)。
我尝试将X和Y转换为其他类型,如列表,但错误仍然存在。
你有什么建议吗?谢谢!
def get_dataloaders(ds, lengths=[0.6, 0.2, 0.2], batch_size=32, seed=42, num_workers=2):
train_set, val_set, test_set = random_split(ds, lengths=lengths, generator=torch.Generator().manual_seed(seed))
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_loader, val_loader, test_loader
data_dir = "/data.tsv"
data = pd.read_csv(data_dir, sep='\t')
y = data['Quality'].values #dtype: int64
X = data[['#1 String', '#2 String']].values #dtype: O
data_input = np.column_stack((X, y)) #dtype: O
train_loader, val_loader, test_loader = get_dataloaders(data_input)
for batch in train_loader: #TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object
----------------------------------------------------------------------------------
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
return self.collate_fn(data)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
return collate(batch, collate_fn_map=default_collate_fn_map)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 119, in collate
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 169, in collate_numpy_array_fn
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object
字符串
编辑:我知道我在哪里写的,它是X和y的dtype,而不是它们的类型。y的dtype是int 64,X是Obj。但是我想当我将合并2列组合成一个值时,它必须是dtype:obj。我应该如何解决这个问题?
1条答案
按热度按时间csbfibhn1#
Pytorch不会接收对象类型输入。你需要先将字符串特征化为数字形式。
例如: