org.nd4j.linalg.factory.Nd4j.argMax()方法的使用及代码示例

x33g5p2x  于2022-01-24 转载在 其他  
字(5.8k)|赞(0)|评价(0)|浏览(225)

本文整理了Java中org.nd4j.linalg.factory.Nd4j.argMax()方法的一些代码示例,展示了Nd4j.argMax()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Nd4j.argMax()方法的具体详情如下:
包路径:org.nd4j.linalg.factory.Nd4j
类名称:Nd4j
方法名:argMax

Nd4j.argMax介绍

暂无

代码示例

代码示例来源:origin: deeplearning4j/nd4j

/**
 * This method returns index of highest value along specified dimension(s)
 *
 * @param dimension
 * @return
 */
@Override
public INDArray argMax(int... dimension) {
  return Nd4j.argMax(this, dimension);
}

代码示例来源:origin: deeplearning4j/dl4j-examples

INDArray argMaxAlongDim0 = Nd4j.argMax(originalArray,0);                            //Index of the max value, along dimension 0
System.out.println("\n\nargmax along dimension 0:   " + argMaxAlongDim0);
INDArray argMinAlongDim0 = Nd4j.getExecutioner().exec(new IMin(originalArray),0);   //Index of the min value, along dimension 0

代码示例来源:origin: apache/tika

private List<RecognisedObject> predict(INDArray predictions)
  {
    List<RecognisedObject> objects = new ArrayList<>();
    int[] topNPredictions = new int[topN];
    float[] topNProb = new float[topN];
    String outLabels[]=new String[topN];
    //brute force collect top N
    int i = 0;
    for (int batch = 0; batch < predictions.size(0); batch++) {
      INDArray currentBatch = predictions.getRow(batch).dup();
      while (i < topN) {
        topNPredictions[i] = Nd4j.argMax(currentBatch, 1).getInt(0, 0);
        topNProb[i] = currentBatch.getFloat(batch, topNPredictions[i]);
        currentBatch.putScalar(0, topNPredictions[i], 0);
        outLabels[i]= imageNetLabels.getLabel(topNPredictions[i]);
        objects.add(new RecognisedObject(outLabels[i], "eng", outLabels[i], topNProb[i]));
        i++;
      }
    }
    return objects;
  }
}

代码示例来源:origin: org.nd4j/nd4j-api

/**
 * This method returns index of highest value along specified dimension(s)
 *
 * @param dimension
 * @return
 */
@Override
public INDArray argMax(int... dimension) {
  return Nd4j.argMax(this, dimension);
}

代码示例来源:origin: dkpro/dkpro-tc

private static void eval(INDArray labels, INDArray p, INDArray outMask, StringBuilder sb)
  {
    Pair<INDArray, INDArray> pair = EvaluationUtils.extractNonMaskedTimeSteps(labels, p,
        outMask);

    INDArray realOutcomes = pair.getFirst();
    INDArray guesses = pair.getSecond();

    // Length of real labels must be same as length of predicted labels
    if (realOutcomes.length() != guesses.length())
      throw new IllegalArgumentException(
          "Unable to evaluate. Outcome matrices not same length");

    INDArray guessIndex = Nd4j.argMax(guesses, 1);
    INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);

    int nExamples = guessIndex.length();
    for (int i = 0; i < nExamples; i++) {
      int actual = (int) realOutcomeIndex.getDouble(i);
      int predicted = (int) guessIndex.getDouble(i);
      sb.append(actual + "\t" + predicted + "\n");
    }
  }
}

代码示例来源:origin: de.tudarmstadt.ukp.inception.app/inception-imls-dl4j

int sampleIdx = 0;
for (Sample sample : aData) {
  INDArray argMax = Nd4j.argMax(predicted, 1);

代码示例来源:origin: inception-project/inception

int sampleIdx = 0;
for (Sample sample : aData) {
  INDArray argMax = Nd4j.argMax(predicted, 1);

代码示例来源:origin: dkpro/dkpro-tc

private void eval(INDArray labels, INDArray p, INDArray outMask, StringBuilder sb)
{
  Pair<INDArray, INDArray> pair = EvaluationUtils.extractNonMaskedTimeSteps(labels, p,
      outMask);
  INDArray realOutcomes = pair.getFirst();
  INDArray guesses = pair.getSecond();
  // Length of real labels must be same as length of predicted labels
  if (realOutcomes.length() != guesses.length())
    throw new IllegalArgumentException(
        "Unable to evaluate. Outcome matrices not same length");
  INDArray guessIndex = Nd4j.argMax(guesses, 1);
  INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
  int nExamples = guessIndex.length();
  for (int i = 0; i < nExamples; i++) {
    int actual = (int) realOutcomeIndex.getDouble(i);
    int predicted = (int) guessIndex.getDouble(i);
    sb.append(
        vectorize.getTagset()[actual] + "\t" + vectorize.getTagset()[predicted] + "\n");
  }
}

代码示例来源:origin: CampagneLaboratory/variationanalysis

if (trueLabels != null) {
  INDArray trueLabelRow = trueLabels;
  trueMaxIndices = Nd4j.argMax(trueLabelRow, 0);
INDArray predictedMaxIndices = Nd4j.argMax(predictedRow, 0);

代码示例来源:origin: org.deeplearning4j/deeplearning4j-modelimport

INDArray currentBatch = predictions.getRow(batch).dup();
while (i < 5) {
  top5[i] = Nd4j.argMax(currentBatch, 1).getInt(0, 0);
  top5Prob[i] = currentBatch.getFloat(batch, top5[i]);
  currentBatch.putScalar(0, top5[i], 0);

代码示例来源:origin: org.deeplearning4j/deeplearning4j-zoo

/**
 * Given predictions from the trained model this method will return a string
 * listing the top five matches and the respective probabilities
 * @param predictions
 * @return
 */
public String decodePredictions(INDArray predictions) {
  String predictionDescription = "";
  int[] top5 = new int[5];
  float[] top5Prob = new float[5];
  //brute force collect top 5
  int i = 0;
  for (int batch = 0; batch < predictions.size(0); batch++) {
    predictionDescription += "Predictions for batch ";
    if (predictions.size(0) > 1) {
      predictionDescription += String.valueOf(batch);
    }
    predictionDescription += " :";
    INDArray currentBatch = predictions.getRow(batch).dup();
    while (i < 5) {
      top5[i] = Nd4j.argMax(currentBatch, 1).getInt(0, 0);
      top5Prob[i] = currentBatch.getFloat(batch, top5[i]);
      currentBatch.putScalar(0, top5[i], 0);
      predictionDescription += "\n\t" + String.format("%3f", top5Prob[i] * 100) + "%, "
              + predictionLabels.get(top5[i]);
      i++;
    }
  }
  return predictionDescription;
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

} else if (costArray != null) {
  guessIndex = Nd4j.argMax(guesses.mulRowVector(costArray), 1);
} else {
  guessIndex = Nd4j.argMax(guesses, 1);
INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
int nExamples = guessIndex.length();
INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
int nExamples = realOutcomeIndex.length();
for (int i = 0; i < nExamples; i++) {

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

INDArray row = Nd4j.linspace(0, maxTsLength - 1, maxTsLength);
INDArray temp = mask.mulRowVector(row);
INDArray lastElementIdx = Nd4j.argMax(temp, 1);
fwdPassTimeSteps = new int[fwdPassShape[0]];
for (int i = 0; i < fwdPassTimeSteps.length; i++) {

代码示例来源:origin: Waikato/wekaDeeplearning4j

INDArray lastTimeStepIndices;
if (labelsMask != null){
 lastTimeStepIndices = Nd4j.argMax(labelsMask, 1);
} else {
 lastTimeStepIndices = Nd4j.zeros(features.size(0), 1);

相关文章