tensorflow 如何使用大型.npy文件进行培训?

chhkpiq4  于 2023-10-23  发布在  其他
关注(0)|答案(1)|浏览(107)

我在做一个分类器我的数据保存在.npy文件中,该文件超过40GB大,所以很明显很难将其放入内存中。我发现在加载fle时使用mmap_mode可能会有帮助,但我不确定如何将其用于训练。目前我使用以下方法:

data_train_np = np.load('./train.npy',mmap_mode='r')
train_dataset = tf.data.Dataset.from_tensor_slices((data_train_np, labels))
train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)

我正在使用fit()进行训练,但我认为这并没有将所有数据都输入到模型中。任何人都可以提供帮助,或者告诉在mmap_mode中使用numpy数据进行训练的正确方法是什么?

cs7cruho

cs7cruho1#

有几个选项可以做到这一点。
我将假设您正在对图像进行的转换以获得最终矩阵是非常昂贵的,并且不希望在数据加载时进行预处理。如果不是这种情况,您可能会受益于阅读与图像处理相关的tutorial
我建议您使用python生成器读取数据,并使用tf.data创建数据管道,将这些.npy文件提供给您的模型。基本思想非常简单。你可以使用一个 Package 器从一个生成器中获取数据,生成器将根据需要读取文件。相关文档和示例在这里。
现在,一旦你得到了工作,我认为这将是一个好主意,你optimize你的管道,特别是如果你计划在多个GPU或多台计算机上训练。

相关问题