如何优化我的tensorflow代码(tensor上的迭代)?

2uluyalo  于 2021-09-08  发布在  Java
关注(0)|答案(1)|浏览(431)

您好,我写了一些代码,对我的Tensor的形状进行评估,并在每一步中为Tensor添加一个值。问题是大量迭代的速度很慢(我需要很多迭代)

for i in range(0, num_samples):
    for j in range(0, int(numTimeSteps.numpy())):
        random_num = tf.squeeze(tf.random.normal([1], mean, std_dev, dtype=dtype))
        tensor[i,j].assign(tf.math.add(random_num, estr[i,j]).numpy())

在每次迭代中,我向Tensor中添加一个随机数(从正态分布中提取)
num_samples由用户指定,可以在1000-100.000范围内
numtimespes是80
有没有办法优化这一点?

pcrecxhr

pcrecxhr1#

tf应该很少依赖循环或numpy。如果您看到其中一个,请尝试将其矢量化并使用本机tf ops进行操作。
也许像

estr = estr + tf.random.normal(shape=estr.shape)

请注意,tf.random很快就会因为不够随机而被弃用-请参阅文档中的注解。首选方法是tf.random.generator。

相关问题