我用image_dataset_from_directory
加载一个图像数据集,它会给我一个PrefetchDataset
,里面有我的图像和它们的相关标签one-hot encoded。
为了构建一个二进制图像分类器,我想转换我的PrefetchDataset
标签,以了解一个图像是照片还是其他东西。
我是这样写的:
batch_size = 32
img_height = 250
img_width = 250
train_ds = image_dataset_from_directory(
data_dir,
validation_split=0.2,
color_mode="rgb",
subset="training",
seed=69,
crop_to_aspect_ratio=False,
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_ds.class_names
# ['Painting', 'Photo', 'Schematics', 'Sketch', 'Text'] in my case
# Convert label to 1 is a photo or else 0
i = 1 # class_names.index('Photo')
def is_photo(batch):
for images, labels in batch:
bool_labels = tf.constant([int(l == 1) for l in labels],
dtype=np.int32)
labels = bool_labels
return batch
new_train_ds = train_ds.apply(is_photo)
我的问题是new_train_ds
没有遵从train_ds
,这让我认为apply
方法一定有问题。我还检查了bool_labels
,它工作得很好。
有没有人对如何解决这个问题有一个想法。
1条答案
按热度按时间jfewjypa1#
也许可以试试这样的方法: