org.nd4j.linalg.factory.Nd4j.write()方法的使用及代码示例

x33g5p2x  于2022-01-24 转载在 其他  
字(9.1k)|赞(0)|评价(0)|浏览(146)

本文整理了Java中org.nd4j.linalg.factory.Nd4j.write()方法的一些代码示例,展示了Nd4j.write()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Nd4j.write()方法的具体详情如下:
包路径:org.nd4j.linalg.factory.Nd4j
类名称:Nd4j
方法名:write

Nd4j.write介绍

[英]Y Write an ndarray to a writer
[中]我给一个作家写一封信

代码示例

代码示例来源:origin: deeplearning4j/nd4j

  1. /**Y
  2. * Write an ndarray to a writer
  3. * @param writer the writer to write to
  4. * @param write the ndarray to write
  5. * @throws IOException
  6. */
  7. public static void write(OutputStream writer, INDArray write) throws IOException {
  8. DataOutputStream stream = new DataOutputStream(writer);
  9. write(write, stream);
  10. stream.close();
  11. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Save an ndarray to the given file
  3. * @param arr the array to save
  4. * @param saveTo the file to save to
  5. * @throws IOException
  6. */
  7. public static void saveBinary(INDArray arr, File saveTo) throws IOException {
  8. BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(saveTo));
  9. DataOutputStream dos = new DataOutputStream(bos);
  10. Nd4j.write(arr, dos);
  11. dos.flush();
  12. dos.close();
  13. bos.close();
  14. }

代码示例来源:origin: deeplearning4j/nd4j

  1. Nd4j.write(features, dos);
  2. if (labels != null && labels != features)
  3. Nd4j.write(labels, dos);
  4. if (featuresMask != null)
  5. Nd4j.write(featuresMask, dos);
  6. if (labelsMask != null)
  7. Nd4j.write(labelsMask, dos);

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public void process(Exchange exchange) throws Exception {
  3. final INDArray arr = (INDArray) exchange.getIn().getBody();
  4. ByteArrayOutputStream bos = new ByteArrayOutputStream();
  5. DataOutputStream dos = new DataOutputStream(bos);
  6. Nd4j.write(arr, dos);
  7. byte[] bytes = bos.toByteArray();
  8. String base64 = Base64.encodeBase64String(bytes);
  9. exchange.getIn().setBody(base64, String.class);
  10. String id = UUID.randomUUID().toString();
  11. exchange.getIn().setHeader(KafkaConstants.KEY, id);
  12. exchange.getIn().setHeader(KafkaConstants.PARTITION_KEY, id);
  13. }
  14. }).to(kafkaUri);

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Returns an ndarray
  3. * as base 64
  4. * @param arr the array to write
  5. * @return the base 64 representation of the binary
  6. * ndarray
  7. * @throws IOException
  8. */
  9. public static String base64String(INDArray arr) throws IOException {
  10. ByteArrayOutputStream bos = new ByteArrayOutputStream();
  11. DataOutputStream dos = new DataOutputStream(bos);
  12. Nd4j.write(arr, dos);
  13. String base64 = Base64.encodeBase64String(bos.toByteArray());
  14. return base64;
  15. }

代码示例来源:origin: deeplearning4j/nd4j

  1. private void saveINDArrays(INDArray[] arrays, DataOutputStream dos, boolean isMask) throws IOException {
  2. if (arrays != null && arrays.length > 0) {
  3. for (INDArray fm : arrays) {
  4. if (isMask && fm == null) {
  5. INDArray temp = EMPTY_MASK_ARRAY_PLACEHOLDER.get();
  6. if(temp == null){
  7. EMPTY_MASK_ARRAY_PLACEHOLDER.set(Nd4j.create(new float[] {-1}));
  8. temp = EMPTY_MASK_ARRAY_PLACEHOLDER.get();
  9. }
  10. fm = temp;
  11. }
  12. Nd4j.write(fm, dos);
  13. }
  14. }
  15. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public void write(@NonNull NormalizerStandardize normalizer, @NonNull OutputStream stream) throws IOException {
  3. try (DataOutputStream dos = new DataOutputStream(stream)) {
  4. dos.writeBoolean(normalizer.isFitLabel());
  5. Nd4j.write(normalizer.getMean(), dos);
  6. Nd4j.write(normalizer.getStd(), dos);
  7. if (normalizer.isFitLabel()) {
  8. Nd4j.write(normalizer.getLabelMean(), dos);
  9. Nd4j.write(normalizer.getLabelStd(), dos);
  10. }
  11. dos.flush();
  12. }
  13. }

代码示例来源:origin: deeplearning4j/nd4j

  1. private static void writeDistributionStats(DistributionStats normalizerStats, DataOutputStream dos)
  2. throws IOException {
  3. dos.writeInt(Strategy.STANDARDIZE.ordinal());
  4. Nd4j.write(normalizerStats.getMean(), dos);
  5. Nd4j.write(normalizerStats.getStd(), dos);
  6. }

代码示例来源:origin: deeplearning4j/nd4j

  1. private static void writeMinMaxStats(MinMaxStats normalizerStats, DataOutputStream dos) throws IOException {
  2. dos.writeInt(Strategy.MIN_MAX.ordinal());
  3. Nd4j.write(normalizerStats.getLower(), dos);
  4. Nd4j.write(normalizerStats.getUpper(), dos);
  5. }

代码示例来源:origin: deeplearning4j/nd4j

  1. private String[] writeData(DataSet write) throws IOException {
  2. String[] ret = new String[2];
  3. String dataSetId = UUID.randomUUID().toString();
  4. BufferedOutputStream dataOut =
  5. new BufferedOutputStream(new FileOutputStream(new File(rootDir, dataSetId + ".bin")));
  6. DataOutputStream dos = new DataOutputStream(dataOut);
  7. Nd4j.write(write.getFeatureMatrix(), dos);
  8. dos.flush();
  9. dos.close();
  10. BufferedOutputStream dataOutLabels =
  11. new BufferedOutputStream(new FileOutputStream(new File(rootDir, dataSetId + ".labels.bin")));
  12. DataOutputStream dosLabels = new DataOutputStream(dataOutLabels);
  13. Nd4j.write(write.getLabels(), dosLabels);
  14. dosLabels.flush();
  15. dosLabels.close();
  16. ret[0] = new File(rootDir, dataSetId + ".bin").getAbsolutePath();
  17. ret[1] = new File(rootDir, dataSetId + ".labels.bin").getAbsolutePath();
  18. return ret;
  19. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Returns a tab delimited base 64
  3. * representation of the given arrays
  4. * @param arrays the arrays
  5. * @return
  6. * @throws IOException
  7. */
  8. public static String arraysToBase64(INDArray[] arrays) throws IOException {
  9. StringBuilder sb = new StringBuilder();
  10. //tab separate the outputs for de serialization
  11. for (INDArray outputArr : arrays) {
  12. ByteArrayOutputStream bos = new ByteArrayOutputStream();
  13. DataOutputStream dos = new DataOutputStream(bos);
  14. Nd4j.write(outputArr, dos);
  15. String base64 = Base64.encodeBase64String(bos.toByteArray());
  16. sb.append(base64);
  17. sb.append("\t");
  18. }
  19. return sb.toString();
  20. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public void write(@NonNull NormalizerMinMaxScaler normalizer, @NonNull OutputStream stream) throws IOException {
  3. try (DataOutputStream dos = new DataOutputStream(stream)) {
  4. dos.writeBoolean(normalizer.isFitLabel());
  5. dos.writeDouble(normalizer.getTargetMin());
  6. dos.writeDouble(normalizer.getTargetMax());
  7. Nd4j.write(normalizer.getMin(), dos);
  8. Nd4j.write(normalizer.getMax(), dos);
  9. if (normalizer.isFitLabel()) {
  10. Nd4j.write(normalizer.getLabelMin(), dos);
  11. Nd4j.write(normalizer.getLabelMax(), dos);
  12. }
  13. dos.flush();
  14. }
  15. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Serialize a MultiNormalizerStandardize to a output stream
  3. *
  4. * @param normalizer the normalizer
  5. * @param stream the output stream to write to
  6. * @throws IOException
  7. */
  8. public void write(@NonNull MultiNormalizerStandardize normalizer, @NonNull OutputStream stream) throws IOException {
  9. try (DataOutputStream dos = new DataOutputStream(stream)) {
  10. dos.writeBoolean(normalizer.isFitLabel());
  11. dos.writeInt(normalizer.numInputs());
  12. dos.writeInt(normalizer.isFitLabel() ? normalizer.numOutputs() : -1);
  13. for (int i = 0; i < normalizer.numInputs(); i++) {
  14. Nd4j.write(normalizer.getFeatureMean(i), dos);
  15. Nd4j.write(normalizer.getFeatureStd(i), dos);
  16. }
  17. if (normalizer.isFitLabel()) {
  18. for (int i = 0; i < normalizer.numOutputs(); i++) {
  19. Nd4j.write(normalizer.getLabelMean(i), dos);
  20. Nd4j.write(normalizer.getLabelStd(i), dos);
  21. }
  22. }
  23. dos.flush();
  24. }
  25. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Serialize a MultiNormalizerMinMaxScaler to a output stream
  3. *
  4. * @param normalizer the normalizer
  5. * @param stream the output stream to write to
  6. * @throws IOException
  7. */
  8. public void write(@NonNull MultiNormalizerMinMaxScaler normalizer, @NonNull OutputStream stream)
  9. throws IOException {
  10. try (DataOutputStream dos = new DataOutputStream(stream)) {
  11. dos.writeBoolean(normalizer.isFitLabel());
  12. dos.writeInt(normalizer.numInputs());
  13. dos.writeInt(normalizer.isFitLabel() ? normalizer.numOutputs() : -1);
  14. dos.writeDouble(normalizer.getTargetMin());
  15. dos.writeDouble(normalizer.getTargetMax());
  16. for (int i = 0; i < normalizer.numInputs(); i++) {
  17. Nd4j.write(normalizer.getMin(i), dos);
  18. Nd4j.write(normalizer.getMax(i), dos);
  19. }
  20. if (normalizer.isFitLabel()) {
  21. for (int i = 0; i < normalizer.numOutputs(); i++) {
  22. Nd4j.write(normalizer.getLabelMin(i), dos);
  23. Nd4j.write(normalizer.getLabelMax(i), dos);
  24. }
  25. }
  26. dos.flush();
  27. }
  28. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Convert an ndarray to a byte array
  3. * @param arr the array to convert
  4. * @return the converted byte array
  5. * @throws IOException
  6. */
  7. public static byte[] toByteArray(INDArray arr) throws IOException {
  8. if (arr.length() * arr.data().getElementSize() > Integer.MAX_VALUE)
  9. throw new ND4JIllegalStateException("");
  10. ByteArrayOutputStream bos = new ByteArrayOutputStream((int) (arr.length() * arr.data().getElementSize()));
  11. DataOutputStream dos = new DataOutputStream(bos);
  12. write(arr, dos);
  13. byte[] ret = bos.toByteArray();
  14. return ret;
  15. }

代码示例来源:origin: org.nd4j/nd4j-api

  1. private void saveINDArrays(INDArray[] arrays, DataOutputStream dos, boolean isMask) throws IOException {
  2. if (arrays != null && arrays.length > 0) {
  3. for (INDArray fm : arrays) {
  4. if (isMask && fm == null) {
  5. fm = EMPTY_MASK_ARRAY_PLACEHOLDER;
  6. }
  7. Nd4j.write(fm, dos);
  8. }
  9. }
  10. }

代码示例来源:origin: org.nd4j/nd4j-x86

  1. /**
  2. * Write an ndarray to the output stream
  3. *
  4. * @param out the ndarray to write
  5. * @param to the output stream to write to
  6. */
  7. @Override
  8. public void write(INDArray out, OutputStream to) throws IOException {
  9. Nd4j.write(out, new DataOutputStream(to));
  10. }

代码示例来源:origin: org.deeplearning4j/cdh4

  1. /**
  2. * Q: "is compute() called before complete() is called in last epoch?"
  3. *
  4. *
  5. */
  6. @Override
  7. public void complete(DataOutputStream osStream) throws IOException {
  8. log.info("IR DBN Master Node: Complete!");
  9. Nd4j.write(paramVector,osStream);
  10. }

代码示例来源:origin: de.datexis/texoo-core

  1. public static String getArrayAsBase64String(INDArray arr) {
  2. ByteArrayOutputStream baos = new ByteArrayOutputStream();
  3. BufferedOutputStream bos = new BufferedOutputStream(baos);
  4. try(DataOutputStream dos = new DataOutputStream(bos)) {
  5. Nd4j.write(arr, dos);
  6. dos.flush();
  7. byte[] encodedBytes = Base64.encodeBase64(baos.toByteArray());
  8. return new String(encodedBytes);
  9. } catch (IOException ex) {
  10. throw new IllegalArgumentException("Could not encode INDArray as Base64");
  11. }
  12. }

代码示例来源:origin: org.nd4j/nd4j-api

  1. private static void writeDistributionStats(DistributionStats normalizerStats, DataOutputStream dos)
  2. throws IOException {
  3. dos.writeInt(Strategy.STANDARDIZE.ordinal());
  4. Nd4j.write(normalizerStats.getMean(), dos);
  5. Nd4j.write(normalizerStats.getStd(), dos);
  6. }

相关文章