org.nd4j.linalg.factory.Nd4j.sortWithIndices()方法的使用及代码示例

x33g5p2x  于2022-01-24 转载在 其他  
字(3.4k)|赞(0)|评价(0)|浏览(127)

本文整理了Java中org.nd4j.linalg.factory.Nd4j.sortWithIndices()方法的一些代码示例,展示了Nd4j.sortWithIndices()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Nd4j.sortWithIndices()方法的具体详情如下:
包路径:org.nd4j.linalg.factory.Nd4j
类名称:Nd4j
方法名:sortWithIndices

Nd4j.sortWithIndices介绍

[英]Sort an ndarray along a particular dimension
[中]沿着特定维度对数据进行排序

代码示例

代码示例来源:origin: de.datexis/texoo-core

  1. @Override
  2. public Collection<String> getNearestNeighbours(INDArray v, int n) {
  3. // find maximum entries
  4. INDArray[] sorted = Nd4j.sortWithIndices(v.dup(), 0, false); // index,value
  5. if(sorted[0].sumNumber().doubleValue() == 0.) // TODO: sortWithIndices could be run on -1 / 0 / 1 ?
  6. log.warn("NearestNeighbour on zero vector - please check vector alignment!");
  7. INDArray idx = sorted[0]; // ranked indexes
  8. // get top n
  9. ArrayList<String> result = new ArrayList<>(n);
  10. for(int i=0; i<n; i++) {
  11. if(sorted[1].getDouble(i) > 0.) result.add(getWord(idx.getInt(i)));
  12. }
  13. return result;
  14. }

代码示例来源:origin: neo4j-graph-analytics/ml-models

  1. public void logBins(double[][] embedding) {
  2. INDArray indArray = Nd4j.create(embedding);
  3. for (int column = 0; column < embedding[0].length; column++) {
  4. int remaining = embedding.length;
  5. int binNumber = 0;
  6. INDArray slice = indArray.slice(column, 1);
  7. INDArray[] indArrays = Nd4j.sortWithIndices(slice, 0, true);
  8. INDArray indices = indArrays[0];
  9. for (int node = 0; node < embedding.length; node++) {
  10. if (node + remaining == embedding.length) {
  11. remaining /= 2;
  12. binNumber++;
  13. }
  14. embedding[(int) indices.getDouble(node)][column] = binNumber - 1;
  15. }
  16. }
  17. }

代码示例来源:origin: neo4j-graph-analytics/ml-models

  1. public void linearBins(double[][] embedding, int numBins) {
  2. INDArray indArray = Nd4j.create(embedding);
  3. for (int column = 0; column < embedding[0].length; column++) {
  4. INDArray slice = indArray.slice(column, 1);
  5. INDArray[] indArrays = Nd4j.sortWithIndices(slice, 0, true);
  6. INDArray indices = indArrays[0];
  7. int maxRank = embedding.length;
  8. for (int rank = 0; rank < indices.size(0); rank++) {
  9. embedding[(int) indices.getDouble(rank)][column] = (int) (((double) rank / maxRank) * numBins);
  10. }
  11. }
  12. }

代码示例来源:origin: org.deeplearning4j/nearestneighbor-core

  1. public void search() {
  2. results = new ArrayList<>();
  3. distances = new ArrayList<>();
  4. //initial search
  5. //vpTree.search(target,k,results,distances);
  6. //fill till there is k results
  7. //by going down the list
  8. // if(results.size() < k) {
  9. INDArray distancesArr = Nd4j.create(vpTree.getItems().rows(), 1);
  10. vpTree.calcDistancesRelativeTo(target, distancesArr);
  11. INDArray[] sortWithIndices = Nd4j.sortWithIndices(distancesArr, 0, !vpTree.isInvert());
  12. results.clear();
  13. distances.clear();
  14. if (vpTree.getItems().isVector()) {
  15. for (int i = 0; i < k; i++) {
  16. int idx = sortWithIndices[0].getInt(i);
  17. results.add(new DataPoint(idx, Nd4j.scalar(vpTree.getItems().getDouble(idx))));
  18. distances.add(sortWithIndices[1].getDouble(idx));
  19. }
  20. } else {
  21. for (int i = 0; i < k; i++) {
  22. int idx = sortWithIndices[0].getInt(i);
  23. results.add(new DataPoint(idx, vpTree.getItems().getRow(idx)));
  24. distances.add(sortWithIndices[1].getDouble(idx));
  25. }
  26. }
  27. }

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

  1. INDArray[] maxWithIndices = Nd4j.sortWithIndices(outcome, -1, false);
  2. INDArray indexes = maxWithIndices[0];

代码示例来源:origin: neo4j-graph-analytics/ml-models

  1. public void logBins(INDArray indArray) {
  2. for (int column = 0; column < indArray.size(1); column++) {
  3. int remaining = indArray.size(0);
  4. int binNumber = 0;
  5. INDArray slice = indArray.slice(column, 1);
  6. INDArray[] indArrays = Nd4j.sortWithIndices(slice, 0, true);
  7. INDArray indices = indArrays[0];
  8. for (int node = 0; node < indArray.size(0); node++) {
  9. if (node + remaining == indArray.size(0)) {
  10. remaining /= 2;
  11. binNumber++;
  12. }
  13. indArray.putScalar((int) indices.getDouble(node), column, binNumber - 1);
  14. }
  15. }
  16. }
  17. }

相关文章