org.nd4j.linalg.factory.Nd4j.pullRows()方法的使用及代码示例

x33g5p2x  于2022-01-24 转载在 其他  
字(5.2k)|赞(0)|评价(0)|浏览(89)

本文整理了Java中org.nd4j.linalg.factory.Nd4j.pullRows()方法的一些代码示例,展示了Nd4j.pullRows()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Nd4j.pullRows()方法的具体详情如下:
包路径:org.nd4j.linalg.factory.Nd4j
类名称:Nd4j
方法名:pullRows

Nd4j.pullRows介绍

[英]This method produces concatenated array, that consist from tensors, fetched from source array, against some dimension and specified indexes
[中]该方法根据某些维度和指定的索引,生成由张量组成的串联数组,从源数组中获取

代码示例

代码示例来源:origin: deeplearning4j/nd4j

/**
 * This method produces concatenated array, that consist from tensors, fetched from source array, against some dimension and specified indexes
 *
 * @param source source tensor
 * @param sourceDimension dimension of source tensor
 * @param indexes indexes from source array
 * @return
 */
public static INDArray pullRows(INDArray source, int sourceDimension, int[] indexes) {
  return pullRows(source, sourceDimension, indexes, Nd4j.order());
}

代码示例来源:origin: deeplearning4j/nd4j

/**
 * Get whole rows from the passed indices.
 *
 * @param rindices
 */
@Override
public INDArray getRows(int[] rindices) {
  Nd4j.getCompressor().autoDecompress(this);
  if (!isMatrix() && !isVector())
    throw new IllegalArgumentException("Unable to get columns from a non matrix or vector");
  if (isVector())
    return Nd4j.pullRows(this, 1, rindices);
  else {
    INDArray ret = Nd4j.create(rindices.length, columns());
    for (int i = 0; i < rindices.length; i++)
      ret.putRow(i, getRow(rindices[i]));
    return ret;
  }
}

代码示例来源:origin: deeplearning4j/nd4j

/**
 * Get whole columns
 * from the passed indices.
 *
 * @param cindices
 */
@Override
public INDArray getColumns(int... cindices) {
  if (!isMatrix() && !isVector())
    throw new IllegalArgumentException("Unable to get columns from a non matrix or vector");
  if (isVector()) {
    return Nd4j.pullRows(this, 0, cindices, this.ordering());
  } else {
    INDArray ret = Nd4j.create(rows(), cindices.length);
    for (int i = 0; i < cindices.length; i++)
      ret.putColumn(i, getColumn(cindices[i]));
    return ret;
  }
}

代码示例来源:origin: deeplearning4j/nd4j

INDArray subset = Nd4j.pullRows(as2d, 1, rowsToPull); //Tensor along dimension 1 == rows
return subset;

代码示例来源:origin: org.nd4j/nd4j-api

/**
 * This method produces concatenated array, that consist from tensors, fetched from source array, against some dimension and specified indexes
 *
 * @param source source tensor
 * @param sourceDimension dimension of source tensor
 * @param indexes indexes from source array
 * @return
 */
public static INDArray pullRows(INDArray source, int sourceDimension, int[] indexes) {
  return pullRows(source, sourceDimension, indexes, Nd4j.order());
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

labels2d = Nd4j.pullRows(labels2d, 1, rowsToPull);
predicted2d = Nd4j.pullRows(predicted2d, 1, rowsToPull);

代码示例来源:origin: org.nd4j/nd4j-api

/**
 * Get whole columns
 * from the passed indices.
 *
 * @param cindices
 */
@Override
public INDArray getColumns(int... cindices) {
  if (!isMatrix() && !isVector())
    throw new IllegalArgumentException("Unable to get columns from a non matrix or vector");
  if (isVector()) {
    return Nd4j.pullRows(this, 0, cindices, this.ordering());
  } else {
    INDArray ret = Nd4j.create(rows(), cindices.length);
    for (int i = 0; i < cindices.length; i++)
      ret.putColumn(i, getColumn(cindices[i]));
    return ret;
  }
}

代码示例来源:origin: org.nd4j/nd4j-api

/**
 * Get whole rows from the passed indices.
 *
 * @param rindices
 */
@Override
public INDArray getRows(int[] rindices) {
  Nd4j.getCompressor().autoDecompress(this);
  if (!isMatrix() && !isVector())
    throw new IllegalArgumentException("Unable to get columns from a non matrix or vector");
  if (isVector())
    return Nd4j.pullRows(this, 1, rindices);
  else {
    INDArray ret = Nd4j.create(rindices.length, columns());
    for (int i = 0; i < rindices.length; i++)
      ret.putRow(i, getRow(rindices[i]));
    return ret;
  }
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

prob = Nd4j.pullRows(prob, 1, rowsToPull); //1: tensor along dim 1
label = Nd4j.pullRows(label, 1, rowsToPull);

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

@Override
public INDArray preOutput(boolean training) {
  if (input.columns() != 1) {
    //Assume shape is [numExamples,1], and each entry is an integer index
    throw new DL4JInvalidInputException(
            "Cannot do forward pass for embedding layer with input more than one column. "
                    + "Expected input shape: [numExamples,1] with each entry being an integer index "
                    + layerId());
  }
  int[] indexes = new int[input.length()];
  for (int i = 0; i < indexes.length; i++)
    indexes[i] = input.getInt(i, 0);
  INDArray weights = getParam(DefaultParamInitializer.WEIGHT_KEY);
  INDArray bias = getParam(DefaultParamInitializer.BIAS_KEY);
  INDArray rows = Nd4j.pullRows(weights, 1, indexes);
  rows.addiRowVector(bias);
  return rows;
}

代码示例来源:origin: org.nd4j/nd4j-api

INDArray subset = Nd4j.pullRows(as2d, 1, rowsToPull); //Tensor along dimension 1 == rows
return subset;

代码示例来源:origin: org.nd4j/nd4j-parameter-server-node_2.11

INDArray syn1Neg = storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
INDArray words = Nd4j.pullRows(storage.getArray(WordVectorStorage.SYN_0), 1, cbr.getSyn0rows(), 'c');
INDArray neue = words.mean(0);

代码示例来源:origin: org.nd4j/nd4j-parameter-server-node

INDArray syn1Neg = storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
INDArray words = Nd4j.pullRows(storage.getArray(WordVectorStorage.SYN_0), 1, cbr.getSyn0rows(), 'c');
INDArray neue = words.mean(0);

代码示例来源:origin: org.nd4j/nd4j-parameter-server-node

INDArray words = Nd4j.pullRows(storage.getArray(WordVectorStorage.SYN_0), 1, rowsA, 'c');
INDArray mean = words.mean(0);

代码示例来源:origin: org.nd4j/nd4j-parameter-server-node_2.11

INDArray words = Nd4j.pullRows(storage.getArray(WordVectorStorage.SYN_0), 1, rowsA, 'c');
INDArray mean = words.mean(0);

相关文章