org.nd4j.linalg.factory.Nd4j.write()方法的使用及代码示例

x33g5p2x  于2022-01-24 转载在 其他  
字(9.1k)|赞(0)|评价(0)|浏览(127)

本文整理了Java中org.nd4j.linalg.factory.Nd4j.write()方法的一些代码示例,展示了Nd4j.write()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Nd4j.write()方法的具体详情如下:
包路径:org.nd4j.linalg.factory.Nd4j
类名称:Nd4j
方法名:write

Nd4j.write介绍

[英]Y Write an ndarray to a writer
[中]我给一个作家写一封信

代码示例

代码示例来源:origin: deeplearning4j/nd4j

/**Y
 * Write an ndarray to a writer
 * @param writer the writer to write to
 * @param write the ndarray to write
 * @throws IOException
 */
public static void write(OutputStream writer, INDArray write) throws IOException {
  DataOutputStream stream = new DataOutputStream(writer);
  write(write, stream);
  stream.close();
}

代码示例来源:origin: deeplearning4j/nd4j

/**
 * Save an ndarray to the given file
 * @param arr the array to save
 * @param saveTo the file to save to
 * @throws IOException
 */
public static void saveBinary(INDArray arr, File saveTo) throws IOException {
  BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(saveTo));
  DataOutputStream dos = new DataOutputStream(bos);
  Nd4j.write(arr, dos);
  dos.flush();
  dos.close();
  bos.close();
}

代码示例来源:origin: deeplearning4j/nd4j

Nd4j.write(features, dos);
if (labels != null && labels != features)
  Nd4j.write(labels, dos);
if (featuresMask != null)
  Nd4j.write(featuresMask, dos);
if (labelsMask != null)
  Nd4j.write(labelsMask, dos);

代码示例来源:origin: deeplearning4j/nd4j

@Override
  public void process(Exchange exchange) throws Exception {
    final INDArray arr = (INDArray) exchange.getIn().getBody();
    ByteArrayOutputStream bos = new ByteArrayOutputStream();
    DataOutputStream dos = new DataOutputStream(bos);
    Nd4j.write(arr, dos);
    byte[] bytes = bos.toByteArray();
    String base64 = Base64.encodeBase64String(bytes);
    exchange.getIn().setBody(base64, String.class);
    String id = UUID.randomUUID().toString();
    exchange.getIn().setHeader(KafkaConstants.KEY, id);
    exchange.getIn().setHeader(KafkaConstants.PARTITION_KEY, id);
  }
}).to(kafkaUri);

代码示例来源:origin: deeplearning4j/nd4j

/**
 * Returns an ndarray
 * as base 64
 * @param arr the array to write
 * @return the base 64 representation of the binary
 * ndarray
 * @throws IOException
 */
public static String base64String(INDArray arr) throws IOException {
  ByteArrayOutputStream bos = new ByteArrayOutputStream();
  DataOutputStream dos = new DataOutputStream(bos);
  Nd4j.write(arr, dos);
  String base64 = Base64.encodeBase64String(bos.toByteArray());
  return base64;
}

代码示例来源:origin: deeplearning4j/nd4j

private void saveINDArrays(INDArray[] arrays, DataOutputStream dos, boolean isMask) throws IOException {
  if (arrays != null && arrays.length > 0) {
    for (INDArray fm : arrays) {
      if (isMask && fm == null) {
        INDArray temp = EMPTY_MASK_ARRAY_PLACEHOLDER.get();
        if(temp == null){
          EMPTY_MASK_ARRAY_PLACEHOLDER.set(Nd4j.create(new float[] {-1}));
          temp = EMPTY_MASK_ARRAY_PLACEHOLDER.get();
        }
        fm = temp;
      }
      Nd4j.write(fm, dos);
    }
  }
}

代码示例来源:origin: deeplearning4j/nd4j

@Override
public void write(@NonNull NormalizerStandardize normalizer, @NonNull OutputStream stream) throws IOException {
  try (DataOutputStream dos = new DataOutputStream(stream)) {
    dos.writeBoolean(normalizer.isFitLabel());
    Nd4j.write(normalizer.getMean(), dos);
    Nd4j.write(normalizer.getStd(), dos);
    if (normalizer.isFitLabel()) {
      Nd4j.write(normalizer.getLabelMean(), dos);
      Nd4j.write(normalizer.getLabelStd(), dos);
    }
    dos.flush();
  }
}

代码示例来源:origin: deeplearning4j/nd4j

private static void writeDistributionStats(DistributionStats normalizerStats, DataOutputStream dos)
        throws IOException {
  dos.writeInt(Strategy.STANDARDIZE.ordinal());
  Nd4j.write(normalizerStats.getMean(), dos);
  Nd4j.write(normalizerStats.getStd(), dos);
}

代码示例来源:origin: deeplearning4j/nd4j

private static void writeMinMaxStats(MinMaxStats normalizerStats, DataOutputStream dos) throws IOException {
  dos.writeInt(Strategy.MIN_MAX.ordinal());
  Nd4j.write(normalizerStats.getLower(), dos);
  Nd4j.write(normalizerStats.getUpper(), dos);
}

代码示例来源:origin: deeplearning4j/nd4j

private String[] writeData(DataSet write) throws IOException {
  String[] ret = new String[2];
  String dataSetId = UUID.randomUUID().toString();
  BufferedOutputStream dataOut =
          new BufferedOutputStream(new FileOutputStream(new File(rootDir, dataSetId + ".bin")));
  DataOutputStream dos = new DataOutputStream(dataOut);
  Nd4j.write(write.getFeatureMatrix(), dos);
  dos.flush();
  dos.close();
  BufferedOutputStream dataOutLabels =
          new BufferedOutputStream(new FileOutputStream(new File(rootDir, dataSetId + ".labels.bin")));
  DataOutputStream dosLabels = new DataOutputStream(dataOutLabels);
  Nd4j.write(write.getLabels(), dosLabels);
  dosLabels.flush();
  dosLabels.close();
  ret[0] = new File(rootDir, dataSetId + ".bin").getAbsolutePath();
  ret[1] = new File(rootDir, dataSetId + ".labels.bin").getAbsolutePath();
  return ret;
}

代码示例来源:origin: deeplearning4j/nd4j

/**
 * Returns a tab delimited base 64
 * representation of the given arrays
 * @param arrays the arrays
 * @return
 * @throws IOException
 */
public static String arraysToBase64(INDArray[] arrays) throws IOException {
  StringBuilder sb = new StringBuilder();
  //tab separate the outputs for de serialization
  for (INDArray outputArr : arrays) {
    ByteArrayOutputStream bos = new ByteArrayOutputStream();
    DataOutputStream dos = new DataOutputStream(bos);
    Nd4j.write(outputArr, dos);
    String base64 = Base64.encodeBase64String(bos.toByteArray());
    sb.append(base64);
    sb.append("\t");
  }
  return sb.toString();
}

代码示例来源:origin: deeplearning4j/nd4j

@Override
public void write(@NonNull NormalizerMinMaxScaler normalizer, @NonNull OutputStream stream) throws IOException {
  try (DataOutputStream dos = new DataOutputStream(stream)) {
    dos.writeBoolean(normalizer.isFitLabel());
    dos.writeDouble(normalizer.getTargetMin());
    dos.writeDouble(normalizer.getTargetMax());
    Nd4j.write(normalizer.getMin(), dos);
    Nd4j.write(normalizer.getMax(), dos);
    if (normalizer.isFitLabel()) {
      Nd4j.write(normalizer.getLabelMin(), dos);
      Nd4j.write(normalizer.getLabelMax(), dos);
    }
    dos.flush();
  }
}

代码示例来源:origin: deeplearning4j/nd4j

/**
 * Serialize a MultiNormalizerStandardize to a output stream
 *
 * @param normalizer the normalizer
 * @param stream     the output stream to write to
 * @throws IOException
 */
public void write(@NonNull MultiNormalizerStandardize normalizer, @NonNull OutputStream stream) throws IOException {
  try (DataOutputStream dos = new DataOutputStream(stream)) {
    dos.writeBoolean(normalizer.isFitLabel());
    dos.writeInt(normalizer.numInputs());
    dos.writeInt(normalizer.isFitLabel() ? normalizer.numOutputs() : -1);
    for (int i = 0; i < normalizer.numInputs(); i++) {
      Nd4j.write(normalizer.getFeatureMean(i), dos);
      Nd4j.write(normalizer.getFeatureStd(i), dos);
    }
    if (normalizer.isFitLabel()) {
      for (int i = 0; i < normalizer.numOutputs(); i++) {
        Nd4j.write(normalizer.getLabelMean(i), dos);
        Nd4j.write(normalizer.getLabelStd(i), dos);
      }
    }
    dos.flush();
  }
}

代码示例来源:origin: deeplearning4j/nd4j

/**
 * Serialize a MultiNormalizerMinMaxScaler to a output stream
 *
 * @param normalizer the normalizer
 * @param stream     the output stream to write to
 * @throws IOException
 */
public void write(@NonNull MultiNormalizerMinMaxScaler normalizer, @NonNull OutputStream stream)
        throws IOException {
  try (DataOutputStream dos = new DataOutputStream(stream)) {
    dos.writeBoolean(normalizer.isFitLabel());
    dos.writeInt(normalizer.numInputs());
    dos.writeInt(normalizer.isFitLabel() ? normalizer.numOutputs() : -1);
    dos.writeDouble(normalizer.getTargetMin());
    dos.writeDouble(normalizer.getTargetMax());
    for (int i = 0; i < normalizer.numInputs(); i++) {
      Nd4j.write(normalizer.getMin(i), dos);
      Nd4j.write(normalizer.getMax(i), dos);
    }
    if (normalizer.isFitLabel()) {
      for (int i = 0; i < normalizer.numOutputs(); i++) {
        Nd4j.write(normalizer.getLabelMin(i), dos);
        Nd4j.write(normalizer.getLabelMax(i), dos);
      }
    }
    dos.flush();
  }
}

代码示例来源:origin: deeplearning4j/nd4j

/**
 * Convert an ndarray to a byte array
 * @param arr the array to convert
 * @return the converted byte array
 * @throws IOException
 */
public static byte[] toByteArray(INDArray arr) throws IOException {
  if (arr.length() * arr.data().getElementSize() >  Integer.MAX_VALUE)
    throw new ND4JIllegalStateException("");
  ByteArrayOutputStream bos = new ByteArrayOutputStream((int) (arr.length() * arr.data().getElementSize()));
  DataOutputStream dos = new DataOutputStream(bos);
  write(arr, dos);
  byte[] ret = bos.toByteArray();
  return ret;
}

代码示例来源:origin: org.nd4j/nd4j-api

private void saveINDArrays(INDArray[] arrays, DataOutputStream dos, boolean isMask) throws IOException {
  if (arrays != null && arrays.length > 0) {
    for (INDArray fm : arrays) {
      if (isMask && fm == null) {
        fm = EMPTY_MASK_ARRAY_PLACEHOLDER;
      }
      Nd4j.write(fm, dos);
    }
  }
}

代码示例来源:origin: org.nd4j/nd4j-x86

/**
 * Write an ndarray to the output stream
 *
 * @param out the ndarray to write
 * @param to  the output stream to write to
 */
@Override
public void write(INDArray out, OutputStream to) throws IOException {
  Nd4j.write(out, new DataOutputStream(to));
}

代码示例来源:origin: org.deeplearning4j/cdh4

/**
 * Q: "is compute() called before complete() is called in last epoch?"
 *
 *
 */
@Override
public void complete(DataOutputStream osStream) throws IOException {
  log.info("IR DBN Master Node: Complete!");
  Nd4j.write(paramVector,osStream);
}

代码示例来源:origin: de.datexis/texoo-core

public static String getArrayAsBase64String(INDArray arr) {
 ByteArrayOutputStream baos = new ByteArrayOutputStream();
 BufferedOutputStream bos = new BufferedOutputStream(baos);
 try(DataOutputStream dos = new DataOutputStream(bos)) {
  Nd4j.write(arr, dos);
  dos.flush();
  byte[] encodedBytes = Base64.encodeBase64(baos.toByteArray());
  return new String(encodedBytes);
 } catch (IOException ex) {
  throw new IllegalArgumentException("Could not encode INDArray as Base64");
 }
}

代码示例来源:origin: org.nd4j/nd4j-api

private static void writeDistributionStats(DistributionStats normalizerStats, DataOutputStream dos)
        throws IOException {
  dos.writeInt(Strategy.STANDARDIZE.ordinal());
  Nd4j.write(normalizerStats.getMean(), dos);
  Nd4j.write(normalizerStats.getStd(), dos);
}

相关文章