下面是我在java中加载tensorflow模型用于对象检测的代码,并在storm中执行相同的代码:
package object_det.object_det;
import java.awt.BorderLayout;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.IOException;
import java.io.PrintStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map.Entry;
import javax.swing.ImageIcon;
import javax.swing.JFrame;
import javax.swing.JLabel;
import org.apache.storm.LocalCluster;
import org.apache.storm.kafka.BrokerHosts;
import org.apache.storm.kafka.KafkaSpout;
import org.apache.storm.kafka.SpoutConfig;
import org.apache.storm.kafka.StringScheme;
import org.apache.storm.kafka.ZkHosts;
import org.apache.storm.spout.SchemeAsMultiScheme;
import org.apache.storm.topology.BasicOutputCollector;
import org.apache.storm.topology.OutputFieldsDeclarer;
import org.apache.storm.topology.TopologyBuilder;
import org.apache.storm.topology.base.BaseBasicBolt;
import org.apache.storm.tuple.Tuple;
import org.json.simple.JSONObject;
import org.json.simple.parser.JSONParser;
import org.json.simple.parser.ParseException;
import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.imgproc.Imgproc;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;
import org.tensorflow.types.UInt8;
import com.google.protobuf.TextFormat;
import object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap;
import object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem;
public class Objectdetect1 {
static int left = 0;
static int bot = 0;
static int top = 0;
static int right = 0;
public static class PrinterBolt extends BaseBasicBolt {
int ii = 1;
JFrame frame=new JFrame();
JLabel jLabel = new JLabel();
static {
nu.pattern.OpenCV.loadShared();
}
public void declareOutputFields(OutputFieldsDeclarer declarer) {
}
public void execute(Tuple tuple, BasicOutputCollector collector) {
String output = tuple.getString(0);
output = output.replaceAll("\\[", "").replaceAll("\\]","");
JSONParser parser = new JSONParser();
JSONObject json = null;
try {
json = (JSONObject) parser.parse(output);
} catch (ParseException e) {
e.printStackTrace();
}
long rows = (Long) json.get("rows");
int row=(int)rows;
long columns = (Long) json.get("cols");
int cols=(int)columns;
long type_data = (Long) json.get("type");
int typedata=(int)type_data;
String base64 = (String) json.get("data");
String cameraId = (String) json.get("cameraId");
String timestamp = (String) json.get("timestamp");
Mat mat1 = new Mat(row,cols, typedata);
mat1.put(0, 0, Base64.getDecoder().decode(base64));
String[] labels = null;
try {
labels = loadLabels("label.pbtxt");
} catch (Exception e) {
e.printStackTrace();
}
try (SavedModelBundle model = SavedModelBundle.load("model", "serve")) {
printSignature(model);
List<Tensor<?>> outputs = null;
try (Tensor<UInt8> input = makeImageTensor(mat1))
{
outputs =
model
.session()
.runner()
.feed("image_tensor", input)
.fetch("detection_scores")
.fetch("detection_classes")
.fetch("detection_boxes")
.run();
}
try (Tensor<Float> scoresT = outputs.get(0).expect(Float.class);
Tensor<Float> classesT = outputs.get(1).expect(Float.class);
Tensor<Float> boxesT = outputs.get(2).expect(Float.class)) {
int maxObjects = (int) scoresT.shape()[1];
float[] scores = scoresT.copyTo(new float[1][maxObjects])[0];
float[] classes = classesT.copyTo(new float[1][maxObjects])[0];
float[][] boxes = boxesT.copyTo(new float[1][maxObjects][4])[0];
boolean foundSomething = false;
int cnt = 0;
for (int i = 0; i < scores.length; ++i) {
if (scores[i] < 0.5) {
continue;
}
cnt ++;
foundSomething = true;
System.out.printf("\tFound %-20s (score: %.4f)\n", labels[(int) classes[i]], scores[i]);
left = (int) Math.round(boxes[i][1] * cols);
top = (int) Math.round(boxes[i][0] * row);
right = (int) Math.round(boxes[i][3] * cols);
bot = (int) Math.round(boxes[i][2] * row);
Imgproc.rectangle(mat1, new Point(right,bot), new Point(left,top),new Scalar(0,69,255),2);
Imgproc.putText(mat1,labels[(int) classes[i]] , new Point(left,top), Core.FONT_HERSHEY_PLAIN, 1.6, new Scalar(240,248,255),2);
BufferedImage bimg = bufferedImage(mat1);
ImageIcon imageIcon = new ImageIcon(bimg);
jLabel.setIcon(imageIcon);
frame.getContentPane().add(jLabel, BorderLayout.CENTER);
frame.pack();
frame.setLocationRelativeTo(null);
frame.setVisible(true);
}
if (!foundSomething) {
System.out.println("No objects detected with a high enough score.");
BufferedImage bimg = bufferedImage(mat1);
ImageIcon imageIcon = new ImageIcon(bimg);
jLabel.setIcon(imageIcon);
frame.getContentPane().add(jLabel, BorderLayout.CENTER);
frame.pack();
frame.setLocationRelativeTo(null);
frame.setVisible(true);
}
}
}
catch (Exception e) {
e.printStackTrace();
}
}
private static void printSignature(SavedModelBundle model) throws Exception {
MetaGraphDef m = MetaGraphDef.parseFrom(model.metaGraphDef());
SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
int numInputs = sig.getInputsCount();
int i = 1;
System.out.println("MODEL SIGNATURE");
System.out.println("Inputs:");
for (Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
}
int numOutputs = sig.getOutputsCount();
i = 1;
System.out.println("Outputs:");
for (Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
}
System.out.println("-----------------------------------------------");
}
private static String[] loadLabels(String filename) throws Exception {
String text = new String(Files.readAllBytes(Paths.get(filename)), StandardCharsets.UTF_8);
StringIntLabelMap.Builder builder = StringIntLabelMap.newBuilder();
TextFormat.merge(text, builder);
StringIntLabelMap proto = builder.build();
int maxId = 0;
for (StringIntLabelMapItem item : proto.getItemList()) {
if (item.getId() > maxId) {
maxId = item.getId();
}
}
String[] ret = new String[maxId + 1];
for (StringIntLabelMapItem item : proto.getItemList()) {
ret[item.getId()] = item.getDisplayName();
}
return ret;
}
private static void bgr2rgb(byte[] data) {
for (int i = 0; i < data.length; i += 3) {
byte tmp = data[i];
data[i] = data[i + 2];
data[i + 2] = tmp;
}
}
public static BufferedImage bufferedImage(Mat m) {
int type = BufferedImage.TYPE_BYTE_GRAY;
if ( m.channels() > 1 ) {
type = BufferedImage.TYPE_3BYTE_BGR;
}
BufferedImage image = new BufferedImage(m.cols(),m.rows(), type);
m.get(0,0,((DataBufferByte)image.getRaster().getDataBuffer()).getData()); // get all the pixels
return image;
}
private static Tensor<UInt8> makeImageTensor(Mat m) throws IOException {
BufferedImage img = bufferedImage(m);
if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) {
throw new IOException(
String.format(
"Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust",
img.getType()));
}
byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData();
bgr2rgb(data);
final long BATCH_SIZE = 1;
final long CHANNELS = 3;
long[] shape = new long[] {BATCH_SIZE, img.getHeight(), img.getWidth(), CHANNELS};
return Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data));
}
private static void printUsage(PrintStream s) {
s.println("USAGE: <model> <label_map> <image> [<image>] [<image>]");
s.println("");
s.println("Where");
s.println("<model> is the path to the SavedModel directory of the model to use.");
s.println(" For example, the saved_model directory in tarballs from ");
s.println(
" https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md)");
s.println("");
s.println(
"<label_map> is the path to a file containing information about the labels detected by the model.");
s.println(" For example, one of the .pbtxt files from ");
s.println(
" https://github.com/tensorflow/models/tree/master/research/object_detection/data");
s.println("");
s.println("<image> is the path to an image file.");
s.println(" Sample images can be found from the COCO, Kitti, or Open Images dataset.");
s.println(
" See: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md");
}
}
public static void main(String[] args) {
final BrokerHosts zkrHosts = new ZkHosts(args[0]);
final String kafkaTopic = args[1];
final String zkRoot = args[2];
final String clientId = args[3];
final SpoutConfig kafkaConf = new SpoutConfig(zkrHosts, kafkaTopic, zkRoot, clientId);
kafkaConf.fetchSizeBytes = 30971520;
kafkaConf.scheme = new SchemeAsMultiScheme(new StringScheme());
final TopologyBuilder topologyBuilder = new TopologyBuilder();
topologyBuilder.setSpout("kafka-spout", new KafkaSpout(kafkaConf), 1);
topologyBuilder.setBolt("print-messages", new PrinterBolt()).shuffleGrouping("kafka-spout");
final LocalCluster localCluster = new LocalCluster();
localCluster.submitTopology("kafka-topology", new HashMap<Object, Object>(), topologyBuilder.createTopology());
}
}
上面的代码使用
mvn清洁安装shade:shade
当这个jar被提交到单节点集群时
storm-jar file.jar object\u det.object\u det.objectdetect1 zookeeperhost:2181 topicname /经纪人测试
代码成功执行。
但是当在多节点hadoop集群中提交相同的jar时,会显示以下错误
Running: /usr/jdk64/jdk1.8.0_112/bin/java -server -Ddaemon.name=
-Dstorm.options= -Dstorm.home=/usr/hdp/3.1.0.0-78/storm -Dstorm.log.dir=/var/log/storm -Djava.library.path=/usr/local/lib:/opt/local/lib:/usr/lib -Dstorm.conf.file= -cp /usr/hdp/3.1.0.0-78/storm/*:/usr/hdp/3.1.0.0-78/storm/lib/*:/usr/hdp/3.1.0.0-78/storm/extlib/* org.apache.storm.daemon.ClientJarTransformerRunner org.apache.storm.hack.StormShadeTransformer strmjr2-0.0.1-SNAPSHOT.jar /tmp/011dcea098a811e9b8d1f9e5e43755af.jar Exception in thread "main" java.lang.IllegalArgumentException at org.apache.storm.hack.shade.org.objectweb.asm.ClassReader.<init>(Unknown Source) at org.apache.storm.hack.shade.org.objectweb.asm.ClassReader.<init>(Unknown Source) at org.apache.storm.hack.shade.org.objectweb.asm.ClassReader.<init>(Unknown Source) at org.apache.storm.hack.DefaultShader.addRemappedClass(DefaultShader.java:182) at org.apache.storm.hack.DefaultShader.shadeJarStream(DefaultShader.java:103) at org.apache.storm.hack.StormShadeTransformer.transform(StormShadeTransformer.java:35) at org.apache.storm.daemon.ClientJarTransformerRunner.main(ClientJarTransformerRunner.java:37) Running: /usr/jdk64/jdk1.8.0_112/bin/java -Ddaemon.name=
-Dstorm.options= -Dstorm.home=/usr/hdp/3.1.0.0-78/storm -Dstorm.log.dir=/var/log/storm -Djava.library.path=/usr/local/lib:/opt/local/lib:/usr/lib -Dstorm.conf.file= -cp /usr/hdp/3.1.0.0-78/storm/*:/usr/hdp/3.1.0.0-78/storm/lib/*:/usr/hdp/3.1.0.0-78/storm/extlib/*:/tmp/011dcea098a811e9b8d1f9e5e43755af.jar:/usr/hdp/current/storm-supervisor/conf:/usr/hdp/3.1.0.0-78/storm/bin
-Dstorm.jar=/tmp/011dcea098a811e9b8d1f9e5e43755af.jar -Dstorm.dependency.jars= -Dstorm.dependency.artifacts={} artifactid.groupid.main
Error: Could not find or load main class
是否可以在hadoop集群中使用tensorflow模型运行storm拓扑。如果可以,请提供帮助。
暂无答案!
目前还没有任何答案,快来回答吧!