tensorflow 在Petastorm中创建训练和有效数据集

m528fe3b  于 2023-05-18  发布在  Storm
关注(0)|答案(1)|浏览(202)

版本:Python3.7.13,Tensorflow-2.9.1,Petastorm-0.12.1
在petastorm中,使用从petastorm创建的数据集训练模型的唯一方法似乎是在Reader上下文管理器中fit模型,如下所示:

with make_batch_reader(train_s3_paths, schema_fields=cols+['target']) as tr_reader:
    dataset = make_petastorm_dataset(tr_reader).shuffle(10000).repeat(n_epochs).map(parse)
    history = model.fit(dataset)

我想传入训练数据集和验证数据集,如何做到这一点?

with make_batch_reader(train_s3_paths, schema_fields=cols+['target']) as tr_reader:
    tr_dataset = make_petastorm_dataset(tr_reader).shuffle(10000).repeat(n_epochs).map(parse)
    with make_batch_reader(val_s3_paths, schema_fields=cols+['target']) as val_reader:
         val_dataset = make_petastorm_dataset(val_reader).shuffle(10000).repeat(n_epochs).map(parse)
         history = model.fit(tr_dataset, validation_data=val_dataset)

这是解决我面临的问题的有效方法吗?是否有替代方法,例如在上下文管理器之外使用数据集或根本不使用上下文管理器?

n53p2ov0

n53p2ov01#

我不确定make_batch_reader是否正确,但是“with”语句可以接受多个语句。阅读this了解更多信息。
对你来说,这应该行得通-

with make_batch_reader(train_s3_paths, schema_fields=cols+['target']) as tr_reader, make_batch_reader(val_s3_paths, schema_fields=cols+['target']) as val_reader:
    tr_dataset = make_petastorm_dataset(tr_reader).shuffle(10000).repeat(n_epochs).map(parse)
    val_dataset = make_petastorm_dataset(val_reader).shuffle(10000).repeat(n_epochs).map(parse)
    history = model.fit(tr_dataset, validation_data=val_dataset)

相关问题