tensorflow 无法加载使用tfa.optimizers.NovoGrad训练的模型时出现错误:ValueError:形状(3, 8)和()不兼容,

zvms9eto  于 5个月前  发布在  其他
关注(0)|答案(6)|浏览(154)

请确保这是一个bug。根据我们的
GitHub Policy ,
我们只在GitHub上解决代码/文档bug、性能问题、功能请求和
构建/安装问题。标签:bug_template*

系统信息

  • 我是否编写了自定义代码(与使用TensorFlow提供的库存示例脚本相反):否。
  • OS平台和发行版(例如,Linux Ubuntu 16.04):Linux Ubuntu 18.04
  • 如果问题发生在移动设备上,则移动设备(例如iPhone 8,Pixel 2,Samsung Galaxy):n/a
  • 从哪里安装的TensorFlow(源或二进制):二进制
  • TensorFlow版本(请使用以下命令):v2.3.0-54-gfcc4b966f1 2.3.1
  • Python版本:3.6.10
  • Bazel版本(如果从源编译):n/a
  • GCC/编译器版本(如果从源编译):n/a
  • CUDA/cuDNN版本:n/a
  • GPU型号和内存:n/a

您可以使用我们的环境捕获 script 收集一些此信息。您还可以使用以下命令获取TensorFlow版本:

  1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
  2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

描述当前行为

我正在使用 tfa.optimizer.NovoGrad ( https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/NovoGrad )训练模型。模型似乎训练得很好,但当我尝试使用 tf.keras.models.load_model("my_model") 加载模型时,我会得到一个关于形状的 ValueError 错误。为了重现此行为,提供了一个最小示例:

# Train a small model and save it as a SavedModel
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(8))
model.add(tf.keras.layers.Dense(2))
model.compile(optimizer=tfa.optimizers.NovoGrad(), loss=tf.keras.losses.sparse_categorical_crossentropy)
callbacks = [tf.keras.callbacks.ModelCheckpoint("scratch")]

x = np.random.random((2, 3))
y = np.random.randint(0, 2, (2,))
model.fit(x, y, batch_size=1, epochs=1, callbacks=callbacks)

现在关闭Python会话,打开一个新的会话,并尝试加载模型:

import tensorflow as tf
model = tf.keras.models.load_model("scratch")

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/keras/saving/save.py", line 187, in load_model
    return saved_model_load.load(filepath, compile, options)
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 121, in load
    path, options=options, loader_cls=KerasObjectLoader)
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 633, in load_internal
    ckpt_options)
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 194, in __init__
    super(KerasObjectLoader, self).__init__(*args, **kwargs)
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 131, in __init__
    self._restore_checkpoint()
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 328, in _restore_checkpoint
    self._checkpoint_options).expect_partial()
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 1320, in restore
    checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 209, in restore
    restore_ops = trackable._restore_from_checkpoint_position(self)  # pylint: disable=protected-access
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 908, in _restore_from_checkpoint_position
    visit_queue=visit_queue))
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 943, in _single_restoration_from_checkpoint_position
    if child_position.bind_object(trackable=local_object):
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 258, in bind_object 
    slot_name=slot_restoration.slot_name)
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 1236, in _create_or_restore_slot_variable
    slot_variable_position.restore(slot_variable)
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 209, in restore
    restore_ops = trackable._restore_from_checkpoint_position(self)  # pylint: disable=protected-access
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 914, in _restore_from_checkpoint_position
    tensor_saveables, python_saveables))
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 297, in restore_saveables
    validated_saveables).restore(self.save_path_tensor, self.options)
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/saving/functional_saver.py", line 340, in restore
    restore_ops = restore_fn()
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/saving/functional_saver.py", line 316, in restore_fn
    restore_ops.update(saver.restore(file_prefix, options))
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/saving/functional_saver.py", line 111, in restore
    restored_tensors, restored_shapes=None)
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/saving/saveable_object_util.py", line 127, in restore
    self.handle_op, self._var_shape, restored_tensor)
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 311, in shape_safe_assign_variable_handle
    shape.assert_is_compatible_with(value_tensor.shape)
  File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/framework/tensor_shape.py", line 1134, in assert_is_compatible_with
    raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (3, 8) and () are incompatible

描述预期行为

保存的模型应该可以无错误地加载。

独立代码以重现问题

上面提供的独立代码。

其他信息/日志 包括对诊断问题有帮助的任何日志或源代码。如果包括回溯,请包括完整的回溯。大型日志和文件应附加。

ukxgm1gy

ukxgm1gy1#

我能够复现报告中提到的问题,请在此处找到nightlytf 2.4tf 2.3的摘要。

vc6uscn9

vc6uscn92#

@Feynman27 我尝试重现这个问题,但在导入TensorFlow附加组件时遇到了错误。请查看这个 gist 。谢谢!

w1e3prcc

w1e3prcc3#

对不起,这仍然是一个问题。Here是我们的参考更新的gist。谢谢!

j91ykkif

j91ykkif4#

我遇到了同样的问题。有什么更新吗?

xwmevbvl

xwmevbvl5#

请查看这个问题,谢谢!

k0pti3hp

k0pti3hp6#

解决方法是使用没有优化器的检查点编译,然后将权重保存回检查点文件。如果在之后从该文件加载,它不会崩溃。但它确实会删除优化器状态。

相关问题