org.nd4j.linalg.factory.Nd4j.getDataBufferFactory()方法的使用及代码示例

x33g5p2x  于2022-01-24 转载在 其他  
字(10.2k)|赞(0)|评价(0)|浏览(118)

本文整理了Java中org.nd4j.linalg.factory.Nd4j.getDataBufferFactory()方法的一些代码示例,展示了Nd4j.getDataBufferFactory()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Nd4j.getDataBufferFactory()方法的具体详情如下:
包路径:org.nd4j.linalg.factory.Nd4j
类名称:Nd4j
方法名:getDataBufferFactory

Nd4j.getDataBufferFactory介绍

暂无

代码示例

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Return the minor pointers. (columns for CSR, rows for CSC,...)
  3. * */
  4. public DataBuffer getVectorCoordinates() {
  5. return Nd4j.getDataBufferFactory().create(columnsPointers, 0, length());
  6. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public DataBuffer data() {
  3. return Nd4j.getDataBufferFactory().create(values, 0, length());
  4. }

代码示例来源:origin: deeplearning4j/nd4j

  1. public DataBuffer getPointerE() {
  2. return Nd4j.getDataBufferFactory().create(pointerE, 0, rows());
  3. }

代码示例来源:origin: deeplearning4j/nd4j

  1. DataBuffer dataBuffer = Nd4j.getDataBufferFactory().createInt(result);
  2. return dataBuffer;

代码示例来源:origin: deeplearning4j/nd4j

  1. public DataBuffer getPointerB() {
  2. return Nd4j.getDataBufferFactory().create(pointerB, 0, rows());
  3. }

代码示例来源:origin: deeplearning4j/nd4j

  1. public BaseSparseNDArrayCSR(DataBuffer data, int[] columnsPointers, int[] pointerB, int[] pointerE, int[] shape) {
  2. checkArgument(pointerB.length == pointerE.length);
  3. setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape));
  4. init(shape);
  5. this.values = data;
  6. this.columnsPointers = Nd4j.getDataBufferFactory().createInt(data.length());
  7. this.columnsPointers.setData(columnsPointers);
  8. this.length = columnsPointers.length;
  9. // The size of these pointers are constant
  10. int pointersSpace = rows;
  11. this.pointerB = Nd4j.getDataBufferFactory().createInt(pointersSpace);
  12. this.pointerB.setData(pointerB);
  13. this.pointerE = Nd4j.getDataBufferFactory().createInt(pointersSpace);
  14. this.pointerE.setData(pointerE);
  15. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public INDArray bitmapEncode(INDArray indArray, double threshold) {
  3. DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(indArray.length() / 16 + 5);
  4. INDArray ret = Nd4j.createArrayFromShapeBuffer(buffer, indArray.shapeInfoDataBuffer());
  5. bitmapEncode(indArray, ret, threshold);
  6. return ret;
  7. }

代码示例来源:origin: deeplearning4j/nd4j

  1. init(shape);
  2. int valuesSpace = (int) (data.length * THRESHOLD_MEMORY_ALLOCATION);
  3. this.values = Nd4j.getDataBufferFactory().createDouble(valuesSpace);
  4. this.values.setData(data);
  5. this.columnsPointers = Nd4j.getDataBufferFactory().createInt(valuesSpace);
  6. this.columnsPointers.setData(columnsPointers);
  7. this.length = columnsPointers.length;
  8. int pointersSpace = rows;
  9. this.pointerB = Nd4j.getDataBufferFactory().createInt(pointersSpace);
  10. this.pointerB.setData(pointerB);
  11. this.pointerE = Nd4j.getDataBufferFactory().createInt(pointersSpace);
  12. this.pointerE.setData(pointerE);

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Returns the underlying indices of the element of the given index
  3. * such as there really are in the original ndarray
  4. *
  5. * @param i the index of the element+
  6. * @return a dataBuffer containing the indices of element
  7. * */
  8. public DataBuffer getUnderlyingIndicesOf(int i) {
  9. int from = underlyingRank() * i;
  10. //int to = from + underlyingRank();
  11. int[] res = new int[underlyingRank()];
  12. for(int j = 0; j< underlyingRank(); j++){
  13. res[j] = indices.getInt(from + j);
  14. }
  15. ///int[] arr = Arrays.copyOfRange(indices.asInt(), from, to);
  16. return Nd4j.getDataBufferFactory().createInt(res);
  17. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Returns the indices of the element of the given index in the array context
  3. *
  4. * @param i the index of the element
  5. * @return a dataBuffer containing the indices of element
  6. * */
  7. public DataBuffer getIndicesOf(int i) {
  8. int from = underlyingRank() * i;
  9. int to = from + underlyingRank(); //not included
  10. int[] arr = new int[rank];
  11. int j = 0; // iterator over underlying indices
  12. int k = 0; //iterator over hiddenIdx
  13. for (int dim = 0; dim < rank; dim++) {
  14. if (k < hiddenDimensions().length && hiddenDimensions()[k] == j) {
  15. arr[dim] = 0;
  16. k++;
  17. } else {
  18. arr[dim] = indices.getInt(j);
  19. j++;
  20. }
  21. }
  22. return Nd4j.getDataBufferFactory().createInt(arr);
  23. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public INDArray getrf(INDArray A) {
  3. // FIXME: int cast
  4. if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE)
  5. throw new ND4JArraySizeException();
  6. int m = (int) A.rows();
  7. int n = (int) A.columns();
  8. INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1),
  9. Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst());
  10. int mn = Math.min(m, n);
  11. INDArray IPIV = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(mn),
  12. Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, mn}).getFirst());
  13. if (A.data().dataType() == DataBuffer.Type.DOUBLE)
  14. dgetrf(m, n, A, IPIV, INFO);
  15. else if (A.data().dataType() == DataBuffer.Type.FLOAT)
  16. sgetrf(m, n, A, IPIV, INFO);
  17. else
  18. throw new UnsupportedOperationException();
  19. if (INFO.getInt(0) < 0) {
  20. throw new Error("Parameter #" + INFO.getInt(0) + " to getrf() was not valid");
  21. } else if (INFO.getInt(0) > 0) {
  22. log.warn("The matrix is singular - cannot be used for inverse op. Check L matrix at row " + INFO.getInt(0));
  23. }
  24. return IPIV;
  25. }

代码示例来源:origin: deeplearning4j/nd4j

  1. public NativeRandom(long seed, long numberOfElements) {
  2. this.amplifier = seed;
  3. this.generation = 1;
  4. this.seed = seed;
  5. this.numberOfElements = numberOfElements;
  6. nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
  7. stateBuffer = Nd4j.getDataBufferFactory().createDouble(numberOfElements);
  8. init();
  9. hostPointer = new LongPointer(stateBuffer.addressPointer());
  10. deallocator = NativeRandomDeallocator.getInstance();
  11. pack = new NativePack(statePointer.address(), statePointer);
  12. deallocator.trackStatePointer(pack);
  13. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public void potrf(INDArray A, boolean lower) {
  3. // FIXME: int cast
  4. if (A.columns() > Integer.MAX_VALUE)
  5. throw new ND4JArraySizeException();
  6. byte uplo = (byte) (lower ? 'L' : 'U'); // upper or lower part of the factor desired ?
  7. int n = (int) A.columns();
  8. INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1),
  9. Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst());
  10. if (A.data().dataType() == DataBuffer.Type.DOUBLE)
  11. dpotrf(uplo, n, A, INFO);
  12. else if (A.data().dataType() == DataBuffer.Type.FLOAT)
  13. spotrf(uplo, n, A, INFO);
  14. else
  15. throw new UnsupportedOperationException();
  16. if (INFO.getInt(0) < 0) {
  17. throw new Error("Parameter #" + INFO.getInt(0) + " to potrf() was not valid");
  18. } else if (INFO.getInt(0) > 0) {
  19. throw new Error("The matrix is not positive definite! (potrf fails @ order " + INFO.getInt(0) + ")");
  20. }
  21. return;
  22. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public void gesvd(INDArray A, INDArray S, INDArray U, INDArray VT) {
  3. // FIXME: int cast
  4. if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE)
  5. throw new ND4JArraySizeException();
  6. int m = (int) A.rows();
  7. int n = (int) A.columns();
  8. byte jobu = (byte) (U == null ? 'N' : 'A');
  9. byte jobvt = (byte) (VT == null ? 'N' : 'A');
  10. INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1),
  11. Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst());
  12. if (A.data().dataType() == DataBuffer.Type.DOUBLE)
  13. dgesvd(jobu, jobvt, m, n, A, S, U, VT, INFO);
  14. else if (A.data().dataType() == DataBuffer.Type.FLOAT)
  15. sgesvd(jobu, jobvt, m, n, A, S, U, VT, INFO);
  16. else
  17. throw new UnsupportedOperationException();
  18. if (INFO.getInt(0) < 0) {
  19. throw new Error("Parameter #" + INFO.getInt(0) + " to gesvd() was not valid");
  20. } else if (INFO.getInt(0) > 0) {
  21. log.warn("The matrix contains singular elements. Check S matrix at row " + INFO.getInt(0));
  22. }
  23. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public void geqrf(INDArray A, INDArray R) {
  3. // FIXME: int cast
  4. if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE)
  5. throw new ND4JArraySizeException();
  6. int m = (int) A.rows();
  7. int n = (int) A.columns();
  8. INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1),
  9. Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst());
  10. if (R.rows() != A.columns() || R.columns() != A.columns()) {
  11. throw new Error("geqrf: R must be N x N (n = columns in A)");
  12. }
  13. if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
  14. dgeqrf(m, n, A, R, INFO);
  15. } else if (A.data().dataType() == DataBuffer.Type.FLOAT) {
  16. sgeqrf(m, n, A, R, INFO);
  17. } else {
  18. throw new UnsupportedOperationException();
  19. }
  20. if (INFO.getInt(0) < 0) {
  21. throw new Error("Parameter #" + INFO.getInt(0) + " to getrf() was not valid");
  22. }
  23. }

代码示例来源:origin: deeplearning4j/nd4j

  1. public static Cloner newCloner() {
  2. Cloner cloner = new Cloner();
  3. //Implement custom cloning for INDArrays (default can have problems with off-heap and pointers)
  4. //Sadly: the cloner library does NOT support interfaces here, hence we need to use the actual classes
  5. //cloner.registerFastCloner(INDArray.class, new INDArrayFastCloner()); //Does not work due to interface
  6. IFastCloner fc = new INDArrayFastCloner();
  7. cloner.registerFastCloner(Nd4j.getBackend().getNDArrayClass(), fc);
  8. cloner.registerFastCloner(Nd4j.getBackend().getComplexNDArrayClass(), fc);
  9. //Same thing with DataBuffers: off heap -> cloner library chokes on them, but need to know the concrete
  10. // buffer classes, not just the interface
  11. IFastCloner fc2 = new DataBufferFastCloner();
  12. DataBufferFactory d = Nd4j.getDataBufferFactory();
  13. doReg(cloner, fc2, d.intBufferClass());
  14. doReg(cloner, fc2, d.longBufferClass());
  15. doReg(cloner, fc2, d.halfBufferClass());
  16. doReg(cloner, fc2, d.floatBufferClass());
  17. doReg(cloner, fc2, d.doubleBufferClass());
  18. doReg(cloner, fc2, CompressedDataBuffer.class);
  19. return cloner;
  20. }

代码示例来源:origin: org.nd4j/nd4j-cuda-7.5

  1. protected <T extends Aggregate> DataBuffer getBuffer(Batch<T> batch) {
  2. DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(batch.getSample().getRequiredBatchMemorySize() * 4,
  3. false);
  4. batch.setParamsSurface(buffer);
  5. return buffer;
  6. }

代码示例来源:origin: org.nd4j/nd4j-cuda-10.0

  1. protected <T extends Aggregate> DataBuffer getBuffer(Batch<T> batch) {
  2. DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(batch.getSample().getRequiredBatchMemorySize() * 4,
  3. false);
  4. batch.setParamsSurface(buffer);
  5. return buffer;
  6. }

代码示例来源:origin: org.nd4j/nd4j-api

  1. @Override
  2. public INDArray bitmapEncode(INDArray indArray, double threshold) {
  3. DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(indArray.length() / 16 + 5);
  4. INDArray ret = Nd4j.createArrayFromShapeBuffer(buffer, indArray.shapeInfoDataBuffer());
  5. bitmapEncode(indArray, ret, threshold);
  6. return ret;
  7. }

代码示例来源:origin: org.nd4j/nd4j-native-api

  1. public NativeRandom(long seed, long numberOfElements) {
  2. this.amplifier = seed;
  3. this.generation = 1;
  4. this.seed = seed;
  5. this.numberOfElements = numberOfElements;
  6. nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
  7. stateBuffer = Nd4j.getDataBufferFactory().createDouble(numberOfElements);
  8. init();
  9. hostPointer = new LongPointer(stateBuffer.addressPointer());
  10. deallocator = NativeRandomDeallocator.getInstance();
  11. pack = new NativePack(statePointer.address(), statePointer);
  12. deallocator.trackStatePointer(pack);
  13. }

相关文章