Keras:字典作为验证数据

mo49yndu  于 2022-11-13  发布在  其他
关注(0)|答案(1)|浏览(135)

从Keras手册中,我了解到变量validation_data可以是:

  • NumPy数组或Tensor的元组(x_val, y_val)
  • NumPy数组的元组(x_val, y_val, val_sample_weights)
  • 一个tf.data.Dataset。
  • 返回(inputs, targets)(inputs, targets, sample_weights)的Python生成器或keras.utils.sequence。

我的问题是:既然我使用了多个命名输入,那么我是否可以使用元组(x_val, y_val)作为validation_data,其中x_val是NumPy数组的字典(键等于模型输入的名称),y_val是一个简单的NumPy数组?
谢谢你的帮助。

r9f1avp5

r9f1avp51#

由于您使用了多个命名输入,因此不能为validation_data参数传递元组(x_val, y_val)(至少目前Keras不支持)。根据TensorFlowKeras文档:

验证数据将覆盖验证拆分验证数据可以是:

  • Numpy数组或Tensor的元组**(x_瓦尔,y_val)**。
  • NumPy数组的元组**(x_瓦尔,y_val,val_sample_weights)**。
  • 一个数据集。
    tf.distribute.experimental.ParameterServerStrategy尚不支持返回*(输入,目标)(输入,目标,样本权重).验证数据的Python生成器或keras.utils.Sequence**。
    可能的解决方案:

一个可能的解决方案是连接训练和验证数据集,并将其作为xy的参数传递给fit方法,同时使用validation_split指定验证部分。请注意:
验证数据是在混洗之前从所提供的x和y数据中的最后一个样本中选择的。

更多详细信息

假设您的数据集有两个输入(例如in1和in2)和两个输出(例如out1和out2)。

可选阅读

您可以首先根据需要重排训练和验证数据集:

concat_xy_train=np.concatenate((train_in1, train_in2, train_out1, train_out2), axis=1)
concat_xy_val=np.concatenate((val_in1, val_in2, val_out1, val_out2), axis=1)
np.random.shuffle(concat_xy_train)
np.random.shuffle(concat_xy_val)

然后,您可以撷取图征和标示:

shuf_train_in1 = concat_xy_train[:,:len_in1]
shuf_train_in2 = concat_xy_train[:,len_in1:len_in1+len_in2]
shuf_train_out1 = concat_xy_train[:,len_in1+len_in2:len_in1+len_in2+len_out1]
shuf_train_out2 = concat_xy_train[:,len_in1+len_in2+len_out1:]

shuf_val_in1 = concat_xy_val[:,:len_in1]
shuf_val_in2 = concat_xy_val[:,len_in1:len_in1+len_in2]
shuf_val_out1 = concat_xy_val[:,len_in1+len_in2:len_in1+len_in2+len_out1]
shuf_val_out2 = concat_xy_val[:,len_in1+len_in2+len_out1:]

训练和验证数据集的连接

train_val_in1 = np.concatenate((shuf_train_in1, shuf_val_in1), axis=0)
train_val_in2 = np.concatenate((shuf_train_in2, shuf_val_in2), axis=0)
train_val_out1 = np.concatenate((shuf_train_out1, shuf_val_out1), axis=0)
train_val_out2 = np.concatenate((shuf_train_out2, shuf_val_out2), axis=0)

适合车型

拟合模型时:

model.fit(
    {"in1": train_val_in1, "in2": train_val_in2},
    {"out1": train_val_out1, "out2": train_val_out2},
    validation_split=len_val/(len_val+len_train),
...

相关问题