keras Tensorflow图像分类,找到了200万个文件,但只使用了其中的416k

5anewei6  于 2022-11-13  发布在  其他
关注(0)|答案(1)|浏览(165)

我目前正在用Tensorflow做一个基本的图像分类算法,代码基本上完全遵循https://www.tensorflow.org/tutorials/images/classification上给出的教程,除了我使用的是我自己的数据。
目前,我已经为生成数据集进行了以下设置:

#Set up information on the data
batch_size = 32
img_height = 100
img_width = 100

#Generate training dataset
train_ds = tf.keras.utils.image_dataset_from_directory(
  Directory,
  validation_split=0.8,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

#Generate val dataset
val_ds = tf.keras.utils.image_dataset_from_directory(
  Directory,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

但在集群上运行后,我在终端输出中看到了以下内容:

2022-09-30 09:49:26.936639: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] 

The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2022-09-30 09:49:26.956813: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Found 2080581 files belonging to 2 classes.
Using 416117 files for training.
Found 2080581 files belonging to 2 classes.
Using 416116 files for validation.

我没有大量的tensorflow 的经验,真的不知道如何修复这个错误,有人能给我指出正确的方向吗?

jjjwad0x

jjjwad0x1#

您保留了20%的数据用于训练(2080581 * 20% ≈ 416117),因为validation_split是80%。我认为您实际上需要相反的方式:

#Generate training dataset
train_ds = tf.keras.utils.image_dataset_from_directory(
  Directory,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

#Generate val dataset
val_ds = tf.keras.utils.image_dataset_from_directory(
  Directory,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

请查看docs以了解更多信息和此示例。

相关问题