可以使用model = tf.function(model, jit_compile=True)
启用XLA。有些模型类型以这种方式更快,有些则更慢。到目前为止,一切顺利。
但是为什么model = tf.function(model, jit_compile=None)
在某些情况下可以显著加快速度(没有TPU)?jit_compile
文档状态:
如果是None
(默认值),则在TPU上运行时使用XLA编译函数,在其他设备上运行时通过常规函数执行路径执行。
我在两台非TPU(甚至非GPU)机器上运行测试(安装了最新的TensorFlow(2.13.0
))。
import timeit
import numpy as np
import tensorflow as tf
model_plain = tf.keras.applications.efficientnet_v2.EfficientNetV2S()
model_jit_compile_true = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=True)
model_jit_compile_false = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=False)
model_jit_compile_none = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=None)
def run(model):
model(np.random.random(size=(1, 384, 384, 3)))
# warmup
run(model_plain)
run(model_jit_compile_true)
run(model_jit_compile_false)
run(model_jit_compile_none)
runs = 10
duration_plain = timeit.timeit(lambda: run(model_plain), number=runs) / runs
duration_jit_compile_true = timeit.timeit(lambda: run(model_jit_compile_true), number=runs) / runs
duration_jit_compile_false = timeit.timeit(lambda: run(model_jit_compile_false), number=runs) / runs
duration_jit_compile_none = timeit.timeit(lambda: run(model_jit_compile_none), number=runs) / runs
print(f"{duration_plain=}")
print(f"{duration_jit_compile_true=}")
print(f"{duration_jit_compile_false=}")
print(f"{duration_jit_compile_none=}")
duration_plain=0.53095479644835
duration_jit_compile_true=1.5860380740836262
duration_jit_compile_false=0.09831228516995907
duration_jit_compile_none=0.09407951850444078
1条答案
按热度按时间q3aa05251#
但是为什么在某些情况下model = tf.function(model,jit_compile=None)可以显著加快速度(没有TPU)呢?
加速主要是由于图形模式enabled by
tf.function
,比model_plain
中使用的急切执行快得多。最重要的是,我们有XLA编译的次要影响with
jit_compile
flag,但它们在很大程度上取决于计算架构。例如,在GPU加速器下编译时,数字看起来会有很大不同。最后但并非最不重要的是,应该纠正基准测试方法,以考虑到10次运行的变化和有问题的用例(否则,结果将是误导性的,甚至是矛盾的,例如:由于高变化,
XLA=None
平均看起来更快)。为了将来的参考,让我们明确**this profiling pattern from Tensorflow docs是不准确的**以下经过修正和扩展的代码片段,在带有GPU的Kaggle笔记本上执行,表明改进主要来自图形模式,XLA编译提供了一些进一步的加速。
从统计上看,我们有:
duration_plain > duration_jit_compile_false = duration_jit_compile_none = duration_tffunc > duration_jit_compile_true
,从输出中可以看出:一个完整的例子是see this public notebook。
注意:这种测量变化的方法是有用的,但not fully accurate。