org.nd4j.linalg.factory.Nd4j.read()方法的使用及代码示例

x33g5p2x  于2022-01-24 转载在 其他  
字(9.3k)|赞(0)|评价(0)|浏览(139)

本文整理了Java中org.nd4j.linalg.factory.Nd4j.read()方法的一些代码示例,展示了Nd4j.read()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Nd4j.read()方法的具体详情如下:
包路径:org.nd4j.linalg.factory.Nd4j
类名称:Nd4j
方法名:read

Nd4j.read介绍

[英]Read in an ndarray from a data input stream
[中]从数据输入流读入数据数组

代码示例

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Read an ndarray from a byte array
  3. * @param arr the array to read from
  4. * @return the deserialized ndarray
  5. * @throws IOException
  6. */
  7. public static INDArray fromByteArray(byte[] arr) throws IOException {
  8. ByteArrayInputStream bis = new ByteArrayInputStream(arr);
  9. INDArray ret = read(bis);
  10. return ret;
  11. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Raad an ndarray from an input stream
  3. * @param reader the input stream to use
  4. * @return the given ndarray
  5. * @throws IOException
  6. */
  7. public static INDArray read(InputStream reader) throws IOException {
  8. return read(new DataInputStream(reader));
  9. }

代码示例来源:origin: deeplearning4j/nd4j

  1. private INDArray[] loadINDArrays(int numArrays, DataInputStream dis, boolean isMask) throws IOException {
  2. INDArray[] result = null;
  3. if (numArrays > 0) {
  4. result = new INDArray[numArrays];
  5. for (int i = 0; i < numArrays; i++) {
  6. INDArray arr = Nd4j.read(dis);
  7. result[i] = isMask && arr.equals(EMPTY_MASK_ARRAY_PLACEHOLDER.get()) ? null : arr;
  8. }
  9. }
  10. return result;
  11. }

代码示例来源:origin: deeplearning4j/nd4j

  1. private static NormalizerStats readMinMaxStats(DataInputStream dis) throws IOException {
  2. return new MinMaxStats(Nd4j.read(dis), Nd4j.read(dis));
  3. }

代码示例来源:origin: deeplearning4j/nd4j

  1. private static NormalizerStats readDistributionStats(DataInputStream dis) throws IOException {
  2. return new DistributionStats(Nd4j.read(dis), Nd4j.read(dis));
  3. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public NormalizerStandardize restore(@NonNull InputStream stream) throws IOException {
  3. DataInputStream dis = new DataInputStream(stream);
  4. boolean fitLabels = dis.readBoolean();
  5. NormalizerStandardize result = new NormalizerStandardize(Nd4j.read(dis), Nd4j.read(dis));
  6. result.fitLabel(fitLabels);
  7. if (fitLabels) {
  8. result.setLabelStats(Nd4j.read(dis), Nd4j.read(dis));
  9. }
  10. return result;
  11. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Read a binary ndarray from the given file
  3. * @param read the nd array to read
  4. * @return the loaded ndarray
  5. * @throws IOException
  6. */
  7. public static INDArray readBinary(File read) throws IOException {
  8. BufferedInputStream bis = new BufferedInputStream(new FileInputStream(read));
  9. DataInputStream dis = new DataInputStream(bis);
  10. INDArray ret = Nd4j.read(dis);
  11. dis.close();
  12. return ret;
  13. }

代码示例来源:origin: deeplearning4j/nd4j

  1. private DataSet read(int idx) throws IOException {
  2. BufferedInputStream bis = new BufferedInputStream(new FileInputStream(paths.get(idx)[0]));
  3. DataInputStream dis = new DataInputStream(bis);
  4. BufferedInputStream labelInputStream = new BufferedInputStream(new FileInputStream(paths.get(idx)[1]));
  5. DataInputStream labelDis = new DataInputStream(labelInputStream);
  6. DataSet d = new DataSet(Nd4j.read(dis), Nd4j.read(labelDis));
  7. dis.close();
  8. labelDis.close();
  9. return d;
  10. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public void load(InputStream from) {
  3. try {
  4. DataInputStream dis = from instanceof BufferedInputStream ? new DataInputStream(from)
  5. : new DataInputStream(new BufferedInputStream(from));
  6. byte included = dis.readByte();
  7. boolean hasFeatures = (included & BITMASK_FEATURES_PRESENT) != 0;
  8. boolean hasLabels = (included & BITMASK_LABELS_PRESENT) != 0;
  9. boolean hasLabelsSameAsFeatures = (included & BITMASK_LABELS_SAME_AS_FEATURES) != 0;
  10. boolean hasFeaturesMask = (included & BITMASK_FEATURE_MASK_PRESENT) != 0;
  11. boolean hasLabelsMask = (included & BITMASK_LABELS_MASK_PRESENT) != 0;
  12. features = (hasFeatures ? Nd4j.read(dis) : null);
  13. if (hasLabels) {
  14. labels = Nd4j.read(dis);
  15. } else if (hasLabelsSameAsFeatures) {
  16. labels = features;
  17. } else {
  18. labels = null;
  19. }
  20. featuresMask = (hasFeaturesMask ? Nd4j.read(dis) : null);
  21. labelsMask = (hasLabelsMask ? Nd4j.read(dis) : null);
  22. dis.close();
  23. } catch (Exception e) {
  24. throw new RuntimeException("Error loading DataSet",e);
  25. }
  26. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Create an ndarray from a base 64
  3. * representation
  4. * @param base64 the base 64 to convert
  5. * @return the ndarray from base 64
  6. * @throws IOException
  7. */
  8. public static INDArray fromBase64(String base64) throws IOException {
  9. byte[] arr = Base64.decodeBase64(base64);
  10. ByteArrayInputStream bis = new ByteArrayInputStream(arr);
  11. DataInputStream dis = new DataInputStream(bis);
  12. INDArray predict = Nd4j.read(dis);
  13. return predict;
  14. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public NormalizerMinMaxScaler restore(@NonNull InputStream stream) throws IOException {
  3. DataInputStream dis = new DataInputStream(stream);
  4. boolean fitLabels = dis.readBoolean();
  5. double targetMin = dis.readDouble();
  6. double targetMax = dis.readDouble();
  7. NormalizerMinMaxScaler result = new NormalizerMinMaxScaler(targetMin, targetMax);
  8. result.fitLabel(fitLabels);
  9. result.setFeatureStats(Nd4j.read(dis), Nd4j.read(dis));
  10. if (fitLabels) {
  11. result.setLabelStats(Nd4j.read(dis), Nd4j.read(dis));
  12. }
  13. return result;
  14. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Returns a set of arrays
  3. * from base 64 that is tab delimited.
  4. * @param base64 the base 64 that's tab delimited
  5. * @return the set of arrays
  6. */
  7. public static INDArray[] arraysFromBase64(String base64) throws IOException {
  8. String[] base64Arr = base64.split("\t");
  9. INDArray[] ret = new INDArray[base64Arr.length];
  10. for (int i = 0; i < base64Arr.length; i++) {
  11. byte[] decode = Base64.decodeBase64(base64Arr[i]);
  12. ByteArrayInputStream bis = new ByteArrayInputStream(decode);
  13. DataInputStream dis = new DataInputStream(bis);
  14. INDArray predict = Nd4j.read(dis);
  15. ret[i] = predict;
  16. }
  17. return ret;
  18. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Restore a MultiNormalizerStandardize that was previously serialized by this strategy
  3. *
  4. * @param stream the input stream to restore from
  5. * @return the restored MultiNormalizerStandardize
  6. * @throws IOException
  7. */
  8. public MultiNormalizerStandardize restore(@NonNull InputStream stream) throws IOException {
  9. DataInputStream dis = new DataInputStream(stream);
  10. boolean fitLabels = dis.readBoolean();
  11. int numInputs = dis.readInt();
  12. int numOutputs = dis.readInt();
  13. MultiNormalizerStandardize result = new MultiNormalizerStandardize();
  14. result.fitLabel(fitLabels);
  15. List<DistributionStats> featureStats = new ArrayList<>();
  16. for (int i = 0; i < numInputs; i++) {
  17. featureStats.add(new DistributionStats(Nd4j.read(dis), Nd4j.read(dis)));
  18. }
  19. result.setFeatureStats(featureStats);
  20. if (fitLabels) {
  21. List<DistributionStats> labelStats = new ArrayList<>();
  22. for (int i = 0; i < numOutputs; i++) {
  23. labelStats.add(new DistributionStats(Nd4j.read(dis), Nd4j.read(dis)));
  24. }
  25. result.setLabelStats(labelStats);
  26. }
  27. return result;
  28. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Restore a MultiNormalizerMinMaxScaler that was previously serialized by this strategy
  3. *
  4. * @param stream the input stream to restore from
  5. * @return the restored MultiNormalizerMinMaxScaler
  6. * @throws IOException
  7. */
  8. public MultiNormalizerMinMaxScaler restore(@NonNull InputStream stream) throws IOException {
  9. DataInputStream dis = new DataInputStream(stream);
  10. boolean fitLabels = dis.readBoolean();
  11. int numInputs = dis.readInt();
  12. int numOutputs = dis.readInt();
  13. double targetMin = dis.readDouble();
  14. double targetMax = dis.readDouble();
  15. MultiNormalizerMinMaxScaler result = new MultiNormalizerMinMaxScaler(targetMin, targetMax);
  16. result.fitLabel(fitLabels);
  17. List<MinMaxStats> featureStats = new ArrayList<>();
  18. for (int i = 0; i < numInputs; i++) {
  19. featureStats.add(new MinMaxStats(Nd4j.read(dis), Nd4j.read(dis)));
  20. }
  21. result.setFeatureStats(featureStats);
  22. if (fitLabels) {
  23. List<MinMaxStats> labelStats = new ArrayList<>();
  24. for (int i = 0; i < numOutputs; i++) {
  25. labelStats.add(new MinMaxStats(Nd4j.read(dis), Nd4j.read(dis)));
  26. }
  27. result.setLabelStats(labelStats);
  28. }
  29. return result;
  30. }

代码示例来源:origin: org.nd4j/nd4j-api

  1. /**
  2. * Raad an ndarray from an input stream
  3. * @param reader the input stream to use
  4. * @return the given ndarray
  5. * @throws IOException
  6. */
  7. public static INDArray read(InputStream reader) throws IOException {
  8. return read(new DataInputStream(reader));
  9. }

代码示例来源:origin: org.nd4j/nd4j-api

  1. /**
  2. * Read an ndarray from a byte array
  3. * @param arr the array to read from
  4. * @return the deserialized ndarray
  5. * @throws IOException
  6. */
  7. public static INDArray fromByteArray(byte[] arr) throws IOException {
  8. ByteArrayInputStream bis = new ByteArrayInputStream(arr);
  9. INDArray ret = read(bis);
  10. return ret;
  11. }

代码示例来源:origin: de.datexis/texoo-core

  1. public static INDArray getArrayFromBase64String(String encoded) {
  2. byte[] decodedBytes = Base64.decodeBase64(encoded);
  3. ByteArrayInputStream bais = new ByteArrayInputStream(decodedBytes);
  4. BufferedInputStream bis = new BufferedInputStream(bais);
  5. try(DataInputStream dis = new DataInputStream(bis)) {
  6. return Nd4j.read(dis);
  7. } catch(IOException ex) {
  8. throw new RuntimeException("Could not create INDArray from Base64 String");
  9. }
  10. }

代码示例来源:origin: org.nd4j/nd4j-api

  1. /**
  2. * Read a binary ndarray from the given file
  3. * @param read the nd array to read
  4. * @return the loaded ndarray
  5. * @throws IOException
  6. */
  7. public static INDArray readBinary(File read) throws IOException {
  8. BufferedInputStream bis = new BufferedInputStream(new FileInputStream(read));
  9. DataInputStream dis = new DataInputStream(bis);
  10. INDArray ret = Nd4j.read(dis);
  11. dis.close();
  12. return ret;
  13. }

代码示例来源:origin: org.nd4j/nd4j-api

  1. @Override
  2. public NormalizerStandardize restore(@NonNull InputStream stream) throws IOException {
  3. DataInputStream dis = new DataInputStream(stream);
  4. boolean fitLabels = dis.readBoolean();
  5. NormalizerStandardize result = new NormalizerStandardize(Nd4j.read(dis), Nd4j.read(dis));
  6. result.fitLabel(fitLabels);
  7. if (fitLabels) {
  8. result.setLabelStats(Nd4j.read(dis), Nd4j.read(dis));
  9. }
  10. return result;
  11. }

代码示例来源:origin: org.nd4j/nd4j-api

  1. private DataSet read(int idx) throws IOException {
  2. BufferedInputStream bis = new BufferedInputStream(new FileInputStream(paths.get(idx)[0]));
  3. DataInputStream dis = new DataInputStream(bis);
  4. BufferedInputStream labelInputStream = new BufferedInputStream(new FileInputStream(paths.get(idx)[1]));
  5. DataInputStream labelDis = new DataInputStream(labelInputStream);
  6. DataSet d = new DataSet(Nd4j.read(dis), Nd4j.read(labelDis));
  7. dis.close();
  8. labelDis.close();
  9. return d;
  10. }

相关文章