我刚刚升级到Tensorflow 2.3,我想制作自己的数据生成器用于训练。使用Tensorflow 1.x,我这样做:
def get_data_generator(test_flag):
item_list = load_item_list(test_flag)
print('data loaded')
while True:
X = []
Y = []
for _ in range(BATCH_SIZE):
x, y = get_random_augmented_sample(item_list)
X.append(x)
Y.append(y)
yield np.asarray(X), np.asarray(Y)
data_generator_train = get_data_generator(False)
data_generator_test = get_data_generator(True)
model.fit_generator(data_generator_train, validation_data=data_generator_test,
epochs=10000, verbose=2,
use_multiprocessing=True,
workers=8,
validation_steps=100,
steps_per_epoch=500,
)
此代码在Tensorflow 1.x中运行良好。在系统中创建了8个进程。处理器和显卡加载完美。“数据加载”打印了8次。
使用Tensorflow 2.3时,我收到警告:
警告:tensorflow :多处理可能会与TensorFlow发生不良交互,从而导致不确定性死锁。对于高性能数据管道tf.data建议访问www.example.com。
“数据加载”被打印一次(应该是8次)。GPU没有被充分利用。它也有内存泄漏每个时期,所以训练将停止后,几个时期。使用_multiprocessing标志没有帮助。
如何在tensorflow(keras)2.x中制作一个生成器/迭代器,使其可以轻松地跨多个CPU进程并行化?死锁和数据顺序并不重要。
1条答案
按热度按时间kzipqqlq1#
使用
tf.data
流水线,您可以在多个点进行并行化。根据数据存储和读取的方式,您可以并行阅读。您还可以并行扩展,并且可以在训练时预取数据,因此您的GPU(或其他硬件)永远不会对数据感到饥饿。在下面的代码中,我演示了如何并行化扩展和添加预取。