org.nd4j.linalg.factory.Nd4j.dataType()方法的使用及代码示例

x33g5p2x  于2022-01-24 转载在 其他  
字(5.8k)|赞(0)|评价(0)|浏览(123)

本文整理了Java中org.nd4j.linalg.factory.Nd4j.dataType()方法的一些代码示例,展示了Nd4j.dataType()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Nd4j.dataType()方法的具体详情如下:
包路径:org.nd4j.linalg.factory.Nd4j
类名称:Nd4j
方法名:dataType

Nd4j.dataType介绍

[英]Returns the data opType used for the runtime
[中]返回用于运行时的数据类型

代码示例

代码示例来源:origin: deeplearning4j/nd4j

  1. public ArrayDescriptor(int[] array) {
  2. this.intArray = array;
  3. this.dtype = DTYPE.INT;
  4. this.bufferType = Nd4j.dataType();
  5. }

代码示例来源:origin: deeplearning4j/nd4j

  1. public ArrayDescriptor(long[] array) {
  2. this.longArray = array;
  3. this.dtype = DTYPE.LONG;
  4. this.bufferType = Nd4j.dataType();
  5. }

代码示例来源:origin: deeplearning4j/nd4j

  1. public ArrayDescriptor(float[] array) {
  2. this.floatArray = array;
  3. this.dtype = DTYPE.FLOAT;
  4. this.bufferType = Nd4j.dataType();
  5. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Returns the data opType for this ndarray
  3. *
  4. * @return the data opType for this ndarray
  5. */
  6. @Override
  7. public DataBuffer.Type dtype() {
  8. return Nd4j.dataType();
  9. }

代码示例来源:origin: deeplearning4j/nd4j

  1. public ArrayDescriptor(double[] array) {
  2. this.doubleArray = array;
  3. this.dtype = DTYPE.DOUBLE;
  4. this.bufferType = Nd4j.dataType();
  5. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public Boolean apply(Number input) {
  3. if (Nd4j.dataType() == DataBuffer.Type.DOUBLE)
  4. return input.doubleValue() != value.doubleValue();
  5. else
  6. return input.floatValue() != value.floatValue();
  7. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public Boolean apply(Number input) {
  3. if (Nd4j.dataType() == DataBuffer.Type.DOUBLE)
  4. return input.doubleValue() == value.doubleValue();
  5. else
  6. return input.floatValue() == value.floatValue();
  7. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * This method returns sizeOf(currentDataType), in bytes
  3. *
  4. * @return number of bytes per element
  5. */
  6. public static int sizeOfDataType() {
  7. return sizeOfDataType(Nd4j.dataType());
  8. }

代码示例来源:origin: deeplearning4j/nd4j

  1. protected DataBuffer.TypeEx getGlobalTypeEx() {
  2. DataBuffer.Type type = Nd4j.dataType();
  3. return convertType(type);
  4. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public void init(INDArray x, INDArray y, INDArray z, long n) {
  3. super.init(x, y, z, n);
  4. if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
  5. this.extraArgs = new Object[] {zeroDouble()};
  6. } else if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
  7. this.extraArgs = new Object[] {zeroFloat()};
  8. } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
  9. this.extraArgs = new Object[] {zeroHalf()};
  10. }
  11. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Create double based on real and imaginary
  3. *
  4. * @param real real component
  5. * @param imag imag component
  6. * @return
  7. */
  8. public static IComplexNumber createComplexNumber(Number real, Number imag) {
  9. if (dataType() == DataBuffer.Type.FLOAT)
  10. return INSTANCE.createFloat(real.floatValue(), imag.floatValue());
  11. return INSTANCE.createDouble(real.doubleValue(), imag.doubleValue());
  12. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public INDArray trueScalar(Number value) {
  3. val dtype = Nd4j.dataType();
  4. switch (dtype) {
  5. case DOUBLE:
  6. return create(new double[] {value.doubleValue()}, new int[] {}, new int[] {}, 0);
  7. case FLOAT:
  8. return create(new float[] {value.floatValue()}, new int[] {}, new int[] {}, 0);
  9. case HALF:
  10. return create(new float[] {value.floatValue()}, new int[] {}, new int[] {}, 0);
  11. default:
  12. throw new UnsupportedOperationException("Unsupported data type: [" + dtype + "]");
  13. }
  14. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Returns the number of bytes
  3. * for the graph
  4. *
  5. * @return
  6. */
  7. public long memoryForGraph() {
  8. return numElements() * DataTypeUtil.lengthForDtype(Nd4j.dataType());
  9. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Create a scalar nd array with the specified value and offset
  3. *
  4. * @param value the value of the scalar
  5. * @return the scalar nd array
  6. */
  7. @Override
  8. public INDArray scalar(double value) {
  9. if (Nd4j.dataType() == DataBuffer.Type.DOUBLE)
  10. return create(new double[] {value}, new int[] {1, 1}, new int[] {1, 1}, 0);
  11. else
  12. return scalar((float) value);
  13. }

代码示例来源:origin: deeplearning4j/nd4j

  1. public static DataBuffer createBufferDetached(float[] data) {
  2. DataBuffer ret;
  3. if (dataType() == DataBuffer.Type.FLOAT)
  4. ret = DATA_BUFFER_FACTORY_INSTANCE.createFloat(data);
  5. else if (dataType() == DataBuffer.Type.HALF)
  6. ret = DATA_BUFFER_FACTORY_INSTANCE.createHalf(data);
  7. else
  8. ret = DATA_BUFFER_FACTORY_INSTANCE.createDouble(ArrayUtil.toDoubles(data));
  9. logCreationIfNecessary(ret);
  10. return ret;
  11. }

代码示例来源:origin: deeplearning4j/nd4j

  1. public static DataBuffer createBufferDetached(double[] data) {
  2. DataBuffer ret;
  3. if (dataType() == DataBuffer.Type.DOUBLE)
  4. ret = DATA_BUFFER_FACTORY_INSTANCE.createDouble(data);
  5. else if (dataType() == DataBuffer.Type.HALF)
  6. ret = DATA_BUFFER_FACTORY_INSTANCE.createHalf(ArrayUtil.toFloats(data));
  7. else
  8. ret = DATA_BUFFER_FACTORY_INSTANCE.createFloat(ArrayUtil.toFloats(data));
  9. logCreationIfNecessary(ret);
  10. return ret;
  11. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public long getRequiredBatchMemorySize() {
  3. long result = maxIntArrays() * maxIntArraySize() * 4;
  4. result += maxArguments() * 8; // pointers
  5. result += maxShapes() * 8; // pointers
  6. result += maxIndexArguments() * 4;
  7. result += maxRealArguments() * (Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 8
  8. : Nd4j.dataType() == DataBuffer.Type.FLOAT ? 4 : 2);
  9. result += 5 * 4; // numArgs
  10. return result * Batch.getBatchLimit();
  11. }
  12. }

代码示例来源:origin: deeplearning4j/nd4j

  1. public static INDArray toNDArray(int[] nums) {
  2. if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
  3. double[] doubles = ArrayUtil.toDoubles(nums);
  4. INDArray create = Nd4j.create(doubles, new int[] {1, nums.length});
  5. return create;
  6. } else {
  7. float[] doubles = ArrayUtil.toFloats(nums);
  8. INDArray create = Nd4j.create(doubles, new int[] {1, nums.length});
  9. return create;
  10. }
  11. }

代码示例来源:origin: deeplearning4j/nd4j

  1. public static INDArray toNDArray(long[] nums) {
  2. if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
  3. double[] doubles = ArrayUtil.toDoubles(nums);
  4. INDArray create = Nd4j.create(doubles, new int[] {1, nums.length});
  5. return create;
  6. } else {
  7. float[] doubles = ArrayUtil.toFloats(nums);
  8. INDArray create = Nd4j.create(doubles, new int[] {1, nums.length});
  9. return create;
  10. }
  11. }

代码示例来源:origin: deeplearning4j/nd4j

  1. public static INDArray toNDArray(int[][] nums) {
  2. if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
  3. double[] doubles = ArrayUtil.toDoubles(nums);
  4. INDArray create = Nd4j.create(doubles, new int[] {nums[0].length, nums.length});
  5. return create;
  6. } else {
  7. float[] doubles = ArrayUtil.toFloats(nums);
  8. INDArray create = Nd4j.create(doubles, new int[] {nums[0].length, nums.length});
  9. return create;
  10. }
  11. }

相关文章