如何获取TensorFlow数据集中的类名?

6jygbczu  于 2023-08-06  发布在  其他
关注(0)|答案(3)|浏览(99)

我试图了解如何使用tensorflow数据集,tfds。
数据集就是这种目录

-dataset
  -train
    -class_name1
      -files...
    -class_name2
      -files...
   -val
    -class_name1
      -files...
    -class_name2
      -files...
  -test
    -class_name1
      -files...
    -class_name2
      -files...

字符串
下面是一些代码:

import tensorflow_datasets as tfds
builder = tfds.ImageFolder('/content/dataset')
train_ds, val_ds, test_ds = builder.as_dataset(split=['train', 'val', 'test'], shuffle_files=True, as_supervised=True)

x

print(builder.info)
Output:
tfds.core.DatasetInfo(
    name='image_folder',
    version=1.0.0,
    description='Generic image classification dataset.',
    homepage='https://www.tensorflow.org/datasets/catalog/image_folder',
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=243),
    }),
    total_num_examples=73090,
    splits={
        'test': 5849,
        'train': 58469,
        'val': 8772,
    },
    supervised_keys=('image', 'label'),
    citation="""""",
    redistribution_info=,
)

的字符串
当我绘图,做分类报告,混淆矩阵等,我希望能够使用class_names而不是整数标签。
有没有一些简单的命令可以给予我访问class_names?(有243个班,不是2个)

kiayqfof

kiayqfof1#

您的初始构建器:

builder = tfds.ImageFolder(path_to_data)

字符串
关于提取类名。他们是隐藏的。在此生成器的train_ds.class_names中不可用。
你必须直接使用你的构建器。
获取builder.infohttps://www.tensorflow.org/datasets/api_docs/python/tfds/core/DatasetInfo
.featureshttps://www.tensorflow.org/datasets/api_docs/python/tfds/features/FeaturesDict
然后在FeaturesDict的'label'键中有一个ClassLabelhttps://www.tensorflow.org/datasets/api_docs/python/tfds/features/ClassLabel),

print(builder.info.features['label'])
ClassLabel(shape=(), dtype=tf.int64, num_classes=99)


在它里面,你得到了一个names args/属性,描述为:
names:列表< str >,整数类的字符串名称。保持提供名称的顺序。
最后,你直接得到了答案:

class_names = builder.info.features['label'].names

frebpwbc

frebpwbc2#

这个问题有点老了。然而,提供了一个答案。
首先,您可以使用iter()函数在您的数据集上创建迭代器,例如,在您的训练数据集上,然后使用next()函数访问每个观察结果。
然后,您可以将.int2str()方法应用于从builder.info对象访问的标签,以将整数标签转换为相应的类名。
例如,假设您正在绘制第一组32个带有相应标签的图像,您可以使用for循环编写以下代码。

plt.figure(figsize = (15,10))
iterator =  iter(train_ds)
for i in range(32):
  img, label = next(iterator)
  plt.subplot(4,8,i+1)
  plt.imshow(img)
  plt.title(builder.info.features['label'].int2str(label))
plt.tight_layout()
plt.show()

字符串
希望这对你有帮助。

uoifb46i

uoifb46i3#

您可以访问train_ds, val_ds, test_ds,而它们也有class_names,您可以使用标签作为索引来访问它,正如fpierrem所提到的
下面是一个情节的例子:

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(train_ds.class_names[labels[i]])
        plt.axis("off")
    
plt.show()

字符串
参见https://www.tensorflow.org/tutorials/keras/classification#import_the_fashion_mnist_dataset和https://www.tensorflow.org/tutorials/load_data/images#visualize_the_data

相关问题