为什么tensorflow.function(没有jit_compile)可以加速Keras模型的前向传递?

djmepvbi  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(114)

可以使用model = tf.function(model, jit_compile=True)启用XLA。有些模型类型以这种方式更快,有些则更慢。到目前为止,一切顺利。
但是为什么model = tf.function(model, jit_compile=None)在某些情况下可以显著加快速度(没有TPU)?
jit_compile文档状态:
如果是None(默认值),则在TPU上运行时使用XLA编译函数,在其他设备上运行时通过常规函数执行路径执行。
我在两台非TPU(甚至非GPU)机器上运行测试(安装了最新的TensorFlow(2.13.0))。

import timeit

import numpy as np
import tensorflow as tf

model_plain = tf.keras.applications.efficientnet_v2.EfficientNetV2S()
model_jit_compile_true = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=True)
model_jit_compile_false = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=False)
model_jit_compile_none = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=None)

def run(model):
    model(np.random.random(size=(1, 384, 384, 3)))

# warmup
run(model_plain)
run(model_jit_compile_true)
run(model_jit_compile_false)
run(model_jit_compile_none)

runs = 10
duration_plain = timeit.timeit(lambda: run(model_plain), number=runs) / runs
duration_jit_compile_true = timeit.timeit(lambda: run(model_jit_compile_true), number=runs) / runs
duration_jit_compile_false = timeit.timeit(lambda: run(model_jit_compile_false), number=runs) / runs
duration_jit_compile_none = timeit.timeit(lambda: run(model_jit_compile_none), number=runs) / runs

print(f"{duration_plain=}")
print(f"{duration_jit_compile_true=}")
print(f"{duration_jit_compile_false=}")
print(f"{duration_jit_compile_none=}")
duration_plain=0.53095479644835
duration_jit_compile_true=1.5860380740836262
duration_jit_compile_false=0.09831228516995907
duration_jit_compile_none=0.09407951850444078
q3aa0525

q3aa05251#

但是为什么在某些情况下model = tf.function(model,jit_compile=None)可以显著加快速度(没有TPU)呢?

加速主要是由于图形模式enabled by tf.function,比model_plain中使用的急切执行快得多。

最重要的是,我们有XLA编译的次要影响with jit_compile flag,但它们在很大程度上取决于计算架构。例如,在GPU加速器下编译时,数字看起来会有很大不同。
最后但并非最不重要的是,应该纠正基准测试方法,以考虑到10次运行的变化和有问题的用例(否则,结果将是误导性的,甚至是矛盾的,例如:由于高变化,XLA=None平均看起来更快)。为了将来的参考,让我们明确**this profiling pattern from Tensorflow docs是不准确的**

# average runtime on 10 repetitions without variance is inaccurate
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))

以下经过修正和扩展的代码片段,在带有GPU的Kaggle笔记本上执行,表明改进主要来自图形模式XLA编译提供了一些进一步的加速

import timeit

import numpy as np
import tensorflow as tf

model_plain = tf.keras.applications.efficientnet_v2.EfficientNetV2S()
model_tffunc = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=None)
model_jit_compile_true = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=True)
model_jit_compile_false = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=False)
model_jit_compile_none = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=None)

x = np.random.random(size=(1, 384, 384, 3))

def run(model):
    model(x)

# warmup
run(model_plain)
run(model_tffunc)
run(model_jit_compile_true)
run(model_jit_compile_false)
run(model_jit_compile_none)

# benchmarking
duration_plain = %timeit -o run(model_plain)
duration_tffunc = %timeit -o run(model_tffunc)
duration_jit_compile_true = %timeit -o run(model_jit_compile_true)
duration_jit_compile_false = %timeit -o run(model_jit_compile_false)
duration_jit_compile_none = %timeit -o run(model_jit_compile_none)

print(f"{str(duration_plain)=}")
print(f"{str(duration_tffunc)=}")
print(f"{str(duration_jit_compile_true)=}")
print(f"{str(duration_jit_compile_false)=}")
print(f"{str(duration_jit_compile_none)=}")

从统计上看,我们有:duration_plain > duration_jit_compile_false = duration_jit_compile_none = duration_tffunc > duration_jit_compile_true,从输出中可以看出:

369 ms ± 3.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
16.1 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
11.6 ms ± 882 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
15.9 ms ± 508 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
15.5 ms ± 450 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
str(duration_plain)='369 ms ± 3.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)'
str(duration_tffunc)='16.1 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)'
str(duration_jit_compile_true)='11.6 ms ± 882 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)'
str(duration_jit_compile_false)='15.9 ms ± 508 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)'
str(duration_jit_compile_none)='15.5 ms ± 450 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)'

一个完整的例子是see this public notebook
注意:这种测量变化的方法是有用的,但not fully accurate

相关问题