如何在hadoop集群中使用tensorflow模型运行storm拓扑

j0pj023g  于 2021-05-29  发布在  Hadoop
关注(0)|答案(0)|浏览(202)

下面是我在java中加载tensorflow模型用于对象检测的代码,并在storm中执行相同的代码:

package object_det.object_det;

import java.awt.BorderLayout;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.IOException;
import java.io.PrintStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map.Entry;
import javax.swing.ImageIcon;
import javax.swing.JFrame;
import javax.swing.JLabel;
import org.apache.storm.LocalCluster;
import org.apache.storm.kafka.BrokerHosts;
import org.apache.storm.kafka.KafkaSpout;
import org.apache.storm.kafka.SpoutConfig;
import org.apache.storm.kafka.StringScheme;
import org.apache.storm.kafka.ZkHosts;
import org.apache.storm.spout.SchemeAsMultiScheme;
import org.apache.storm.topology.BasicOutputCollector;
import org.apache.storm.topology.OutputFieldsDeclarer;
import org.apache.storm.topology.TopologyBuilder;
import org.apache.storm.topology.base.BaseBasicBolt;
import org.apache.storm.tuple.Tuple;
import org.json.simple.JSONObject;
import org.json.simple.parser.JSONParser;
import org.json.simple.parser.ParseException;
import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.imgproc.Imgproc;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;
import org.tensorflow.types.UInt8;
import com.google.protobuf.TextFormat;
import object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap;
import object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem;

public class Objectdetect1 {

static int left = 0;
    static int bot = 0;
    static int top = 0;
    static int right = 0;

public static class PrinterBolt extends BaseBasicBolt {
                int ii = 1;
                JFrame frame=new JFrame();
                JLabel jLabel = new JLabel();       

       static {
                 nu.pattern.OpenCV.loadShared();
              }

public void declareOutputFields(OutputFieldsDeclarer declarer) {
      }

      public void execute(Tuple tuple, BasicOutputCollector collector) {
                String output = tuple.getString(0);
                output = output.replaceAll("\\[", "").replaceAll("\\]","");

                JSONParser parser = new JSONParser();
                JSONObject json = null;
                try {
                    json = (JSONObject) parser.parse(output);
                    } catch (ParseException e) {
                    e.printStackTrace();
                    }
                long rows = (Long) json.get("rows");
                int row=(int)rows;  
                long columns = (Long) json.get("cols");
                int cols=(int)columns;
                long type_data = (Long) json.get("type");
                int typedata=(int)type_data;
                String base64 = (String) json.get("data");
                String cameraId = (String) json.get("cameraId");
                String timestamp = (String) json.get("timestamp");
                Mat mat1 = new Mat(row,cols, typedata);
                mat1.put(0, 0, Base64.getDecoder().decode(base64)); 
                String[] labels = null;
                try {
                    labels = loadLabels("label.pbtxt");
                    } catch (Exception e) {
                    e.printStackTrace();
                    }

                try (SavedModelBundle model = SavedModelBundle.load("model", "serve")) {   
                printSignature(model);

                List<Tensor<?>> outputs = null;           
                try (Tensor<UInt8> input = makeImageTensor(mat1))
                {

                outputs =
                    model
                        .session()
                        .runner()
                        .feed("image_tensor", input)
                        .fetch("detection_scores")
                        .fetch("detection_classes")
                        .fetch("detection_boxes")
                        .run();
              }
             try (Tensor<Float> scoresT = outputs.get(0).expect(Float.class);
                  Tensor<Float> classesT = outputs.get(1).expect(Float.class);
                  Tensor<Float> boxesT = outputs.get(2).expect(Float.class)) {

             int maxObjects = (int) scoresT.shape()[1];
             float[] scores = scoresT.copyTo(new float[1][maxObjects])[0];
             float[] classes = classesT.copyTo(new float[1][maxObjects])[0];
             float[][] boxes = boxesT.copyTo(new float[1][maxObjects][4])[0];

             boolean foundSomething = false;
             int cnt = 0;
             for (int i = 0; i < scores.length; ++i) {
                  if (scores[i] < 0.5) {
                    continue;
                  }
                  cnt ++;
                  foundSomething = true;
                  System.out.printf("\tFound %-20s (score: %.4f)\n", labels[(int) classes[i]], scores[i]);
                  left = (int) Math.round(boxes[i][1] * cols);
                  top = (int) Math.round(boxes[i][0] * row);
                  right = (int) Math.round(boxes[i][3] * cols);
                  bot = (int) Math.round(boxes[i][2] * row);
                  Imgproc.rectangle(mat1, new Point(right,bot), new Point(left,top),new Scalar(0,69,255),2);
                  Imgproc.putText(mat1,labels[(int) classes[i]] , new Point(left,top), Core.FONT_HERSHEY_PLAIN, 1.6, new Scalar(240,248,255),2);

                  BufferedImage bimg = bufferedImage(mat1);

                ImageIcon imageIcon = new ImageIcon(bimg);
                jLabel.setIcon(imageIcon);
                frame.getContentPane().add(jLabel, BorderLayout.CENTER);
                frame.pack();
                frame.setLocationRelativeTo(null);
                frame.setVisible(true);     
                                      }

            if (!foundSomething) {
                  System.out.println("No objects detected with a high enough score.");                      
                  BufferedImage bimg = bufferedImage(mat1);                      
                  ImageIcon imageIcon = new ImageIcon(bimg);
                  jLabel.setIcon(imageIcon);
                  frame.getContentPane().add(jLabel, BorderLayout.CENTER);
                  frame.pack();
                  frame.setLocationRelativeTo(null);
                  frame.setVisible(true);
                }

              }

            }

 catch (Exception e) {
        e.printStackTrace();
} 
         }

 private static void printSignature(SavedModelBundle model) throws Exception {
            MetaGraphDef m = MetaGraphDef.parseFrom(model.metaGraphDef());
            SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
            int numInputs = sig.getInputsCount();
            int i = 1;
            System.out.println("MODEL SIGNATURE");
            System.out.println("Inputs:");
            for (Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
              TensorInfo t = entry.getValue();
              System.out.printf(
                  "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                  i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
            }
            int numOutputs = sig.getOutputsCount();
            i = 1;
            System.out.println("Outputs:");
            for (Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
              TensorInfo t = entry.getValue();
              System.out.printf(
                  "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                  i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
            }
            System.out.println("-----------------------------------------------");
          }

          private static String[] loadLabels(String filename) throws Exception {
            String text = new String(Files.readAllBytes(Paths.get(filename)), StandardCharsets.UTF_8);
            StringIntLabelMap.Builder builder = StringIntLabelMap.newBuilder();
            TextFormat.merge(text, builder);
            StringIntLabelMap proto = builder.build();
            int maxId = 0;
            for (StringIntLabelMapItem item : proto.getItemList()) {
              if (item.getId() > maxId) {
                maxId = item.getId();
              }
            }
            String[] ret = new String[maxId + 1];
            for (StringIntLabelMapItem item : proto.getItemList()) {
              ret[item.getId()] = item.getDisplayName();
            }
            return ret;
          }

          private static void bgr2rgb(byte[] data) {
            for (int i = 0; i < data.length; i += 3) {
              byte tmp = data[i];
              data[i] = data[i + 2];
              data[i + 2] = tmp;
            }
          }

          public static BufferedImage bufferedImage(Mat m) {
                int type = BufferedImage.TYPE_BYTE_GRAY;
                if ( m.channels() > 1 ) {
                    type = BufferedImage.TYPE_3BYTE_BGR;
                }
                BufferedImage image = new BufferedImage(m.cols(),m.rows(), type);
                m.get(0,0,((DataBufferByte)image.getRaster().getDataBuffer()).getData()); // get all the pixels
                return image;
            } 

          private static Tensor<UInt8> makeImageTensor(Mat m) throws IOException {
            BufferedImage img = bufferedImage(m);
            if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) {
              throw new IOException(
                  String.format(
                      "Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust",
                      img.getType()));
            }
            byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData();
                        bgr2rgb(data);
            final long BATCH_SIZE = 1;
            final long CHANNELS = 3;
            long[] shape = new long[] {BATCH_SIZE, img.getHeight(), img.getWidth(), CHANNELS};
            return Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data));
          }

          private static void printUsage(PrintStream s) {
            s.println("USAGE: <model> <label_map> <image> [<image>] [<image>]");
            s.println("");
            s.println("Where");
            s.println("<model> is the path to the SavedModel directory of the model to use.");
            s.println("        For example, the saved_model directory in tarballs from ");
            s.println(
                "        https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md)");
            s.println("");
            s.println(
                "<label_map> is the path to a file containing information about the labels detected by the model.");
            s.println("            For example, one of the .pbtxt files from ");
            s.println(
                "            https://github.com/tensorflow/models/tree/master/research/object_detection/data");
            s.println("");
            s.println("<image> is the path to an image file.");
            s.println("        Sample images can be found from the COCO, Kitti, or Open Images dataset.");
            s.println(
                "        See: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md");
          }  

     }

public static void main(String[] args) {
final BrokerHosts zkrHosts = new ZkHosts(args[0]);
        final String kafkaTopic = args[1];
        final String zkRoot = args[2];
        final String clientId = args[3];        
        final SpoutConfig kafkaConf = new SpoutConfig(zkrHosts, kafkaTopic, zkRoot, clientId);
        kafkaConf.fetchSizeBytes = 30971520;
        kafkaConf.scheme = new SchemeAsMultiScheme(new StringScheme());
        final TopologyBuilder topologyBuilder = new TopologyBuilder();
        topologyBuilder.setSpout("kafka-spout", new KafkaSpout(kafkaConf), 1);
        topologyBuilder.setBolt("print-messages", new PrinterBolt()).shuffleGrouping("kafka-spout");
        final LocalCluster localCluster = new LocalCluster();
        localCluster.submitTopology("kafka-topology", new HashMap<Object, Object>(), topologyBuilder.createTopology());

}
}

上面的代码使用
mvn清洁安装shade:shade
当这个jar被提交到单节点集群时
storm-jar file.jar object\u det.object\u det.objectdetect1 zookeeperhost:2181 topicname /经纪人测试
代码成功执行。
但是当在多节点hadoop集群中提交相同的jar时,会显示以下错误

Running: /usr/jdk64/jdk1.8.0_112/bin/java -server -Ddaemon.name=
            -Dstorm.options= -Dstorm.home=/usr/hdp/3.1.0.0-78/storm -Dstorm.log.dir=/var/log/storm -Djava.library.path=/usr/local/lib:/opt/local/lib:/usr/lib -Dstorm.conf.file= -cp /usr/hdp/3.1.0.0-78/storm/*:/usr/hdp/3.1.0.0-78/storm/lib/*:/usr/hdp/3.1.0.0-78/storm/extlib/* org.apache.storm.daemon.ClientJarTransformerRunner org.apache.storm.hack.StormShadeTransformer strmjr2-0.0.1-SNAPSHOT.jar /tmp/011dcea098a811e9b8d1f9e5e43755af.jar Exception in thread "main" java.lang.IllegalArgumentException   at org.apache.storm.hack.shade.org.objectweb.asm.ClassReader.<init>(Unknown Source)     at org.apache.storm.hack.shade.org.objectweb.asm.ClassReader.<init>(Unknown Source)     at org.apache.storm.hack.shade.org.objectweb.asm.ClassReader.<init>(Unknown Source)     at org.apache.storm.hack.DefaultShader.addRemappedClass(DefaultShader.java:182)     at org.apache.storm.hack.DefaultShader.shadeJarStream(DefaultShader.java:103)   at org.apache.storm.hack.StormShadeTransformer.transform(StormShadeTransformer.java:35)     at org.apache.storm.daemon.ClientJarTransformerRunner.main(ClientJarTransformerRunner.java:37) Running: /usr/jdk64/jdk1.8.0_112/bin/java -Ddaemon.name=
            -Dstorm.options= -Dstorm.home=/usr/hdp/3.1.0.0-78/storm -Dstorm.log.dir=/var/log/storm -Djava.library.path=/usr/local/lib:/opt/local/lib:/usr/lib -Dstorm.conf.file= -cp /usr/hdp/3.1.0.0-78/storm/*:/usr/hdp/3.1.0.0-78/storm/lib/*:/usr/hdp/3.1.0.0-78/storm/extlib/*:/tmp/011dcea098a811e9b8d1f9e5e43755af.jar:/usr/hdp/current/storm-supervisor/conf:/usr/hdp/3.1.0.0-78/storm/bin
            -Dstorm.jar=/tmp/011dcea098a811e9b8d1f9e5e43755af.jar -Dstorm.dependency.jars= -Dstorm.dependency.artifacts={} artifactid.groupid.main

        Error: Could not find or load main class

是否可以在hadoop集群中使用tensorflow模型运行storm拓扑。如果可以,请提供帮助。

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题