tensorflow 使用抓取器优化和(重新)保存保存的模型

pod7payv  于 2022-11-16  发布在  其他
关注(0)|答案(1)|浏览(188)

我有一个**(TF2)保存的模型**,a * 充满了训练操作混乱,我正在尝试使用grappler优化它的推理,但我想随后将其保存回TF2保存的模型(以使常规工作流远离TF1)。
我目前拥有:

from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
from tensorflow.lite.python.util import run_graph_optimizations, get_grappler_config

# Load the saved-model and get the inference concrete function
sm = tf.saved_model.load('path/to/savedmodel/dir')
func = sm.signatures['serving_default']

# Replace variables with constants in order to get rid of the training clutter
frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(func)

# Use grappler to optimize the concrete function graph after replacing vars with constants
input_tensors = [tsr for tsr in frozen_func.inputs if tsr.dtype != tf.resource]
output_tensors = frozen_func.outputs
graph_def = run_graph_optimizations(graph_def, input_tensors, output_tensors,
                                    config=get_grappler_config(["constfold", "function"]),
                                    graph=frozen_func.graph)

# Here the intention is to somehow reconvert the optimized graph-def into a concrete function
# and subsequently re-save that as a TF2(not TF1!) saved-model, is there a way to do that?
frozen_func_graph = tf.Graph()
with frozen_func_graph.as_default():
    tf.import_graph_def(graph_def, name='')

# ... what now?

问题是,由于direct tf.Graph的使用在TF2中已被弃用,我打算将优化后的图转换回TF2保存模型。我想通过手动构造一个 Package 此优化图的ConcreteFunction来实现这一点,但就我所研究的而言,似乎现在有一种方法可以实现这一点。这基本上意味着我仍必须使用TF1 compat API,这是我最想避免的
我真正想避免的丑陋(丑陋)选项是(还没有尝试过,但可能会起作用):

  • 使用v1 API构造TF1保存模型,方法是使用tf.compat.v1.saved_model.builder.SavedModelBuilder保存TF1保存模型
    *使用v2 API加载回TF1保存模型(因此使用tf.saved_model.load而不是tf.compat.v1.saved_model.load,前者将TF1保存模型自动转换为TF2保存模型)
  • (重新)保存转换的TF2保存模型

有没有一种方法可以很好地做到这一点?如果我不想的话,最好也不要被迫转储优化的保存模型,似乎在内存中构建保存模型是不可能的?(虽然这不是一个大问题)

mbzjlibv

mbzjlibv1#

终于找到了,不太理想,因为我使用内部(或多或少)TF2 API调用,但至少TF1 compat API根本没有使用。

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
from tensorflow.lite.python.util import run_graph_optimizations, get_grappler_config
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
from tensorflow.python.eager import context, wrap_function

# Load the saved-model and get the inference concrete function
sm = tf.saved_model.load('path/to/savedmodel/dir')
func = sm.signatures['serving_default'] # note: key might differ according to what your model's inference function is

# Replace variables with constants in order to get rid of the training clutter
frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(func)

# Use grappler to optimize the concrete function graph after replacing vars with constants
input_tensors = [tsr for tsr in frozen_func.inputs if tsr.dtype != tf.resource]
output_tensors = frozen_func.outputs
graph_def = run_graph_optimizations(graph_def, input_tensors, output_tensors,
                                    config=get_grappler_config(["constfold", "function"]),
                                    graph=frozen_func.graph)

# Optimize for inference
input_tsr_names = [tsr.name for tsr in input_tensors]
output_tsr_names = [tsr.name for tsr in output_tensors]
input_node_names = list(set([tsr_name.rsplit(':', 1)[0] for tsr_name in input_tsr_names]))
output_node_names = list(set([tsr_name.rsplit(':', 1)[0] for tsr_name in output_tsr_names]))
graph_def = optimize_for_inference(input_graph_def=graph_def,
                                   input_node_names=input_node_names,
               placeholder_type_enum=tf.dtypes.float32.as_datatype_enum,
                                   output_node_names=output_node_names,
                                   toco_compatible=True)

# This next part inspired from _construct_concrete_function function here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/convert_to_constants.py#L1062

# Remove old functions to use updated functions from graph def - not sure if this is actually needed here, didn't look into it
for f in graph_def.library.function:
    if context.context().has_function(f.signature.name):
        context.context().remove_function(f.signature.name)

# GraphDef to concrete function
opt_frozen_func = wrap_function.function_from_graph_def(graph_def,
                                                        input_tsr_names,
                                                        output_tsr_names)

# Wrap concrete function into module to export as saved-model
class OptimizedFrozenModel(tf.Module):
    def __init__(self, name=None):
        super().__init__(name)

module = OptimizedFrozenModel()
module.__call__ = opt_frozen_func

# Export frozen & optimized saved-model
tf.saved_model.save(module, 'path/to/optimized_savedmodel/dir', signatures=opt_frozen_func)

相关问题