HuggingFace数据集到PyTorch

8yoxcaq7  于 2022-12-13  发布在  其他
关注(0)|答案(1)|浏览(187)

我想从拥抱脸加载数据集,将其转换为PYtorch数据加载器。这是我的脚本。

dataset = load_dataset('cats_vs_dogs', split='train[:1000]')
trans = transforms.Compose([transforms.Resize((256,256)), transforms.PILToTensor()])

def encode(examples):
  num = random.randint(0,1)
  if num:
    examples["image"] = [image.convert("RGB").transpose(Image.FLIP_TOP_BOTTOM) for image in examples["image"]]
    examples['labels']= [1]* len(examples['image'])
  else:
    examples["image"] = [image.convert("RGB") for image in examples["image"]]
    examples['labels']=[0]*len(examples['image'])
  return examples

def annot(examples):
  examples['image'] = [trans(img) for img in examples['image']]
  return examples

dataset = dataset.map(encode, batched=True, remove_columns=['image_file_path'], batch_size=256)

dataset.set_transform(annot)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)

在这里,我随机翻转图像,并决定标签上的翻转。
如果我打印数据集,

>>> print(dataset)
Dataset({
    features: ['image', 'labels'],
    num_rows: 1000
})

如果我检查任何一个例子

>>> dataset['image'][0].shape
torch.Size([3, 256, 256])

转换后得到的错误是

>>> next(iter(dataloader))['image']
AttributeError: 'bytes' object has no attribute 'dtype'

完整的回电是

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-180-b773e67ad66a> in <module>()
----> 1 next(iter(dataloader))['image']

16 frames
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    519             if self._sampler_iter is None:
    520                 self._reset()
--> 521             data = self._next_data()
    522             self._num_yielded += 1
    523             if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    559     def _next_data(self):
    560         index = self._next_index()  # may raise StopIteration
--> 561         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    562         if self._pin_memory:
    563             data = _utils.pin_memory.pin_memory(data)

/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     47     def fetch(self, possibly_batched_index):
     48         if self.auto_collation:
---> 49             data = [self.dataset[idx] for idx in possibly_batched_index]
     50         else:
     51             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     47     def fetch(self, possibly_batched_index):
     48         if self.auto_collation:
---> 49             data = [self.dataset[idx] for idx in possibly_batched_index]
     50         else:
     51             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.7/dist-packages/datasets/arrow_dataset.py in __getitem__(self, key)
   1764         """Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools)."""
   1765         return self._getitem(
-> 1766             key,
   1767         )
   1768 

/usr/local/lib/python3.7/dist-packages/datasets/arrow_dataset.py in _getitem(self, key, decoded, **kwargs)
   1749         pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
   1750         formatted_output = format_table(
-> 1751             pa_subtable, key, formatter=formatter, format_columns=format_columns, output_all_columns=output_all_columns
   1752         )
   1753         return formatted_output

/usr/local/lib/python3.7/dist-packages/datasets/formatting/formatting.py in format_table(table, key, formatter, format_columns, output_all_columns)
    530     python_formatter = PythonFormatter(features=None)
    531     if format_columns is None:
--> 532         return formatter(pa_table, query_type=query_type)
    533     elif query_type == "column":
    534         if key in format_columns:

/usr/local/lib/python3.7/dist-packages/datasets/formatting/formatting.py in __call__(self, pa_table, query_type)
    279     def __call__(self, pa_table: pa.Table, query_type: str) -> Union[RowFormat, ColumnFormat, BatchFormat]:
    280         if query_type == "row":
--> 281             return self.format_row(pa_table)
    282         elif query_type == "column":
    283             return self.format_column(pa_table)

/usr/local/lib/python3.7/dist-packages/datasets/formatting/torch_formatter.py in format_row(self, pa_table)
     56     def format_row(self, pa_table: pa.Table) -> dict:
     57         row = self.numpy_arrow_extractor().extract_row(pa_table)
---> 58         return self.recursive_tensorize(row)
     59 
     60     def format_column(self, pa_table: pa.Table) -> "torch.Tensor":

/usr/local/lib/python3.7/dist-packages/datasets/formatting/torch_formatter.py in recursive_tensorize(self, data_struct)
     52 
     53     def recursive_tensorize(self, data_struct: dict):
---> 54         return map_nested(self._recursive_tensorize, data_struct, map_list=False)
     55 
     56     def format_row(self, pa_table: pa.Table) -> dict:

/usr/local/lib/python3.7/dist-packages/datasets/utils/py_utils.py in map_nested(function, data_struct, dict_only, map_list, map_tuple, map_numpy, num_proc, types, disable_tqdm, desc)
    314         mapped = [
    315             _single_map_nested((function, obj, types, None, True, None))
--> 316             for obj in logging.tqdm(iterable, disable=disable_tqdm, desc=desc)
    317         ]
    318     else:

/usr/local/lib/python3.7/dist-packages/datasets/utils/py_utils.py in <listcomp>(.0)
    314         mapped = [
    315             _single_map_nested((function, obj, types, None, True, None))
--> 316             for obj in logging.tqdm(iterable, disable=disable_tqdm, desc=desc)
    317         ]
    318     else:

/usr/local/lib/python3.7/dist-packages/datasets/utils/py_utils.py in _single_map_nested(args)
    265 
    266     if isinstance(data_struct, dict):
--> 267         return {k: _single_map_nested((function, v, types, None, True, None)) for k, v in pbar}
    268     else:
    269         mapped = [_single_map_nested((function, v, types, None, True, None)) for v in pbar]

/usr/local/lib/python3.7/dist-packages/datasets/utils/py_utils.py in <dictcomp>(.0)
    265 
    266     if isinstance(data_struct, dict):
--> 267         return {k: _single_map_nested((function, v, types, None, True, None)) for k, v in pbar}
    268     else:
    269         mapped = [_single_map_nested((function, v, types, None, True, None)) for v in pbar]

/usr/local/lib/python3.7/dist-packages/datasets/utils/py_utils.py in _single_map_nested(args)
    249     # Singleton first to spare some computation
    250     if not isinstance(data_struct, dict) and not isinstance(data_struct, types):
--> 251         return function(data_struct)
    252 
    253     # Reduce logging to keep things readable in multiprocessing with tqdm

/usr/local/lib/python3.7/dist-packages/datasets/formatting/torch_formatter.py in _recursive_tensorize(self, data_struct)
     49             if data_struct.dtype == np.object:  # pytorch tensors cannot be instantied from an array of objects
     50                 return [self.recursive_tensorize(substruct) for substruct in data_struct]
---> 51         return self._tensorize(data_struct)
     52 
     53     def recursive_tensorize(self, data_struct: dict):

/usr/local/lib/python3.7/dist-packages/datasets/formatting/torch_formatter.py in _tensorize(self, value)
     36 
     37         default_dtype = {}
---> 38         if np.issubdtype(value.dtype, np.integer):
     39             default_dtype = {"dtype": torch.int64}
     40         elif np.issubdtype(value.dtype, np.floating):

我如何解决这个问题,并将数据集传递给dl模型。谢谢。

1dkrff03

1dkrff031#

我认为您可能只需要使用set_format(),例如:

...
dataset.set_transform(annot)
dataset.set_format("torch")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)

相关问题