keras 使用TensorFlow数据集的验证集

dgsult0t  于 2023-03-08  发布在  其他
关注(0)|答案(1)|浏览(161)

Train and evaluate with Keras开始:
当从Dataset对象训练时,不支持参数validation_split(从训练数据生成维持集),因为此功能要求能够索引数据集的样本,而使用Dataset API通常无法实现这一点。
是否有变通方案?如何仍然使用TF数据集的验证集?

s2j5cfk0

s2j5cfk01#

不,您不能使用use validation_split(文档中有明确的描述),但是您可以创建validation_data,并“手动”创建Dataset
您可以在同一tensorflow 教程中看到一个示例:

# Prepare the training dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

# Prepare the validation dataset
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(64)

model.fit(train_dataset, epochs=3, validation_data=val_dataset)

您可以使用简单的切片从numpy数组((x_train, y_train)(x_val, y_val))创建这两个数据集,如下所示:

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

创建tf.data.Dataset对象还有其他方法,请参阅tf.data.Dataset文档和相关教程/笔记本。

相关问题