keras 使用tensorflow:google.protobuf.message.DecodeError:标记中的导线类型错误

5w9g7ksd  于 2023-06-30  发布在  Go
关注(0)|答案(3)|浏览(154)

对于我的项目,我试图根据存储在saved_model. pb中的训练模型进行推理。我怀疑这个错误是由于我的代码,你可以在这里看到,但更有可能是由于安装问题:

from PIL import Image
import numpy as np
import scipy
from scipy import misc
import matplotlib.pyplot as plt
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

with tf.Graph().as_default() as graph: # Set default graph as graph

           with tf.Session() as sess:
                # Load the graph in graph_def
                print("load graph")

                # We load the protobuf file from the disk and parse it to retrive the unserialized graph_drf
                with gfile.FastGFile("saved_model.pb",'rb') as f:

                                from scipy.io import wavfile
                                samplerate, data = wavfile.read('sound.wav')

                                # Set FCN graph to the default graph
                                graph_def = tf.GraphDef()
                                graph_def.ParseFromString(f.read())
                                sess.graph.as_default()

                                # Import a graph_def into the current default Graph (In this case, the weights are (typically) embedded in the graph)

                                tf.import_graph_def(
                                graph_def,
                                input_map=None,
                                return_elements=None,
                                name="",
                                op_dict=None,
                                producer_op_list=None
                                )

                                # Print the name of operations in the session
                                for op in graph.get_operations():
                                        print("Operation Name :",op.name)         # Operation name
                                        print("Tensor Stats :",str(op.values()))     # Tensor name

                                # INFERENCE Here
                                l_input = graph.get_tensor_by_name('Inputs/fifo_queue_Dequeue:0') # Input Tensor
                                l_output = graph.get_tensor_by_name('upscore32/conv2d_transpose:0') # Output Tensor

                                print("Shape of input : ", tf.shape(l_input))
                                
                                f.global_variables_initializer()

                                # Run model on single image
                                Session_out = sess.run( m_output, feed_dict = {m_input : data} )

                                print("Predicted class:", class_names[Session_out[0].argmax()] )

追溯如下

Traceback (most recent call last):
  File "/home/pi/model_inference/test.py", line 11, in <module>
    graph_def.ParseFromString(f.read())
  File "/usr/local/lib/python3.7/dist-packages/google/protobuf/message.py", line 199, in ParseFromString
    return self.MergeFromString(serialized)
  File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1145, in MergeFromString
    if self._InternalParse(serialized, 0, length) != length:
  File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1212, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 754, in DecodeField
    if value._InternalParse(buffer, pos, new_pos) != new_pos:
  File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1212, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 733, in DecodeRepeatedField
    if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
  File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1212, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 888, in DecodeMap
    if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
  File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1199, in InternalParse
    buffer, new_pos, wire_type)  # pylint: disable=protected-access
  File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 989, in _DecodeUnknownField
    (data, pos) = _DecodeUnknownFieldSet(buffer, pos)
  File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 968, in _DecodeUnknownFieldSet
    (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
  File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 993, in _DecodeUnknownField
    raise _DecodeError('Wrong wire type in tag.')
google.protobuf.message.DecodeError: Wrong wire type in tag.

值得注意的是,我正在raspberry pi v4上尝试这个(因此linux运行在它上面)。我会很高兴任何提示该怎么做。先谢谢你了!

dohp0rv5

dohp0rv51#

看起来您的文件"saved_model.pb"不是消息类型GraphDef的已保存(wireformat)protobuffer。也许你可以看看它是如何保存的,并找到一些关于如何加载它回来的说明?只是从名字上猜测,它会不会是一个keras模型,你必须使用tf.keras.models.load_model

bttbmeg0

bttbmeg02#

确保对象是对象而不是非类型对象。

dgenwo3n

dgenwo3n3#

尝试使用tf.saved_model.load(path),其中path是保存模型的文件夹(包含assets and variables文件夹)的路径。请点击此链接访问tensorflow对象检测API推理教程。

相关问题