org.nd4j.linalg.factory.Nd4j.read()方法的使用及代码示例

x33g5p2x  于2022-01-24 转载在 其他  
字(9.3k)|赞(0)|评价(0)|浏览(113)

本文整理了Java中org.nd4j.linalg.factory.Nd4j.read()方法的一些代码示例,展示了Nd4j.read()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Nd4j.read()方法的具体详情如下:
包路径:org.nd4j.linalg.factory.Nd4j
类名称:Nd4j
方法名:read

Nd4j.read介绍

[英]Read in an ndarray from a data input stream
[中]从数据输入流读入数据数组

代码示例

代码示例来源:origin: deeplearning4j/nd4j

/**
 * Read an ndarray from a byte array
 * @param arr the array to read from
 * @return the deserialized ndarray
 * @throws IOException
 */
public static INDArray fromByteArray(byte[] arr) throws IOException {
  ByteArrayInputStream bis = new ByteArrayInputStream(arr);
  INDArray ret = read(bis);
  return ret;
}

代码示例来源:origin: deeplearning4j/nd4j

/**
 * Raad an ndarray from an input stream
 * @param reader the input stream to use
 * @return the given ndarray
 * @throws IOException
 */
public static INDArray read(InputStream reader) throws IOException {
  return read(new DataInputStream(reader));
}

代码示例来源:origin: deeplearning4j/nd4j

private INDArray[] loadINDArrays(int numArrays, DataInputStream dis, boolean isMask) throws IOException {
  INDArray[] result = null;
  if (numArrays > 0) {
    result = new INDArray[numArrays];
    for (int i = 0; i < numArrays; i++) {
      INDArray arr = Nd4j.read(dis);
      result[i] = isMask && arr.equals(EMPTY_MASK_ARRAY_PLACEHOLDER.get()) ? null : arr;
    }
  }
  return result;
}

代码示例来源:origin: deeplearning4j/nd4j

private static NormalizerStats readMinMaxStats(DataInputStream dis) throws IOException {
  return new MinMaxStats(Nd4j.read(dis), Nd4j.read(dis));
}

代码示例来源:origin: deeplearning4j/nd4j

private static NormalizerStats readDistributionStats(DataInputStream dis) throws IOException {
  return new DistributionStats(Nd4j.read(dis), Nd4j.read(dis));
}

代码示例来源:origin: deeplearning4j/nd4j

@Override
public NormalizerStandardize restore(@NonNull InputStream stream) throws IOException {
  DataInputStream dis = new DataInputStream(stream);
  boolean fitLabels = dis.readBoolean();
  NormalizerStandardize result = new NormalizerStandardize(Nd4j.read(dis), Nd4j.read(dis));
  result.fitLabel(fitLabels);
  if (fitLabels) {
    result.setLabelStats(Nd4j.read(dis), Nd4j.read(dis));
  }
  return result;
}

代码示例来源:origin: deeplearning4j/nd4j

/**
 * Read a binary ndarray from the given file
 * @param read the nd array to read
 * @return the loaded ndarray
 * @throws IOException
 */
public static INDArray readBinary(File read) throws IOException {
  BufferedInputStream bis = new BufferedInputStream(new FileInputStream(read));
  DataInputStream dis = new DataInputStream(bis);
  INDArray ret = Nd4j.read(dis);
  dis.close();
  return ret;
}

代码示例来源:origin: deeplearning4j/nd4j

private DataSet read(int idx) throws IOException {
  BufferedInputStream bis = new BufferedInputStream(new FileInputStream(paths.get(idx)[0]));
  DataInputStream dis = new DataInputStream(bis);
  BufferedInputStream labelInputStream = new BufferedInputStream(new FileInputStream(paths.get(idx)[1]));
  DataInputStream labelDis = new DataInputStream(labelInputStream);
  DataSet d = new DataSet(Nd4j.read(dis), Nd4j.read(labelDis));
  dis.close();
  labelDis.close();
  return d;
}

代码示例来源:origin: deeplearning4j/nd4j

@Override
public void load(InputStream from) {
  try {
    DataInputStream dis = from instanceof BufferedInputStream ? new DataInputStream(from)
            : new DataInputStream(new BufferedInputStream(from));
    byte included = dis.readByte();
    boolean hasFeatures = (included & BITMASK_FEATURES_PRESENT) != 0;
    boolean hasLabels = (included & BITMASK_LABELS_PRESENT) != 0;
    boolean hasLabelsSameAsFeatures = (included & BITMASK_LABELS_SAME_AS_FEATURES) != 0;
    boolean hasFeaturesMask = (included & BITMASK_FEATURE_MASK_PRESENT) != 0;
    boolean hasLabelsMask = (included & BITMASK_LABELS_MASK_PRESENT) != 0;
    features = (hasFeatures ? Nd4j.read(dis) : null);
    if (hasLabels) {
      labels = Nd4j.read(dis);
    } else if (hasLabelsSameAsFeatures) {
      labels = features;
    } else {
      labels = null;
    }
    featuresMask = (hasFeaturesMask ? Nd4j.read(dis) : null);
    labelsMask = (hasLabelsMask ? Nd4j.read(dis) : null);
    dis.close();
  } catch (Exception e) {
    throw new RuntimeException("Error loading DataSet",e);
  }
}

代码示例来源:origin: deeplearning4j/nd4j

/**
 * Create an ndarray from a base 64
 * representation
 * @param base64 the base 64 to convert
 * @return the ndarray from base 64
 * @throws IOException
 */
public static INDArray fromBase64(String base64) throws IOException {
  byte[] arr = Base64.decodeBase64(base64);
  ByteArrayInputStream bis = new ByteArrayInputStream(arr);
  DataInputStream dis = new DataInputStream(bis);
  INDArray predict = Nd4j.read(dis);
  return predict;
}

代码示例来源:origin: deeplearning4j/nd4j

@Override
public NormalizerMinMaxScaler restore(@NonNull InputStream stream) throws IOException {
  DataInputStream dis = new DataInputStream(stream);
  boolean fitLabels = dis.readBoolean();
  double targetMin = dis.readDouble();
  double targetMax = dis.readDouble();
  NormalizerMinMaxScaler result = new NormalizerMinMaxScaler(targetMin, targetMax);
  result.fitLabel(fitLabels);
  result.setFeatureStats(Nd4j.read(dis), Nd4j.read(dis));
  if (fitLabels) {
    result.setLabelStats(Nd4j.read(dis), Nd4j.read(dis));
  }
  return result;
}

代码示例来源:origin: deeplearning4j/nd4j

/**
 * Returns a set of arrays
 * from base 64 that is tab delimited.
 * @param base64 the base 64 that's tab delimited
 * @return the set of arrays
 */
public static INDArray[] arraysFromBase64(String base64) throws IOException {
  String[] base64Arr = base64.split("\t");
  INDArray[] ret = new INDArray[base64Arr.length];
  for (int i = 0; i < base64Arr.length; i++) {
    byte[] decode = Base64.decodeBase64(base64Arr[i]);
    ByteArrayInputStream bis = new ByteArrayInputStream(decode);
    DataInputStream dis = new DataInputStream(bis);
    INDArray predict = Nd4j.read(dis);
    ret[i] = predict;
  }
  return ret;
}

代码示例来源:origin: deeplearning4j/nd4j

/**
 * Restore a MultiNormalizerStandardize that was previously serialized by this strategy
 *
 * @param stream the input stream to restore from
 * @return the restored MultiNormalizerStandardize
 * @throws IOException
 */
public MultiNormalizerStandardize restore(@NonNull InputStream stream) throws IOException {
  DataInputStream dis = new DataInputStream(stream);
  boolean fitLabels = dis.readBoolean();
  int numInputs = dis.readInt();
  int numOutputs = dis.readInt();
  MultiNormalizerStandardize result = new MultiNormalizerStandardize();
  result.fitLabel(fitLabels);
  List<DistributionStats> featureStats = new ArrayList<>();
  for (int i = 0; i < numInputs; i++) {
    featureStats.add(new DistributionStats(Nd4j.read(dis), Nd4j.read(dis)));
  }
  result.setFeatureStats(featureStats);
  if (fitLabels) {
    List<DistributionStats> labelStats = new ArrayList<>();
    for (int i = 0; i < numOutputs; i++) {
      labelStats.add(new DistributionStats(Nd4j.read(dis), Nd4j.read(dis)));
    }
    result.setLabelStats(labelStats);
  }
  return result;
}

代码示例来源:origin: deeplearning4j/nd4j

/**
 * Restore a MultiNormalizerMinMaxScaler that was previously serialized by this strategy
 *
 * @param stream the input stream to restore from
 * @return the restored MultiNormalizerMinMaxScaler
 * @throws IOException
 */
public MultiNormalizerMinMaxScaler restore(@NonNull InputStream stream) throws IOException {
  DataInputStream dis = new DataInputStream(stream);
  boolean fitLabels = dis.readBoolean();
  int numInputs = dis.readInt();
  int numOutputs = dis.readInt();
  double targetMin = dis.readDouble();
  double targetMax = dis.readDouble();
  MultiNormalizerMinMaxScaler result = new MultiNormalizerMinMaxScaler(targetMin, targetMax);
  result.fitLabel(fitLabels);
  List<MinMaxStats> featureStats = new ArrayList<>();
  for (int i = 0; i < numInputs; i++) {
    featureStats.add(new MinMaxStats(Nd4j.read(dis), Nd4j.read(dis)));
  }
  result.setFeatureStats(featureStats);
  if (fitLabels) {
    List<MinMaxStats> labelStats = new ArrayList<>();
    for (int i = 0; i < numOutputs; i++) {
      labelStats.add(new MinMaxStats(Nd4j.read(dis), Nd4j.read(dis)));
    }
    result.setLabelStats(labelStats);
  }
  return result;
}

代码示例来源:origin: org.nd4j/nd4j-api

/**
 * Raad an ndarray from an input stream
 * @param reader the input stream to use
 * @return the given ndarray
 * @throws IOException
 */
public static INDArray read(InputStream reader) throws IOException {
  return read(new DataInputStream(reader));
}

代码示例来源:origin: org.nd4j/nd4j-api

/**
 * Read an ndarray from a byte array
 * @param arr the array to read from
 * @return the deserialized ndarray
 * @throws IOException
 */
public static INDArray fromByteArray(byte[] arr) throws IOException {
  ByteArrayInputStream bis = new ByteArrayInputStream(arr);
  INDArray ret = read(bis);
  return ret;
}

代码示例来源:origin: de.datexis/texoo-core

public static INDArray getArrayFromBase64String(String encoded) {
 byte[] decodedBytes = Base64.decodeBase64(encoded);
 ByteArrayInputStream bais = new ByteArrayInputStream(decodedBytes);
 BufferedInputStream bis = new BufferedInputStream(bais);
 try(DataInputStream dis = new DataInputStream(bis)) {
  return Nd4j.read(dis);
 } catch(IOException ex) {
  throw new RuntimeException("Could not create INDArray from Base64 String");
 }
}

代码示例来源:origin: org.nd4j/nd4j-api

/**
 * Read a binary ndarray from the given file
 * @param read the nd array to read
 * @return the loaded ndarray
 * @throws IOException
 */
public static INDArray readBinary(File read) throws IOException {
  BufferedInputStream bis = new BufferedInputStream(new FileInputStream(read));
  DataInputStream dis = new DataInputStream(bis);
  INDArray ret = Nd4j.read(dis);
  dis.close();
  return ret;
}

代码示例来源:origin: org.nd4j/nd4j-api

@Override
public NormalizerStandardize restore(@NonNull InputStream stream) throws IOException {
  DataInputStream dis = new DataInputStream(stream);
  boolean fitLabels = dis.readBoolean();
  NormalizerStandardize result = new NormalizerStandardize(Nd4j.read(dis), Nd4j.read(dis));
  result.fitLabel(fitLabels);
  if (fitLabels) {
    result.setLabelStats(Nd4j.read(dis), Nd4j.read(dis));
  }
  return result;
}

代码示例来源:origin: org.nd4j/nd4j-api

private DataSet read(int idx) throws IOException {
  BufferedInputStream bis = new BufferedInputStream(new FileInputStream(paths.get(idx)[0]));
  DataInputStream dis = new DataInputStream(bis);
  BufferedInputStream labelInputStream = new BufferedInputStream(new FileInputStream(paths.get(idx)[1]));
  DataInputStream labelDis = new DataInputStream(labelInputStream);
  DataSet d = new DataSet(Nd4j.read(dis), Nd4j.read(labelDis));
  dis.close();
  labelDis.close();
  return d;
}

相关文章