org.nd4j.linalg.factory.Nd4j.argMax()方法的使用及代码示例

x33g5p2x  于2022-01-24 转载在 其他  
字(5.8k)|赞(0)|评价(0)|浏览(255)

本文整理了Java中org.nd4j.linalg.factory.Nd4j.argMax()方法的一些代码示例,展示了Nd4j.argMax()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Nd4j.argMax()方法的具体详情如下:
包路径:org.nd4j.linalg.factory.Nd4j
类名称:Nd4j
方法名:argMax

Nd4j.argMax介绍

暂无

代码示例

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * This method returns index of highest value along specified dimension(s)
  3. *
  4. * @param dimension
  5. * @return
  6. */
  7. @Override
  8. public INDArray argMax(int... dimension) {
  9. return Nd4j.argMax(this, dimension);
  10. }

代码示例来源:origin: deeplearning4j/dl4j-examples

  1. INDArray argMaxAlongDim0 = Nd4j.argMax(originalArray,0); //Index of the max value, along dimension 0
  2. System.out.println("\n\nargmax along dimension 0: " + argMaxAlongDim0);
  3. INDArray argMinAlongDim0 = Nd4j.getExecutioner().exec(new IMin(originalArray),0); //Index of the min value, along dimension 0

代码示例来源:origin: apache/tika

  1. private List<RecognisedObject> predict(INDArray predictions)
  2. {
  3. List<RecognisedObject> objects = new ArrayList<>();
  4. int[] topNPredictions = new int[topN];
  5. float[] topNProb = new float[topN];
  6. String outLabels[]=new String[topN];
  7. //brute force collect top N
  8. int i = 0;
  9. for (int batch = 0; batch < predictions.size(0); batch++) {
  10. INDArray currentBatch = predictions.getRow(batch).dup();
  11. while (i < topN) {
  12. topNPredictions[i] = Nd4j.argMax(currentBatch, 1).getInt(0, 0);
  13. topNProb[i] = currentBatch.getFloat(batch, topNPredictions[i]);
  14. currentBatch.putScalar(0, topNPredictions[i], 0);
  15. outLabels[i]= imageNetLabels.getLabel(topNPredictions[i]);
  16. objects.add(new RecognisedObject(outLabels[i], "eng", outLabels[i], topNProb[i]));
  17. i++;
  18. }
  19. }
  20. return objects;
  21. }
  22. }

代码示例来源:origin: org.nd4j/nd4j-api

  1. /**
  2. * This method returns index of highest value along specified dimension(s)
  3. *
  4. * @param dimension
  5. * @return
  6. */
  7. @Override
  8. public INDArray argMax(int... dimension) {
  9. return Nd4j.argMax(this, dimension);
  10. }

代码示例来源:origin: dkpro/dkpro-tc

  1. private static void eval(INDArray labels, INDArray p, INDArray outMask, StringBuilder sb)
  2. {
  3. Pair<INDArray, INDArray> pair = EvaluationUtils.extractNonMaskedTimeSteps(labels, p,
  4. outMask);
  5. INDArray realOutcomes = pair.getFirst();
  6. INDArray guesses = pair.getSecond();
  7. // Length of real labels must be same as length of predicted labels
  8. if (realOutcomes.length() != guesses.length())
  9. throw new IllegalArgumentException(
  10. "Unable to evaluate. Outcome matrices not same length");
  11. INDArray guessIndex = Nd4j.argMax(guesses, 1);
  12. INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
  13. int nExamples = guessIndex.length();
  14. for (int i = 0; i < nExamples; i++) {
  15. int actual = (int) realOutcomeIndex.getDouble(i);
  16. int predicted = (int) guessIndex.getDouble(i);
  17. sb.append(actual + "\t" + predicted + "\n");
  18. }
  19. }
  20. }

代码示例来源:origin: de.tudarmstadt.ukp.inception.app/inception-imls-dl4j

  1. int sampleIdx = 0;
  2. for (Sample sample : aData) {
  3. INDArray argMax = Nd4j.argMax(predicted, 1);

代码示例来源:origin: inception-project/inception

  1. int sampleIdx = 0;
  2. for (Sample sample : aData) {
  3. INDArray argMax = Nd4j.argMax(predicted, 1);

代码示例来源:origin: dkpro/dkpro-tc

  1. private void eval(INDArray labels, INDArray p, INDArray outMask, StringBuilder sb)
  2. {
  3. Pair<INDArray, INDArray> pair = EvaluationUtils.extractNonMaskedTimeSteps(labels, p,
  4. outMask);
  5. INDArray realOutcomes = pair.getFirst();
  6. INDArray guesses = pair.getSecond();
  7. // Length of real labels must be same as length of predicted labels
  8. if (realOutcomes.length() != guesses.length())
  9. throw new IllegalArgumentException(
  10. "Unable to evaluate. Outcome matrices not same length");
  11. INDArray guessIndex = Nd4j.argMax(guesses, 1);
  12. INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
  13. int nExamples = guessIndex.length();
  14. for (int i = 0; i < nExamples; i++) {
  15. int actual = (int) realOutcomeIndex.getDouble(i);
  16. int predicted = (int) guessIndex.getDouble(i);
  17. sb.append(
  18. vectorize.getTagset()[actual] + "\t" + vectorize.getTagset()[predicted] + "\n");
  19. }
  20. }

代码示例来源:origin: CampagneLaboratory/variationanalysis

  1. if (trueLabels != null) {
  2. INDArray trueLabelRow = trueLabels;
  3. trueMaxIndices = Nd4j.argMax(trueLabelRow, 0);
  4. INDArray predictedMaxIndices = Nd4j.argMax(predictedRow, 0);

代码示例来源:origin: org.deeplearning4j/deeplearning4j-modelimport

  1. INDArray currentBatch = predictions.getRow(batch).dup();
  2. while (i < 5) {
  3. top5[i] = Nd4j.argMax(currentBatch, 1).getInt(0, 0);
  4. top5Prob[i] = currentBatch.getFloat(batch, top5[i]);
  5. currentBatch.putScalar(0, top5[i], 0);

代码示例来源:origin: org.deeplearning4j/deeplearning4j-zoo

  1. /**
  2. * Given predictions from the trained model this method will return a string
  3. * listing the top five matches and the respective probabilities
  4. * @param predictions
  5. * @return
  6. */
  7. public String decodePredictions(INDArray predictions) {
  8. String predictionDescription = "";
  9. int[] top5 = new int[5];
  10. float[] top5Prob = new float[5];
  11. //brute force collect top 5
  12. int i = 0;
  13. for (int batch = 0; batch < predictions.size(0); batch++) {
  14. predictionDescription += "Predictions for batch ";
  15. if (predictions.size(0) > 1) {
  16. predictionDescription += String.valueOf(batch);
  17. }
  18. predictionDescription += " :";
  19. INDArray currentBatch = predictions.getRow(batch).dup();
  20. while (i < 5) {
  21. top5[i] = Nd4j.argMax(currentBatch, 1).getInt(0, 0);
  22. top5Prob[i] = currentBatch.getFloat(batch, top5[i]);
  23. currentBatch.putScalar(0, top5[i], 0);
  24. predictionDescription += "\n\t" + String.format("%3f", top5Prob[i] * 100) + "%, "
  25. + predictionLabels.get(top5[i]);
  26. i++;
  27. }
  28. }
  29. return predictionDescription;
  30. }

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

  1. } else if (costArray != null) {
  2. guessIndex = Nd4j.argMax(guesses.mulRowVector(costArray), 1);
  3. } else {
  4. guessIndex = Nd4j.argMax(guesses, 1);
  5. INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
  6. int nExamples = guessIndex.length();
  7. INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
  8. int nExamples = realOutcomeIndex.length();
  9. for (int i = 0; i < nExamples; i++) {

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

  1. INDArray row = Nd4j.linspace(0, maxTsLength - 1, maxTsLength);
  2. INDArray temp = mask.mulRowVector(row);
  3. INDArray lastElementIdx = Nd4j.argMax(temp, 1);
  4. fwdPassTimeSteps = new int[fwdPassShape[0]];
  5. for (int i = 0; i < fwdPassTimeSteps.length; i++) {

代码示例来源:origin: Waikato/wekaDeeplearning4j

  1. INDArray lastTimeStepIndices;
  2. if (labelsMask != null){
  3. lastTimeStepIndices = Nd4j.argMax(labelsMask, 1);
  4. } else {
  5. lastTimeStepIndices = Nd4j.zeros(features.size(0), 1);

相关文章