train_on_batch()在keras中有什么用?

o3imoua4  于 2023-01-02  发布在  其他
关注(0)|答案(5)|浏览(165)

train_on_batch()fit()有何不同?在哪些情况下我们应该使用train_on_batch()

q3qa4bjr

q3qa4bjr1#

对于这个问题,第一作者给出了一个简单的答案:
使用fit_generator,您还可以使用验证数据的生成器。一般来说,我建议使用fit_generator,但使用train_on_batch也可以。这些方法只是为了在不同的用例中使用方便,没有“正确”的方法。
train_on_batch允许您根据您提供的样本集合明确更新权重,而不考虑任何固定的批处理大小。您可以在需要时使用此选项:您可以使用这种方法在传统训练集的多个批处理上维护自己的迭代,但允许fitfit_generator为您迭代批处理可能更简单。
使用train_on_batch的一种情况是,在一批新样本上更新预先训练好的模型。假设您已经训练并部署了一个模型,稍后您收到了一组以前从未使用过的新训练样本。您可以使用train_on_batch仅在这些样本上直接更新现有模型。其他方法也可以做到这一点。但是对于这种情况使用train_on_batch是相当明确的。
除了像这样的特殊情况(出于某种教学原因,您需要在不同的训练批处理之间维护自己的游标,或者需要在特殊批处理上进行某种类型的半在线训练更新),最好始终使用fit(用于适合内存的数据)或fit_generator(用于将数据流批处理作为生成器)。

cngwdvgl

cngwdvgl2#

train_on_batch()使您能够更好地控制LSTM的状态,例如,在使用有状态LSTM并需要控制对model.reset_states()的调用时。您可能有多个系列的数据,并需要在每个系列之后重置状态,这可以使用train_on_batch()来完成。但是如果你使用.fit(),那么网络将在所有的数据序列上训练,而不需要重置状态。没有对与错,这取决于你使用的是什么数据,以及您希望网络如何运行。

mctunoxg

mctunoxg3#

如果您使用的是大型数据集,并且没有易于序列化的数据(如高秩numpy数组)来写入tfrecord,那么Train_on_batch的性能也会比fit和fit生成器有所提高。
在这种情况下,您可以将数组保存为numpy文件并加载其中较小的子集(traina.npy,trainb.npy等),当整个集合不适合内存时,您可以使用tf.data.Dataset.from_tensor_slices,然后使用train_on_batch处理您的子数据集,然后加载另一个数据集并再次调用train on batch,等等。现在你已经训练了你的整个数据集,并且可以精确地控制你的数据集训练你的模型的多少和什么。然后你可以定义你自己的时期,批量大小,等等,用简单的循环和函数从你的数据集中抓取。

kknvjkwl

kknvjkwl4#

@nbro answer确实有帮助,只是为了增加一些场景,假设你正在训练一些seq to seq模型或者一个带有一个或多个编码器的大型网络。我们可以使用train_on_batch创建自定义训练循环,并使用我们的一部分数据直接在编码器上验证,而不使用回调。为复杂的验证过程编写回调可能很困难。有几种情况下我们希望在批处理上训练。
问候你,卡希克

rwqw0loc

rwqw0loc5#

来自Keras -模型培训API:

  • fit:为模型训练固定数量的时期(数据集上的迭代)。
  • 批次培养:对单个批次数据运行单个梯度更新

当我们一次使用一批训练数据集更新鉴别器和生成器时,我们可以在GAN中使用它。我在Jason Brownlee的教程(How to Develop a 1D Generative Adversarial Network From Scratch in Keras)中看到使用train_on_batch。

快速搜索提示:键入Control+F,然后在搜索框中键入要搜索的术语(例如train_on_batch)。

相关问题