tensorflow 使用`jit_compile=True`的`tf.raw_ops.SqrtGrad`的不同行为

xv8emn3q  于 4个月前  发布在  其他
关注(0)|答案(2)|浏览(104)

问题类型

Bug

你是否在TensorFlow Nightly版本中复现了这个bug?

问题来源

source

TensorFlow版本

2.14.0

自定义代码

OS平台和发行版

  • 无响应*

移动设备

  • 无响应*

Python版本

  • 无响应*

Bazel版本

  • 无响应*

GCC/编译器版本

  • 无响应*

CUDA/cuDNN版本

11.8

GPU型号和内存

GPU 0: NVIDIA GeForce RTX 2070 GPU 1: NVIDIA GeForce RTX 2070 GPU 2: NVIDIA GeForce RTX 2070 GPU 3: NVIDIA GeForce RTX 2070

当前行为?

tf.raw_ops.SqrtGrad操作在一个启用了JIT编译的tf.function中被调用(jit_compile=True)时,它产生的结果与没有启用JIT编译的相同操作产生的结果不同。这种不一致性在GPU设备上执行代码时被观察到。

重现问题的独立代码

import tensorflow as tf
import traceback

class Network(tf.Module):
    def __init__(self):
        super().__init__()

    @tf.function(jit_compile=True)
    def __call__(self, x):
      real_part = tf.random.normal([], dtype=tf.float64)
      imag_part = tf.random.normal([], dtype=tf.float64)
      tensor = tf.complex(real_part, imag_part)
      tensor = tf.cast(tensor,dtype=tf.complex128)
      x = tf.raw_ops.SqrtGrad(y=x, dy=tensor)        
      return x

m = Network()
real_part = tf.random.normal([], dtype=tf.float64)
imag_part = tf.random.normal([], dtype=tf.float64)
tensor = tf.complex(real_part, imag_part)
tensor = tf.cast(tensor,dtype=tf.complex128)
inp = {
    "x": tensor,
}

with tf.device('/GPU:0'):
    tf.config.run_functions_eagerly(True)
    no_op_res = m(**inp)
    tf.config.run_functions_eagerly(False)
    with tf.device('/GPU:0'):
        op_res = m(**inp)

    tf.debugging.assert_near(tf.cast(no_op_res, tf.float64), tf.cast(op_res, tf.float64), atol=0.001, rtol=0.001)

相关日志输出

File "/home/guihuan/LLM/results/tf-2/2023-10-22-20-21/test.py", line 33, in <module>
    tf.debugging.assert_near(tf.cast(no_op_res, tf.float64), tf.cast(op_res, tf.float64), atol=0.001, rtol=0.001)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/tensorflow/python/ops/control_flow_assert.py", line 102, in Assert
    raise errors.InvalidArgumentError(
tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected 'tf.Tensor(False, shape=(), dtype=bool)' to be true. Summarized data: b''
b'x and y not equal to tolerance rtol = tf.Tensor(0.001, shape=(), dtype=float64), atol = tf.Tensor(0.001, shape=(), dtype=float64)'
b'x (shape=() dtype=float64) = '
-0.006697387971180855
b'y (shape=() dtype=float64) = '
0.07167101474792367
edqdpe6u

edqdpe6u1#

你好,@zoux1a!
我能够使用jit_compile=True和jit_compile=False来复现这个问题。在这里,我附上了一张gist的图片。
谢谢!

xt0899hw

xt0899hw2#

似乎与tf.function有关,因为在有和没有jit_compile的情况下都无法复制。

相关问题