org.nd4j.linalg.factory.Nd4j.getDistributions()方法的使用及代码示例

x33g5p2x  于2022-01-24 转载在 其他  
字(7.6k)|赞(0)|评价(0)|浏览(108)

本文整理了Java中org.nd4j.linalg.factory.Nd4j.getDistributions()方法的一些代码示例,展示了Nd4j.getDistributions()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Nd4j.getDistributions()方法的具体详情如下:
包路径:org.nd4j.linalg.factory.Nd4j
类名称:Nd4j
方法名:getDistributions

Nd4j.getDistributions介绍

[英]Get the primary distributions factory
[中]获取主分发工厂

代码示例

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public INDArray doCreate(long[] shape, INDArray paramsView) {
  3. //As per Glorot and Bengio 2010: Uniform distribution U(-s,s) with s = sqrt(6/(fanIn + fanOut))
  4. //Eq 16: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
  5. double s = Math.sqrt(6.0) / Math.sqrt(fanIn + fanOut);
  6. return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-s, s));
  7. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public INDArray doCreate(long[] shape, INDArray paramsView) {
  3. double a = 1.0 / Math.sqrt(fanIn);
  4. return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-a, a));
  5. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public INDArray doCreate(long[] shape, INDArray paramsView) {
  3. double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut));
  4. return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-r, r));
  5. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public INDArray doCreate(long[] shape, INDArray paramsView) {
  3. double b = 3.0 / Math.sqrt(fanIn);
  4. return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-b, b));
  5. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public INDArray doCreate(long[] shape, INDArray paramsView) {
  3. double scalingFanIn = 3.0 / Math.sqrt(fanIn);
  4. return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn));
  5. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public INDArray doCreate(long[] shape, INDArray paramsView) {
  3. double u = Math.sqrt(6.0 / fanIn);
  4. return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn)
  5. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public INDArray doCreate(long[] shape, INDArray paramsView) {
  3. double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2);
  4. return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg));
  5. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public INDArray doCreate(long[] shape, INDArray paramsView) {
  3. double scalingFanOut = 3.0 / Math.sqrt(fanOut);
  4. return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut));
  5. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public INDArray rand(long[] shape, float min, float max, org.nd4j.linalg.api.rng.Random rng) {
  3. //ensure shapes that wind up being scalar end up with the write shape
  4. if (shape.length == 1 && shape[0] == 0) {
  5. shape = new long[] {1, 1};
  6. }
  7. return Nd4j.getDistributions().createUniform(min, max).sample(shape);
  8. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Generates a random matrix between min and max
  3. *
  4. * @param shape the number of rows of the matrix
  5. * @param min the minimum number
  6. * @param max the maximum number
  7. * @param rng the rng to use
  8. * @return a random matrix of the specified shape and range
  9. */
  10. @Override
  11. public INDArray rand(int[] shape, float min, float max, org.nd4j.linalg.api.rng.Random rng) {
  12. //ensure shapes that wind up being scalar end up with the write shape
  13. if (shape.length == 1 && shape[0] == 0) {
  14. shape = new int[] {1, 1};
  15. }
  16. return Nd4j.getDistributions().createUniform(min, max).sample(shape);
  17. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public INDArray rand(int[] shape, double min, double max, org.nd4j.linalg.api.rng.Random rng) {
  3. Nd4j.getRandom().setSeed(rng.getSeed());
  4. return Nd4j.getDistributions().createUniform(min, max).sample(shape);
  5. }

代码示例来源:origin: deeplearning4j/nd4j

  1. @Override
  2. public INDArray rand(long[] shape, double min, double max, org.nd4j.linalg.api.rng.Random rng) {
  3. Nd4j.getRandom().setSeed(rng.getSeed());
  4. return Nd4j.getDistributions().createUniform(min, max).sample(shape);
  5. }

代码示例来源:origin: deeplearning4j/nd4j

  1. /**
  2. * Create an ndarray
  3. * of
  4. * @param seed
  5. * @param rank
  6. * @param numShapes
  7. * @return
  8. */
  9. public static int[][] getRandomBroadCastShape(long seed, int rank, int numShapes) {
  10. Nd4j.getRandom().setSeed(seed);
  11. INDArray coinFlip = Nd4j.getDistributions().createBinomial(1, 0.5).sample(new int[] {numShapes, rank});
  12. int[][] ret = new int[(int) coinFlip.rows()][(int) coinFlip.columns()];
  13. for (int i = 0; i < coinFlip.rows(); i++) {
  14. for (int j = 0; j < coinFlip.columns(); j++) {
  15. int set = coinFlip.getInt(i, j);
  16. if (set > 0)
  17. ret[i][j] = set;
  18. else {
  19. //anything from 0 to 9
  20. ret[i][j] = Nd4j.getRandom().nextInt(9) + 1;
  21. }
  22. }
  23. }
  24. return ret;
  25. }

代码示例来源:origin: org.nd4j/nd4j-api

  1. /**
  2. * Generates a random matrix between min and max
  3. *
  4. * @param shape the number of rows of the matrix
  5. * @param min the minimum number
  6. * @param max the maximum number
  7. * @param rng the rng to use
  8. * @return a random matrix of the specified shape and range
  9. */
  10. @Override
  11. public INDArray rand(int[] shape, float min, float max, org.nd4j.linalg.api.rng.Random rng) {
  12. //ensure shapes that wind up being scalar end up with the write shape
  13. if (shape.length == 1 && shape[0] == 0) {
  14. shape = new int[] {1, 1};
  15. }
  16. return Nd4j.getDistributions().createUniform(min, max).sample(shape);
  17. }

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

  1. @Override
  2. public INDArray preProcess(INDArray input, int miniBatchSize) {
  3. return Nd4j.getDistributions().createBinomial(1, input).sample(input.shape());
  4. }

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

  1. public static org.nd4j.linalg.api.rng.distribution.Distribution createDistribution(Distribution dist) {
  2. if (dist == null)
  3. return null;
  4. if (dist instanceof NormalDistribution) {
  5. NormalDistribution nd = (NormalDistribution) dist;
  6. return Nd4j.getDistributions().createNormal(nd.getMean(), nd.getStd());
  7. }
  8. if (dist instanceof GaussianDistribution) {
  9. GaussianDistribution nd = (GaussianDistribution) dist;
  10. return Nd4j.getDistributions().createNormal(nd.getMean(), nd.getStd());
  11. }
  12. if (dist instanceof UniformDistribution) {
  13. UniformDistribution ud = (UniformDistribution) dist;
  14. return Nd4j.getDistributions().createUniform(ud.getLower(), ud.getUpper());
  15. }
  16. if (dist instanceof BinomialDistribution) {
  17. BinomialDistribution bd = (BinomialDistribution) dist;
  18. return Nd4j.getDistributions().createBinomial(bd.getNumberOfTrials(), bd.getProbabilityOfSuccess());
  19. }
  20. throw new RuntimeException("unknown distribution type: " + dist.getClass());
  21. }
  22. }

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

  1. /**
  2. * Corrupts the given input by doing a binomial sampling
  3. * given the corruption level
  4. * @param x the input to corrupt
  5. * @param corruptionLevel the corruption value
  6. * @return the binomial sampled corrupted input
  7. */
  8. public INDArray getCorruptedInput(INDArray x, double corruptionLevel) {
  9. INDArray corrupted = Nd4j.getDistributions().createBinomial(1, 1 - corruptionLevel).sample(x.shape());
  10. corrupted.muli(x);
  11. return corrupted;
  12. }

代码示例来源:origin: org.nd4j/nd4j-api

  1. @Override
  2. public INDArray rand(int[] shape, double min, double max, org.nd4j.linalg.api.rng.Random rng) {
  3. Nd4j.getRandom().setSeed(rng.getSeed());
  4. return Nd4j.getDistributions().createUniform(min, max).sample(shape);
  5. }

代码示例来源:origin: org.nd4j/canova-nd4j-image

  1. public static void drawMnist(DataSet mnist,INDArray reconstruct) throws InterruptedException {
  2. for(int j = 0; j < mnist.numExamples(); j++) {
  3. INDArray draw1 = mnist.get(j).getFeatureMatrix().mul(255);
  4. INDArray reconstructed2 = reconstruct.getRow(j);
  5. INDArray draw2 = Nd4j.getDistributions().createBinomial(1,reconstructed2).sample(reconstructed2.shape()).mul(255);
  6. DrawReconstruction d = new DrawReconstruction(draw1);
  7. d.title = "REAL";
  8. d.draw();
  9. DrawReconstruction d2 = new DrawReconstruction(draw2,1000,1000);
  10. d2.title = "TEST";
  11. d2.draw();
  12. Thread.sleep(1000);
  13. d.frame.dispose();
  14. d2.frame.dispose();
  15. }
  16. }

代码示例来源:origin: org.datavec/datavec-data-image

  1. public static void drawMnist(DataSet mnist, INDArray reconstruct) throws InterruptedException {
  2. for (int j = 0; j < mnist.numExamples(); j++) {
  3. INDArray draw1 = mnist.get(j).getFeatureMatrix().mul(255);
  4. INDArray reconstructed2 = reconstruct.getRow(j);
  5. INDArray draw2 = Nd4j.getDistributions().createBinomial(1, reconstructed2).sample(reconstructed2.shape())
  6. .mul(255);
  7. DrawReconstruction d = new DrawReconstruction(draw1);
  8. d.title = "REAL";
  9. d.draw();
  10. DrawReconstruction d2 = new DrawReconstruction(draw2, 1000, 1000);
  11. d2.title = "TEST";
  12. d2.draw();
  13. Thread.sleep(1000);
  14. d.frame.dispose();
  15. d2.frame.dispose();
  16. }
  17. }

相关文章