keras tf.data.dataset apply()不会更新数据集

wpx232ag  于 2022-11-13  发布在  其他
关注(0)|答案(1)|浏览(200)

我用image_dataset_from_directory加载一个图像数据集,它会给我一个PrefetchDataset,里面有我的图像和它们的相关标签one-hot encoded。
为了构建一个二进制图像分类器,我想转换我的PrefetchDataset标签,以了解一个图像是照片还是其他东西。
我是这样写的:

batch_size = 32
img_height = 250
img_width = 250

train_ds = image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  color_mode="rgb",
  subset="training",
  seed=69,
  crop_to_aspect_ratio=False,
  image_size=(img_height, img_width),
  batch_size=batch_size)

class_names = train_ds.class_names
# ['Painting', 'Photo', 'Schematics', 'Sketch', 'Text'] in my case

# Convert label to 1 is a photo or else 0
i = 1 # class_names.index('Photo')

def is_photo(batch):
    for images, labels in batch:
        bool_labels = tf.constant([int(l == 1) for l in labels],
                                  dtype=np.int32)
        labels = bool_labels
    return batch

new_train_ds = train_ds.apply(is_photo)

我的问题是new_train_ds没有遵从train_ds,这让我认为apply方法一定有问题。我还检查了bool_labels,它工作得很好。
有没有人对如何解决这个问题有一个想法。

jfewjypa

jfewjypa1#

也许可以试试这样的方法:

train_ds = train_ds.map(lambda x, y: (x, tf.cast(y == 1, dtype=tf.int64)))

相关问题