我一直在关注Tensorflow text generation tutorial,它包括两个模型,“MyModel”和“OneStep”。“MyModel”是一个对向量化字符串进行操作的RNN;“OneStep”本质上 Package 了“MyModel”并直接对字符串进行操作。
教程保存并加载了“OneStep”,我成功地遵循了这一步骤,但现在我想保存并重新加载“MyModel”,这在教程中没有完成,当我试图用return_state=True
调用重新加载的模型时,我得到了一个错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_23/2335414736.py in <module>
1 # TODO: Loaded model gives an error
2 for input_example_batch, target_example_batch in train_ds.take(1):
----> 3 example_batch_predictions, example_states = loaded_model(input_example_batch, False, None, return_state=True)
4 print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
5 print(example_states.shape, " # (batch_size, rnn_units)")
/opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in _call_attribute(instance, *args, **kwargs)
662
663 def _call_attribute(instance, *args, **kwargs):
--> 664 return instance.__call__(*args, **kwargs)
665
666
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
883
884 with OptionalXlaContext(self._jit_compile):
--> 885 result = self._call(*args, **kwds)
886
887 new_tracing_count = self.experimental_get_tracing_count()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
931 # This is the first call of __call__, so we have to initialize.
932 initializers = []
--> 933 self._initialize(args, kwds, add_initializers_to=initializers)
934 finally:
935 # At this point we know that the initialization is complete (or less
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
758 self._concrete_stateful_fn = (
759 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
--> 760 *args, **kwds))
761
762 def invalid_creator_scope(*unused_args, **unused_kwds):
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
3064 args, kwargs = None, None
3065 with self._lock:
-> 3066 graph_function, _ = self._maybe_define_function(args, kwargs)
3067 return graph_function
3068
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3461
3462 self._function_cache.missed.add(call_context_key)
-> 3463 graph_function = self._create_graph_function(args, kwargs)
3464 self._function_cache.primary[cache_key] = graph_function
3465
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3306 arg_names=arg_names,
3307 override_flat_arg_shapes=override_flat_arg_shapes,
-> 3308 capture_by_value=self._capture_by_value),
3309 self._function_attributes,
3310 function_spec=self.function_spec,
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
1005 _, original_func = tf_decorator.unwrap(python_func)
1006
-> 1007 func_outputs = python_func(*func_args, **func_kwargs)
1008
1009 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
666 # the function a weak reference to itself to avoid a reference cycle.
667 with OptionalXlaContext(compile_with_xla):
--> 668 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
669 return out
670
/opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/function_deserialization.py in restored_function_body(*args, **kwargs)
292 .format(_pretty_format_positional(args), kwargs,
293 len(saved_function.concrete_functions),
--> 294 "\n\n".join(signature_descriptions)))
295
296 concrete_function_objects = []
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
Positional arguments (4 total):
* Tensor("inputs:0", shape=(64, 113), dtype=int64)
* False
* None
* True
Keyword arguments: {}
Expected these arguments to match one of the following 4 option(s):
Option 1:
Positional arguments (4 total):
* TensorSpec(shape=(None, 113), dtype=tf.int64, name='input_1')
* False
* None
* False
Keyword arguments: {}
Option 2:
Positional arguments (4 total):
* TensorSpec(shape=(None, 113), dtype=tf.int64, name='inputs')
* False
* None
* False
Keyword arguments: {}
Option 3:
Positional arguments (4 total):
* TensorSpec(shape=(None, 113), dtype=tf.int64, name='inputs')
* True
* None
* False
Keyword arguments: {}
Option 4:
Positional arguments (4 total):
* TensorSpec(shape=(None, 113), dtype=tf.int64, name='input_1')
* True
* None
* False
Keyword arguments: {}
我认为这是由于调用方法中的自定义参数造成的。下面是一个重现问题的最小示例:
import tensorflow as tf
class CustomModel(tf.keras.models.Model):
def __init__(self):
super().__init__()
self.dense = tf.keras.layers.Dense(10)
def call(self, inputs, custom_param=False):
return self.dense(inputs)
model = CustomModel()
sample_inputs = tf.zeros((16, 30))
print('Sample inputs:', sample_inputs)
sample_outputs = model(sample_inputs)
print('Sample outputs:', sample_outputs)
model.save('saved_model')
loaded_model = tf.keras.models.load_model('saved_model')
sample_outputs_2 = loaded_model(sample_inputs, custom_param=True)
print('Sample outputs 2:', sample_outputs_2)
用custom_param
调用重载的模型,而custom_param
取默认值以外的任何值,似乎总是失败。
这是一个bug还是设计使然?我如何修改模型,使其在训练时只返回输出序列,而在推理时返回输出序列和状态?这样我就可以将状态反馈到模型中,并在推理时生成额外的字符。
1条答案
按热度按时间lyr7nygr1#
我想通了。解决方案在TensorFlow docs中,尽管不是很清楚。
对于上面的代码,加载的模型是
keras.saving.saved_model.load.CustomModel
类型,这与原始类型不同。要恢复原始类型,需要执行以下操作。CustomModel
类需要get_config
和from_config
方法。加载模型时,需要在
custom_objects
字典中传递定制类。那么,
loaded_model
的类型就是CustomModel
,用custom_param=True
调用它是有效的。