请确保这是一个bug。根据我们的
GitHub Policy ,
我们只在GitHub上解决代码/文档bug、性能问题、功能请求和
构建/安装问题。标签:bug_template*
系统信息
- 我是否编写了自定义代码(与使用TensorFlow提供的库存示例脚本相反):否。
- OS平台和发行版(例如,Linux Ubuntu 16.04):Linux Ubuntu 18.04
- 如果问题发生在移动设备上,则移动设备(例如iPhone 8,Pixel 2,Samsung Galaxy):n/a
- 从哪里安装的TensorFlow(源或二进制):二进制
- TensorFlow版本(请使用以下命令):v2.3.0-54-gfcc4b966f1 2.3.1
- Python版本:3.6.10
- Bazel版本(如果从源编译):n/a
- GCC/编译器版本(如果从源编译):n/a
- CUDA/cuDNN版本:n/a
- GPU型号和内存:n/a
您可以使用我们的环境捕获 script 收集一些此信息。您还可以使用以下命令获取TensorFlow版本:
- TF 1.0:
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
- TF 2.0:
python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"
描述当前行为
我正在使用 tfa.optimizer.NovoGrad
( https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/NovoGrad )训练模型。模型似乎训练得很好,但当我尝试使用 tf.keras.models.load_model("my_model")
加载模型时,我会得到一个关于形状的 ValueError
错误。为了重现此行为,提供了一个最小示例:
# Train a small model and save it as a SavedModel
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(8))
model.add(tf.keras.layers.Dense(2))
model.compile(optimizer=tfa.optimizers.NovoGrad(), loss=tf.keras.losses.sparse_categorical_crossentropy)
callbacks = [tf.keras.callbacks.ModelCheckpoint("scratch")]
x = np.random.random((2, 3))
y = np.random.randint(0, 2, (2,))
model.fit(x, y, batch_size=1, epochs=1, callbacks=callbacks)
现在关闭Python会话,打开一个新的会话,并尝试加载模型:
import tensorflow as tf
model = tf.keras.models.load_model("scratch")
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/keras/saving/save.py", line 187, in load_model
return saved_model_load.load(filepath, compile, options)
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 121, in load
path, options=options, loader_cls=KerasObjectLoader)
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 633, in load_internal
ckpt_options)
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 194, in __init__
super(KerasObjectLoader, self).__init__(*args, **kwargs)
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 131, in __init__
self._restore_checkpoint()
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 328, in _restore_checkpoint
self._checkpoint_options).expect_partial()
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 1320, in restore
checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 209, in restore
restore_ops = trackable._restore_from_checkpoint_position(self) # pylint: disable=protected-access
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 908, in _restore_from_checkpoint_position
visit_queue=visit_queue))
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 943, in _single_restoration_from_checkpoint_position
if child_position.bind_object(trackable=local_object):
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 258, in bind_object
slot_name=slot_restoration.slot_name)
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 1236, in _create_or_restore_slot_variable
slot_variable_position.restore(slot_variable)
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 209, in restore
restore_ops = trackable._restore_from_checkpoint_position(self) # pylint: disable=protected-access
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 914, in _restore_from_checkpoint_position
tensor_saveables, python_saveables))
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 297, in restore_saveables
validated_saveables).restore(self.save_path_tensor, self.options)
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/saving/functional_saver.py", line 340, in restore
restore_ops = restore_fn()
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/saving/functional_saver.py", line 316, in restore_fn
restore_ops.update(saver.restore(file_prefix, options))
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/saving/functional_saver.py", line 111, in restore
restored_tensors, restored_shapes=None)
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/training/saving/saveable_object_util.py", line 127, in restore
self.handle_op, self._var_shape, restored_tensor)
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 311, in shape_safe_assign_variable_handle
shape.assert_is_compatible_with(value_tensor.shape)
File "/home/tblstri/anaconda3/envs/TENSORFLOW_CPU/lib/python3.6/site-packages/tensorflow/python/framework/tensor_shape.py", line 1134, in assert_is_compatible_with
raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (3, 8) and () are incompatible
描述预期行为
保存的模型应该可以无错误地加载。
独立代码以重现问题
上面提供的独立代码。
其他信息/日志 包括对诊断问题有帮助的任何日志或源代码。如果包括回溯,请包括完整的回溯。大型日志和文件应附加。
6条答案
按热度按时间ukxgm1gy1#
我能够复现报告中提到的问题,请在此处找到nightly、tf 2.4和tf 2.3的摘要。
vc6uscn92#
@Feynman27 我尝试重现这个问题,但在导入TensorFlow附加组件时遇到了错误。请查看这个 gist 。谢谢!
w1e3prcc3#
对不起,这仍然是一个问题。Here是我们的参考更新的gist。谢谢!
j91ykkif4#
我遇到了同样的问题。有什么更新吗?
xwmevbvl5#
请查看这个问题,谢谢!
k0pti3hp6#
解决方法是使用没有优化器的检查点编译,然后将权重保存回检查点文件。如果在之后从该文件加载,它不会崩溃。但它确实会删除优化器状态。