对于我的项目,我试图根据存储在saved_model. pb中的训练模型进行推理。我怀疑这个错误是由于我的代码,你可以在这里看到,但更有可能是由于安装问题:
from PIL import Image
import numpy as np
import scipy
from scipy import misc
import matplotlib.pyplot as plt
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
with tf.Graph().as_default() as graph: # Set default graph as graph
with tf.Session() as sess:
# Load the graph in graph_def
print("load graph")
# We load the protobuf file from the disk and parse it to retrive the unserialized graph_drf
with gfile.FastGFile("saved_model.pb",'rb') as f:
from scipy.io import wavfile
samplerate, data = wavfile.read('sound.wav')
# Set FCN graph to the default graph
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
# Import a graph_def into the current default Graph (In this case, the weights are (typically) embedded in the graph)
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name="",
op_dict=None,
producer_op_list=None
)
# Print the name of operations in the session
for op in graph.get_operations():
print("Operation Name :",op.name) # Operation name
print("Tensor Stats :",str(op.values())) # Tensor name
# INFERENCE Here
l_input = graph.get_tensor_by_name('Inputs/fifo_queue_Dequeue:0') # Input Tensor
l_output = graph.get_tensor_by_name('upscore32/conv2d_transpose:0') # Output Tensor
print("Shape of input : ", tf.shape(l_input))
f.global_variables_initializer()
# Run model on single image
Session_out = sess.run( m_output, feed_dict = {m_input : data} )
print("Predicted class:", class_names[Session_out[0].argmax()] )
追溯如下
Traceback (most recent call last):
File "/home/pi/model_inference/test.py", line 11, in <module>
graph_def.ParseFromString(f.read())
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/message.py", line 199, in ParseFromString
return self.MergeFromString(serialized)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1145, in MergeFromString
if self._InternalParse(serialized, 0, length) != length:
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1212, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 754, in DecodeField
if value._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1212, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 733, in DecodeRepeatedField
if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1212, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 888, in DecodeMap
if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1199, in InternalParse
buffer, new_pos, wire_type) # pylint: disable=protected-access
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 989, in _DecodeUnknownField
(data, pos) = _DecodeUnknownFieldSet(buffer, pos)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 968, in _DecodeUnknownFieldSet
(data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 993, in _DecodeUnknownField
raise _DecodeError('Wrong wire type in tag.')
google.protobuf.message.DecodeError: Wrong wire type in tag.
值得注意的是,我正在raspberry pi v4上尝试这个(因此linux运行在它上面)。我会很高兴任何提示该怎么做。先谢谢你了!
3条答案
按热度按时间dohp0rv51#
看起来您的文件
"saved_model.pb"
不是消息类型GraphDef
的已保存(wireformat)protobuffer。也许你可以看看它是如何保存的,并找到一些关于如何加载它回来的说明?只是从名字上猜测,它会不会是一个keras模型,你必须使用tf.keras.models.load_model
?bttbmeg02#
确保对象是对象而不是非类型对象。
dgenwo3n3#
尝试使用
tf.saved_model.load(path)
,其中path
是保存模型的文件夹(包含assets and variables文件夹)的路径。请点击此链接访问tensorflow对象检测API推理教程。