本文整理了Java中org.nd4j.linalg.factory.Nd4j.argMax()
方法的一些代码示例,展示了Nd4j.argMax()
的具体用法。这些代码示例主要来源于Github
/Stackoverflow
/Maven
等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Nd4j.argMax()
方法的具体详情如下:
包路径:org.nd4j.linalg.factory.Nd4j
类名称:Nd4j
方法名:argMax
暂无
代码示例来源:origin: deeplearning4j/nd4j
/**
* This method returns index of highest value along specified dimension(s)
*
* @param dimension
* @return
*/
@Override
public INDArray argMax(int... dimension) {
return Nd4j.argMax(this, dimension);
}
代码示例来源:origin: deeplearning4j/dl4j-examples
INDArray argMaxAlongDim0 = Nd4j.argMax(originalArray,0); //Index of the max value, along dimension 0
System.out.println("\n\nargmax along dimension 0: " + argMaxAlongDim0);
INDArray argMinAlongDim0 = Nd4j.getExecutioner().exec(new IMin(originalArray),0); //Index of the min value, along dimension 0
代码示例来源:origin: apache/tika
private List<RecognisedObject> predict(INDArray predictions)
{
List<RecognisedObject> objects = new ArrayList<>();
int[] topNPredictions = new int[topN];
float[] topNProb = new float[topN];
String outLabels[]=new String[topN];
//brute force collect top N
int i = 0;
for (int batch = 0; batch < predictions.size(0); batch++) {
INDArray currentBatch = predictions.getRow(batch).dup();
while (i < topN) {
topNPredictions[i] = Nd4j.argMax(currentBatch, 1).getInt(0, 0);
topNProb[i] = currentBatch.getFloat(batch, topNPredictions[i]);
currentBatch.putScalar(0, topNPredictions[i], 0);
outLabels[i]= imageNetLabels.getLabel(topNPredictions[i]);
objects.add(new RecognisedObject(outLabels[i], "eng", outLabels[i], topNProb[i]));
i++;
}
}
return objects;
}
}
代码示例来源:origin: org.nd4j/nd4j-api
/**
* This method returns index of highest value along specified dimension(s)
*
* @param dimension
* @return
*/
@Override
public INDArray argMax(int... dimension) {
return Nd4j.argMax(this, dimension);
}
代码示例来源:origin: dkpro/dkpro-tc
private static void eval(INDArray labels, INDArray p, INDArray outMask, StringBuilder sb)
{
Pair<INDArray, INDArray> pair = EvaluationUtils.extractNonMaskedTimeSteps(labels, p,
outMask);
INDArray realOutcomes = pair.getFirst();
INDArray guesses = pair.getSecond();
// Length of real labels must be same as length of predicted labels
if (realOutcomes.length() != guesses.length())
throw new IllegalArgumentException(
"Unable to evaluate. Outcome matrices not same length");
INDArray guessIndex = Nd4j.argMax(guesses, 1);
INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
int nExamples = guessIndex.length();
for (int i = 0; i < nExamples; i++) {
int actual = (int) realOutcomeIndex.getDouble(i);
int predicted = (int) guessIndex.getDouble(i);
sb.append(actual + "\t" + predicted + "\n");
}
}
}
代码示例来源:origin: de.tudarmstadt.ukp.inception.app/inception-imls-dl4j
int sampleIdx = 0;
for (Sample sample : aData) {
INDArray argMax = Nd4j.argMax(predicted, 1);
代码示例来源:origin: inception-project/inception
int sampleIdx = 0;
for (Sample sample : aData) {
INDArray argMax = Nd4j.argMax(predicted, 1);
代码示例来源:origin: dkpro/dkpro-tc
private void eval(INDArray labels, INDArray p, INDArray outMask, StringBuilder sb)
{
Pair<INDArray, INDArray> pair = EvaluationUtils.extractNonMaskedTimeSteps(labels, p,
outMask);
INDArray realOutcomes = pair.getFirst();
INDArray guesses = pair.getSecond();
// Length of real labels must be same as length of predicted labels
if (realOutcomes.length() != guesses.length())
throw new IllegalArgumentException(
"Unable to evaluate. Outcome matrices not same length");
INDArray guessIndex = Nd4j.argMax(guesses, 1);
INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
int nExamples = guessIndex.length();
for (int i = 0; i < nExamples; i++) {
int actual = (int) realOutcomeIndex.getDouble(i);
int predicted = (int) guessIndex.getDouble(i);
sb.append(
vectorize.getTagset()[actual] + "\t" + vectorize.getTagset()[predicted] + "\n");
}
}
代码示例来源:origin: CampagneLaboratory/variationanalysis
if (trueLabels != null) {
INDArray trueLabelRow = trueLabels;
trueMaxIndices = Nd4j.argMax(trueLabelRow, 0);
INDArray predictedMaxIndices = Nd4j.argMax(predictedRow, 0);
代码示例来源:origin: org.deeplearning4j/deeplearning4j-modelimport
INDArray currentBatch = predictions.getRow(batch).dup();
while (i < 5) {
top5[i] = Nd4j.argMax(currentBatch, 1).getInt(0, 0);
top5Prob[i] = currentBatch.getFloat(batch, top5[i]);
currentBatch.putScalar(0, top5[i], 0);
代码示例来源:origin: org.deeplearning4j/deeplearning4j-zoo
/**
* Given predictions from the trained model this method will return a string
* listing the top five matches and the respective probabilities
* @param predictions
* @return
*/
public String decodePredictions(INDArray predictions) {
String predictionDescription = "";
int[] top5 = new int[5];
float[] top5Prob = new float[5];
//brute force collect top 5
int i = 0;
for (int batch = 0; batch < predictions.size(0); batch++) {
predictionDescription += "Predictions for batch ";
if (predictions.size(0) > 1) {
predictionDescription += String.valueOf(batch);
}
predictionDescription += " :";
INDArray currentBatch = predictions.getRow(batch).dup();
while (i < 5) {
top5[i] = Nd4j.argMax(currentBatch, 1).getInt(0, 0);
top5Prob[i] = currentBatch.getFloat(batch, top5[i]);
currentBatch.putScalar(0, top5[i], 0);
predictionDescription += "\n\t" + String.format("%3f", top5Prob[i] * 100) + "%, "
+ predictionLabels.get(top5[i]);
i++;
}
}
return predictionDescription;
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
} else if (costArray != null) {
guessIndex = Nd4j.argMax(guesses.mulRowVector(costArray), 1);
} else {
guessIndex = Nd4j.argMax(guesses, 1);
INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
int nExamples = guessIndex.length();
INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
int nExamples = realOutcomeIndex.length();
for (int i = 0; i < nExamples; i++) {
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
INDArray row = Nd4j.linspace(0, maxTsLength - 1, maxTsLength);
INDArray temp = mask.mulRowVector(row);
INDArray lastElementIdx = Nd4j.argMax(temp, 1);
fwdPassTimeSteps = new int[fwdPassShape[0]];
for (int i = 0; i < fwdPassTimeSteps.length; i++) {
代码示例来源:origin: Waikato/wekaDeeplearning4j
INDArray lastTimeStepIndices;
if (labelsMask != null){
lastTimeStepIndices = Nd4j.argMax(labelsMask, 1);
} else {
lastTimeStepIndices = Nd4j.zeros(features.size(0), 1);
内容来源于网络,如有侵权,请联系作者删除!