tensorflow 在保存模型时,tf.function内部的自动转换变量(AutoCast variable)不会正确转换,

cs7cruho  于 10个月前  发布在  其他
关注(0)|答案(9)|浏览(85)

问题类型

Bug

来源

source

Tensorflow版本

2.9.3

自定义代码

OS平台和发行版

Windows 10

移动设备

  • 无响应*

Python版本

3.9

Bazel版本

  • 无响应*

GCC/编译器版本

  • 无响应*

CUDA/cuDNN版本

  • 无响应*

GPU型号和内存

  • 无响应*

当前行为?

  1. Error when calling model.save()
  2. You can work around the error by commenting out `tf.function` over `call()`.

独立代码重现问题

  1. import tensorflow as tf
  2. import keras
  3. import keras.layers
  4. class CosineSimilarityLayer(keras.layers.Layer):
  5. def __init__(
  6. self, num_classes: int, name: str = None
  7. ):
  8. super().__init__(name=name)
  9. self.num_classes = num_classes
  10. self._weights = None
  11. def build(self, input_shape):
  12. self._weights = self.add_weight(
  13. name="W",
  14. shape=(
  15. input_shape[-1],
  16. self.num_classes,
  17. ),
  18. initializer="glorot_normal",
  19. trainable=True,
  20. dtype=self.dtype
  21. )
  22. super().build(input_shape)
  23. def compute_output_shape(self):
  24. return None, self.num_classes
  25. # If you comment out tf.function here, it works.
  26. @tf.function
  27. def call(self, inputs: tf.Tensor):
  28. embedding = inputs
  29. # normalize feature
  30. embedding_normalized = tf.nn.l2_normalize(embedding, axis=1)
  31. # get centroids
  32. weights_normalized = tf.nn.l2_normalize(self._weights, axis=1, )
  33. logits = embedding_normalized @ weights_normalized
  34. return logits
  35. def get_config(self):
  36. config = super().get_config().copy()
  37. config.update(
  38. {
  39. "num_classes": self.num_classes,
  40. }
  41. )
  42. return config
  43. def main():
  44. tf.keras.mixed_precision.set_global_policy("mixed_float16")
  45. layer = CosineSimilarityLayer(num_classes=100)
  46. input = tf.zeros(shape=(10, 100), dtype=tf.float16)
  47. model = keras.models.Sequential([
  48. layer
  49. ])
  50. model(input)
  51. model.save("autocast_issue")
  52. if __name__ == "__main__":
  53. main()

相关日志输出

  1. File "*\autocast_issue.py", line 37, in call *
  2. logits = embedding_normalized @ weights_normalized
  3. TypeError: Input 'b' of 'MatMul' Op has type float32 that does not match type float16 of argument 'a'.
7rtdyuoh

7rtdyuoh1#

我不太明白。图层调用函数只有一个参数。所以文档中指定的限制不应该在这里适用。

p1tboqfb

p1tboqfb2#

@JustASquid,
是的,策略仅适用于将第一个参数转换为float16的层调用的第一个参数,但是变量weights_normalized的类型是float32,这是TensorFlow的默认数据类型,因此在操作embedding_normalized @ weights_normalized中引发了TypeError: Input 'b' of 'MatMul' Op has type float32 that does not match type float16 of argument 'a'
错误不是由策略引发的。策略按预期将第一个参数的数据类型进行了转换。但是操作embedding_normalized @ weights_normalized不支持不同数据类型的乘法。两个参数都应该是相同的类型。请同时参考Tensorflow的[tensordot](embedding_normalized @ weights_normalized) API,关于相同的内容。
因此,您需要保持点积的相同数据类型,这就是错误的原因,而不是策略tf.keras.mixed_precision.set_global_policy
谢谢!

t3irkdon

t3irkdon3#

我认为这里存在一些混淆。在混合精度策略下,权重变量被创建为AutoCastVariable类型,当它与float16输入Tensor相乘时,应该自动转换为float16。例如,查看Dense层的实现,它在进行卷积操作时也是依赖于这一点的。

nvbavucw

nvbavucw4#

@JustASquid ,

感谢指出。我深入了解了Policy,是的,您在这里是正确的。权重变量应该自动转换为 Policy dtype ,而且似乎在调用模型时也发生了这种情况。我插入了一些调试语句。我观察到,当 call() 方法在 model.save() 的最后一次调用中与 @tf.function 一起 Package model.save() 时,不知何故返回了 float32 结果,这可能是问题所在。而没有 @tf.function Package 器的所有调用都返回了 float16 作为 self.dtype 的结果。请参阅附件 gist

这里似乎存在一些序列化问题,tf.function 作为 Package 器。我们需要仔细研究这个问题,并请求您尽可能找到问题的根源。

谢谢!

xwbd5t1u

xwbd5t1u5#

Reed,看起来tf.functions和AutocastVariable之间存在一些问题。当一个函数在Keras下被追踪(我猜enable_auto_cast_variables可能在某个地方被使用),变量会被正确地转换为预期的数据类型。但是当一个函数在核心TF下被追踪时,没有上下文,所以会使用普通的float32值。

是否有可能在tf.function内部始终将Autocast Variables解析为其计算类型?

2guxujil

2guxujil6#

enable_auto_cast_variablesLayer.__call__ 中被调用,以便在 Layer.call 内部对 AutoCastVariables 进行类型转换。
当在 model.save() 调用过程中发生错误时,我在堆栈跟踪中看到了函数的重新追踪。它在 Layer.__call__ 外部被重新追踪,因此 enable_auto_cast_variables 的作用域不生效,导致了错误。@k-w-w,你知道为什么它会在 Layer.__call__ 外部被重新追踪吗?我对 TF 代码库的 tf.function 部分不是很熟悉。
不幸的是,我们不能总是将 AutocastVariables 解析为 tf.functions 内部的计算侧,因为只有层具有计算类型,而不是变量。有意将变量仅在 Layer.call 内部使用时进行类型转换。

rslzwgfq

rslzwgfq7#

@reedwm 它正在被重写,因为函数被标记为 "tf.function"。当 SavedModel 代码看到一个 tf.function 时,它会在空白上下文中重新跟踪(例如,在分布策略内部被跟踪的情况下)。我们不想保存分布式函数。我不太确定在这个情况下缓存键的哪个部分发生了变化。

x7rlezfr

x7rlezfr8#

你好@JustASquid!
我可以在2.9、2.10和2.11中复制这个问题。
@SuryanarayanaY!
我在2.9、2.10和2.11中附上了gist供参考。
谢谢!

wkftcu5l

wkftcu5l9#

你好@JustASquid,

根据文档,这个策略有一个限制如下:

目前,只有层调用方法的第一个参数会被转换(尽管这在未来的次要版本中可能会改变)

由于你将第一个参数 input 传递为 float16,将第二个参数 weights_normalized 传递为 float32,策略只转换了第一个参数,而第二个参数保持不变,从而导致类型错误。因此,我将第二个参数也更改为与策略中提到的参数 (float16) 相匹配,这样就没有错误了。我甚至尝试将第一个参数从 float16 更改为 float32,策略能够将其转换为 float16,代码执行没有错误。请参考附加的 gist-nightly-2.12V

希望这能解答你的疑问。谢谢!

相关问题