keras 具有特定类的Tensorflow数据集管道

webghufk  于 2022-11-13  发布在  其他
关注(0)|答案(1)|浏览(140)

我希望使用具有特定类索引的数据集管道。

  • 例如:

如果我使用CIFAR-10数据集,我可以加载数据集,如下所示:

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

它加载所有的类标签(10个类)。我可以使用下面的代码创建一个管道:

train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(64)
test_dataset = tf.data.Dataset.from_tensor_slices(x_test,y_test)).batch(64)

这对于训练Keras模型很有效。

  • 现在我想创建一个有几个样本的管道(而不是使用所有10个类样本,可能只使用5个样本)。有什么方法可以创建这样的管道吗?
zbwhf8kr

zbwhf8kr1#

您可以使用tf.data.Dataset.filter

import tensorflow as tf

class_indexes_to_keep = tf.constant([0, 3, 4, 6, 8], dtype=tf.int64)

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

y_train = y_train.astype(int)
y_test = y_test.astype(int)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).filter(lambda x, y: tf.reduce_any(y == class_indexes_to_keep)).batch(64)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test,y_test)).filter(lambda x, y: tf.reduce_any(y == class_indexes_to_keep)).batch(64)

要转换为分类标签,您可以尝试:

import tensorflow as tf

one_hot_encode = tf.keras.utils.to_categorical(tf.range(10, dtype=tf.int64), num_classes=10)
class_indexes_to_keep = tf.gather(one_hot_encode, tf.constant([0, 3, 4, 6, 8], dtype=tf.int64))

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
y_train = y_train.astype(int)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).map(lambda x, y: (x, tf.one_hot(y, 10)[0]))
train_dataset = train_dataset.filter(lambda x, y: tf.reduce_any(tf.reduce_all(y == class_indexes_to_keep, axis=-1))).batch(64)

相关问题