如何使用python中训练的模型在java/tensorflow中添加training=true?

disbfnqx  于 2021-07-12  发布在  Java
关注(0)|答案(0)|浏览(186)

这个问题已经被问了好几次,但没有一个有用的答案。我打开这个,希望能得到一个明确的答案。这里有同样的问题之前;link1、link2、link3、link4、link5、link6、link7等等,只需搜索 training = True 在操作系统上,您将看到几个问题。
主要的问题来自于当一个模型有一个规范化层或退出层时,必须提供 training = True ,执行预测。
其中一个简单的模型是gan。在提供的链接中,生成器模型在python中如下所示:

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Reshape((7, 7, 256)))
    #assert model.output_shape == (None, 7, 7, 256)  # Note: None is the batch size
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    #assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    #assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)
    return model

我已将模型另存为 generator.save("ganModel") 我们可以简单地将模型加载为:; generator = tf.keras.models.load_model("ganModel") .
现在可以做如下预测:;

test_input = tf.random.normal([1,100])

prediction = generator(test_input, training = True) # training has to be set True, otherwise all values are nan or zeros!

现在,当您尝试使用下面的模型在java中执行预测时,问题就开始了。java代码;

TFloat32 input = TFloat32.tensorOf(Shape.of(1,100));
System.out.println(input.shape());

SavedModelBundle theModel = SavedModelBundle.load("ganModel", "serve");

Graph gp = theModel.graph();
java.util.Iterator<Operation> theOps = gp.operations();
while(theOps.hasNext()) {
    Operation theOp = theOps.next();
    System.out.println(theOp);
}

Session theSess = theModel.session();
TFloat32 result = (TFloat32) theSess.runner().feed("serving_default_dense_input", input).fetch("StatefulPartitionedCall").run().get(0);
float[][][][] flt = StdArrays.array4dCopyOf(result);
BufferedImage bfImage = new BufferedImage(28,28, BufferedImage.TYPE_INT_RGB);
for(int i = 0; i < 28; i++) {
    for(int j = 0; j < 28; j++) {
        int RdC = (int) ((int) (flt[0][i][j][0]+1)*127.5);
        int GrC = 0;
        int BlC = 0;
        Color theColor = new Color(RdC, GrC, BlC);
        bfImage.setRGB(i, j, theColor.getRGB());
    }
}

File output = new File("bfImage.png");
try {
    ImageIO.write(bfImage, "PNG", output);
} catch (IOException e) {
    e.printStackTrace();
}

正如我在上面分享的几个链接一样,有几个关于这个问题的问题,但是没有答案,特别是在tf2中。问题是如何在tensorflow/java中为预测设置训练true?我认为,可以提供一个标量布尔Tensor,但如何提供呢? TBool tfBool = TBool.scalarOf(true);

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题