如何 编辑 tensorflow 数据 集 ?

u5rb5r59  于 2022-11-16  发布在  其他
关注(0)|答案(2)|浏览(136)

我通过tensorflow_dataset.load()导入了CIFAR10数据集。
这样得到<PrefetchDataset element_spec={'id': TensorSpec(shape=(), dtype=tf.string, name=None), 'image': TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}>
这个数据集有一个id列。我想删除这个id。因为这个id在jax中会引起异常。为什么JAX会抛出一个未过滤的堆栈跟踪?我想我可以把它转换成panda Dataframe ,但他们是一个更好的方法吗?

vx6bjr1n

vx6bjr1n1#

试试看:

result = ds.map(lambda x: {
    'image': x['image'],
    'label': x['label']
})

result.element_spec
>>> <MapDataset element_spec={'image': TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}>
nzk0hqpo

nzk0hqpo2#

如果使用tfds.as_numpy,则可以获取字典形式的数据集,并轻松删除列:

import tensorflow_datasets as tfds

ds_builder = tfds.builder('cifar10')
ds_builder.download_and_prepare()
data = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))

print(type(data))
# <class 'dict'>

print(data.keys())
# dict_keys(['id', 'image', 'label'])

del data['id']
print(data.keys())
# dict_keys(['image', 'label'])

这也是将数据集导入JAX所需的表单。

相关问题