org.nd4j.linalg.factory.Nd4j.pullRows()方法的使用及代码示例

x33g5p2x  于2022-01-24 转载在 其他  
字(5.2k)|赞(0)|评价(0)|浏览(107)

本文整理了Java中org.nd4j.linalg.factory.Nd4j.pullRows()方法的一些代码示例,展示了Nd4j.pullRows()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Nd4j.pullRows()方法的具体详情如下:
包路径:org.nd4j.linalg.factory.Nd4j
类名称:Nd4j
方法名:pullRows

Nd4j.pullRows介绍

[英]This method produces concatenated array, that consist from tensors, fetched from source array, against some dimension and specified indexes
[中]该方法根据某些维度和指定的索引,生成由张量组成的串联数组,从源数组中获取

代码示例

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * This method produces concatenated array, that consist from tensors, fetched from source array, against some dimension and specified indexes
  3. *
  4. * @param source source tensor
  5. * @param sourceDimension dimension of source tensor
  6. * @param indexes indexes from source array
  7. * @return
  8. */
  9. public static INDArray pullRows(INDArray source, int sourceDimension, int[] indexes) {
  10. return pullRows(source, sourceDimension, indexes, Nd4j.order());
  11. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Get whole rows from the passed indices.
  3. *
  4. * @param rindices
  5. */
  6. @Override
  7. public INDArray getRows(int[] rindices) {
  8. Nd4j.getCompressor().autoDecompress(this);
  9. if (!isMatrix() && !isVector())
  10. throw new IllegalArgumentException("Unable to get columns from a non matrix or vector");
  11. if (isVector())
  12. return Nd4j.pullRows(this, 1, rindices);
  13. else {
  14. INDArray ret = Nd4j.create(rindices.length, columns());
  15. for (int i = 0; i < rindices.length; i++)
  16. ret.putRow(i, getRow(rindices[i]));
  17. return ret;
  18. }
  19. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Get whole columns
  3. * from the passed indices.
  4. *
  5. * @param cindices
  6. */
  7. @Override
  8. public INDArray getColumns(int... cindices) {
  9. if (!isMatrix() && !isVector())
  10. throw new IllegalArgumentException("Unable to get columns from a non matrix or vector");
  11. if (isVector()) {
  12. return Nd4j.pullRows(this, 0, cindices, this.ordering());
  13. } else {
  14. INDArray ret = Nd4j.create(rows(), cindices.length);
  15. for (int i = 0; i < cindices.length; i++)
  16. ret.putColumn(i, getColumn(cindices[i]));
  17. return ret;
  18. }
  19. }

代码示例来源:origin: deeplearning4j/nd4j

  1. INDArray subset = Nd4j.pullRows(as2d, 1, rowsToPull); //Tensor along dimension 1 == rows
  2. return subset;

代码示例来源:origin: org.nd4j/nd4j-api

  1. /**
  2. * This method produces concatenated array, that consist from tensors, fetched from source array, against some dimension and specified indexes
  3. *
  4. * @param source source tensor
  5. * @param sourceDimension dimension of source tensor
  6. * @param indexes indexes from source array
  7. * @return
  8. */
  9. public static INDArray pullRows(INDArray source, int sourceDimension, int[] indexes) {
  10. return pullRows(source, sourceDimension, indexes, Nd4j.order());
  11. }

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

  1. labels2d = Nd4j.pullRows(labels2d, 1, rowsToPull);
  2. predicted2d = Nd4j.pullRows(predicted2d, 1, rowsToPull);

代码示例来源:origin: org.nd4j/nd4j-api

  1. /**
  2. * Get whole columns
  3. * from the passed indices.
  4. *
  5. * @param cindices
  6. */
  7. @Override
  8. public INDArray getColumns(int... cindices) {
  9. if (!isMatrix() && !isVector())
  10. throw new IllegalArgumentException("Unable to get columns from a non matrix or vector");
  11. if (isVector()) {
  12. return Nd4j.pullRows(this, 0, cindices, this.ordering());
  13. } else {
  14. INDArray ret = Nd4j.create(rows(), cindices.length);
  15. for (int i = 0; i < cindices.length; i++)
  16. ret.putColumn(i, getColumn(cindices[i]));
  17. return ret;
  18. }
  19. }

代码示例来源:origin: org.nd4j/nd4j-api

  1. /**
  2. * Get whole rows from the passed indices.
  3. *
  4. * @param rindices
  5. */
  6. @Override
  7. public INDArray getRows(int[] rindices) {
  8. Nd4j.getCompressor().autoDecompress(this);
  9. if (!isMatrix() && !isVector())
  10. throw new IllegalArgumentException("Unable to get columns from a non matrix or vector");
  11. if (isVector())
  12. return Nd4j.pullRows(this, 1, rindices);
  13. else {
  14. INDArray ret = Nd4j.create(rindices.length, columns());
  15. for (int i = 0; i < rindices.length; i++)
  16. ret.putRow(i, getRow(rindices[i]));
  17. return ret;
  18. }
  19. }

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

  1. prob = Nd4j.pullRows(prob, 1, rowsToPull); //1: tensor along dim 1
  2. label = Nd4j.pullRows(label, 1, rowsToPull);

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

  1. @Override
  2. public INDArray preOutput(boolean training) {
  3. if (input.columns() != 1) {
  4. //Assume shape is [numExamples,1], and each entry is an integer index
  5. throw new DL4JInvalidInputException(
  6. "Cannot do forward pass for embedding layer with input more than one column. "
  7. + "Expected input shape: [numExamples,1] with each entry being an integer index "
  8. + layerId());
  9. }
  10. int[] indexes = new int[input.length()];
  11. for (int i = 0; i < indexes.length; i++)
  12. indexes[i] = input.getInt(i, 0);
  13. INDArray weights = getParam(DefaultParamInitializer.WEIGHT_KEY);
  14. INDArray bias = getParam(DefaultParamInitializer.BIAS_KEY);
  15. INDArray rows = Nd4j.pullRows(weights, 1, indexes);
  16. rows.addiRowVector(bias);
  17. return rows;
  18. }

代码示例来源:origin: org.nd4j/nd4j-api

  1. INDArray subset = Nd4j.pullRows(as2d, 1, rowsToPull); //Tensor along dimension 1 == rows
  2. return subset;

代码示例来源:origin: org.nd4j/nd4j-parameter-server-node_2.11

  1. INDArray syn1Neg = storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
  2. INDArray words = Nd4j.pullRows(storage.getArray(WordVectorStorage.SYN_0), 1, cbr.getSyn0rows(), 'c');
  3. INDArray neue = words.mean(0);

代码示例来源:origin: org.nd4j/nd4j-parameter-server-node

  1. INDArray syn1Neg = storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
  2. INDArray words = Nd4j.pullRows(storage.getArray(WordVectorStorage.SYN_0), 1, cbr.getSyn0rows(), 'c');
  3. INDArray neue = words.mean(0);

代码示例来源:origin: org.nd4j/nd4j-parameter-server-node

  1. INDArray words = Nd4j.pullRows(storage.getArray(WordVectorStorage.SYN_0), 1, rowsA, 'c');
  2. INDArray mean = words.mean(0);

代码示例来源:origin: org.nd4j/nd4j-parameter-server-node_2.11

  1. INDArray words = Nd4j.pullRows(storage.getArray(WordVectorStorage.SYN_0), 1, rowsA, 'c');
  2. INDArray mean = words.mean(0);

相关文章