keras 如何迭代测试数据集批处理?

t3psigkw  于 2022-11-13  发布在  其他
关注(0)|答案(1)|浏览(203)

我有一个关于测试模型的问题。我使用tf.keras.utils.image_dataset_from_directory创建了一个模型测试集,如下所示:

batch_size = 32
test_dataset = tf.keras.utils.image_dataset_from_directory(
    '/content/drive/MyDrive/test',
    image_size = (224, 224),
    batch_size = batch_size,
    shuffle = False
)

然后我得到的输出是Found 150 files belonging to 3 classes.之后,我想使用以下语句迭代测试数据集批处理:

labels_batch = []
for dataset in test_dataset.unbatch():
  image_batch, label_batch = dataset
  labels = label_batch.numpy()
  labels_batch.append(labels)

我了解到,在数据集〈类'元组'〉的结构中,由两个位置组成,分别是image_batchlabel_batch,它们是**〈类'tensorflow.python.framework.ops. EagerTensor'〉**。
因此,image_batch[0]应表示test_dataset中的第一个图像。当我想显示第一个图像的数组时,我使用命令print(image_batch[0]),如所示,所有图像的数组为shape=(224, 3),但我认为所有图像的大小应为shape=(224,224,3)那么,我必须使用什么命令来访问每个图像的数组?
我在google colab中使用TensorFlow版本2.9。我不确定test_dataset.unbatch()问题是否在这里?

hfyxw5xn

hfyxw5xn1#

unbatch方法实际上返回每个单独的图像,要获得在每次迭代时返回批处理的批处理迭代器,您应该改为调用batch方法,或者只使用数据集迭代器,即:

for dataset in test_dataset:

所以在代码中,image_batch是一个shape(224,224,3)的图像,而image_batch[0]是一个shape(224,3)的数组,因为您对第一维进行了切片。
您可能希望查看dataset documentation以了解每个方法的说明。

相关问题