本文整理了Java中org.nd4j.linalg.factory.Nd4j.sortWithIndices()
方法的一些代码示例,展示了Nd4j.sortWithIndices()
的具体用法。这些代码示例主要来源于Github
/Stackoverflow
/Maven
等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Nd4j.sortWithIndices()
方法的具体详情如下:
包路径:org.nd4j.linalg.factory.Nd4j
类名称:Nd4j
方法名:sortWithIndices
[英]Sort an ndarray along a particular dimension
[中]沿着特定维度对数据进行排序
代码示例来源:origin: de.datexis/texoo-core
@Override
public Collection<String> getNearestNeighbours(INDArray v, int n) {
// find maximum entries
INDArray[] sorted = Nd4j.sortWithIndices(v.dup(), 0, false); // index,value
if(sorted[0].sumNumber().doubleValue() == 0.) // TODO: sortWithIndices could be run on -1 / 0 / 1 ?
log.warn("NearestNeighbour on zero vector - please check vector alignment!");
INDArray idx = sorted[0]; // ranked indexes
// get top n
ArrayList<String> result = new ArrayList<>(n);
for(int i=0; i<n; i++) {
if(sorted[1].getDouble(i) > 0.) result.add(getWord(idx.getInt(i)));
}
return result;
}
代码示例来源:origin: neo4j-graph-analytics/ml-models
public void logBins(double[][] embedding) {
INDArray indArray = Nd4j.create(embedding);
for (int column = 0; column < embedding[0].length; column++) {
int remaining = embedding.length;
int binNumber = 0;
INDArray slice = indArray.slice(column, 1);
INDArray[] indArrays = Nd4j.sortWithIndices(slice, 0, true);
INDArray indices = indArrays[0];
for (int node = 0; node < embedding.length; node++) {
if (node + remaining == embedding.length) {
remaining /= 2;
binNumber++;
}
embedding[(int) indices.getDouble(node)][column] = binNumber - 1;
}
}
}
代码示例来源:origin: neo4j-graph-analytics/ml-models
public void linearBins(double[][] embedding, int numBins) {
INDArray indArray = Nd4j.create(embedding);
for (int column = 0; column < embedding[0].length; column++) {
INDArray slice = indArray.slice(column, 1);
INDArray[] indArrays = Nd4j.sortWithIndices(slice, 0, true);
INDArray indices = indArrays[0];
int maxRank = embedding.length;
for (int rank = 0; rank < indices.size(0); rank++) {
embedding[(int) indices.getDouble(rank)][column] = (int) (((double) rank / maxRank) * numBins);
}
}
}
代码示例来源:origin: org.deeplearning4j/nearestneighbor-core
public void search() {
results = new ArrayList<>();
distances = new ArrayList<>();
//initial search
//vpTree.search(target,k,results,distances);
//fill till there is k results
//by going down the list
// if(results.size() < k) {
INDArray distancesArr = Nd4j.create(vpTree.getItems().rows(), 1);
vpTree.calcDistancesRelativeTo(target, distancesArr);
INDArray[] sortWithIndices = Nd4j.sortWithIndices(distancesArr, 0, !vpTree.isInvert());
results.clear();
distances.clear();
if (vpTree.getItems().isVector()) {
for (int i = 0; i < k; i++) {
int idx = sortWithIndices[0].getInt(i);
results.add(new DataPoint(idx, Nd4j.scalar(vpTree.getItems().getDouble(idx))));
distances.add(sortWithIndices[1].getDouble(idx));
}
} else {
for (int i = 0; i < k; i++) {
int idx = sortWithIndices[0].getInt(i);
results.add(new DataPoint(idx, vpTree.getItems().getRow(idx)));
distances.add(sortWithIndices[1].getDouble(idx));
}
}
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
INDArray[] maxWithIndices = Nd4j.sortWithIndices(outcome, -1, false);
INDArray indexes = maxWithIndices[0];
代码示例来源:origin: neo4j-graph-analytics/ml-models
public void logBins(INDArray indArray) {
for (int column = 0; column < indArray.size(1); column++) {
int remaining = indArray.size(0);
int binNumber = 0;
INDArray slice = indArray.slice(column, 1);
INDArray[] indArrays = Nd4j.sortWithIndices(slice, 0, true);
INDArray indices = indArrays[0];
for (int node = 0; node < indArray.size(0); node++) {
if (node + remaining == indArray.size(0)) {
remaining /= 2;
binNumber++;
}
indArray.putScalar((int) indices.getDouble(node), column, binNumber - 1);
}
}
}
}
内容来源于网络,如有侵权,请联系作者删除!