我通过tensorflow_dataset.load()
导入了CIFAR10数据集。
这样得到<PrefetchDataset element_spec={'id': TensorSpec(shape=(), dtype=tf.string, name=None), 'image': TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}>
这个数据集有一个id列。我想删除这个id。因为这个id在jax中会引起异常。为什么JAX会抛出一个未过滤的堆栈跟踪?我想我可以把它转换成panda Dataframe ,但他们是一个更好的方法吗?
2条答案
按热度按时间vx6bjr1n1#
试试看:
nzk0hqpo2#
如果使用
tfds.as_numpy
,则可以获取字典形式的数据集,并轻松删除列:这也是将数据集导入JAX所需的表单。