deeplearning4j:预期的模型类名model(找到functional)(invalidkerasconfigurationexception)

jaxagkaj  于 2021-07-06  发布在  Java
关注(0)|答案(0)|浏览(356)

我试图使用deeplearning4j在java(maven项目)中导入一个python学习的机器学习模型。我在tf.keras中使用一个函数模型。但每当我尝试(文字上)文档告诉我的操作时,它都会给我错误。为了完整起见,我在下面添加了我的代码。
python模型:


# Create a MirroredStrategy.

strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

class lstm_bottleneck(tf.keras.layers.Layer):
    def __init__(self, lstm_units, time_steps,**kwargs):
        self.lstm_units = lstm_units
        self.time_steps = time_steps
        self.lstm_layer = Bidirectional(LSTM(lstm_units, return_sequences=False))
        self.repeat_layer = RepeatVector(time_steps)
        super(lstm_bottleneck, self).__init__(**kwargs)

    def call(self, inputs):
        # just call the two initialized layers
        return self.repeat_layer(self.lstm_layer(inputs))

    def compute_mask(self, inputs, mask=None):
        # return the input_mask directly
        return mask

    def get_config(self):
        cfg = super().get_config()
        return cfg 

with strategy.scope():

  inp1 = Input(shape=(timesteps, 7), name="inp1")
  mask1 = Masking(mask_value=-1.)(inp1)

  enc = Bidirectional(LSTM(55, activation = 'tanh', return_sequences = True, dropout = 0.1, kernel_regularizer=l2(0.01)))(mask1)
  enc = Dropout(0.2)(enc)
  enc = Bidirectional(LSTM(50, activation = 'tanh', return_sequences = True, kernel_regularizer=l2(0.01)))(enc)
  enc = Dropout(0.1)(enc)

  decode = lstm_bottleneck(lstm_units=45, time_steps=timesteps)(enc)

  decode = Bidirectional(LSTM(50, activation = 'tanh', return_sequences = True, kernel_regularizer=l2(0.01)))(decode)
  decode = Dropout(0.2)(decode)
  decode = Bidirectional(LSTM(55, activation = 'tanh', return_sequences = True, kernel_regularizer=l2(0.01)))(decode)
  decode = TimeDistributed(Dense(6, activation="softmax"), name="dec1")(decode)

  new_model = Model(inputs=inp1, outputs = decode)
  new_model.compile(loss= 'categorical_crossentropy', optimizer= tf.keras.optimizers.Adam(lr=0.0005), metrics=['categorical_accuracy'])
  plot_model(new_model, to_file='model.png')
  new_model.summary()

pom.xml文件:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org.example</groupId>
    <artifactId>sampleProject</artifactId>
    <version>1.0-SNAPSHOT</version>

    <dependencies>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>1.0.0-beta6</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-modelimport</artifactId>
            <version>1.0.0-beta6</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-native-platform</artifactId>
            <version>1.0.0-beta6</version>
        </dependency>
    </dependencies>

</project>

java代码:

import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.nd4j.linalg.io.ClassPathResource;

// Loading the model
String fullModel = new ClassPathResource("val_loss_model.h5").getFile().getPath();
thisModel = KerasModelImport.importKerasModelAndWeights(fullModel);

错误:

SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException: Expected model class name Model (found Functional). For more information, see http://deeplearning4j.org/docs/latest/keras-import-overview
    at org.deeplearning4j.nn.modelimport.keras.KerasModel.<init>(KerasModel.java:133)
    at org.deeplearning4j.nn.modelimport.keras.KerasModel.<init>(KerasModel.java:96)
    at org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder.buildModel(KerasModelBuilder.java:307)
    at org.deeplearning4j.nn.modelimport.keras.KerasModelImport.importKerasModelAndWeights(KerasModelImport.java:172)
    at MachineLearningModel.<init>(MachineLearningModel.java:21)
    at SimulatedAnnealing.Optimize(SimulatedAnnealing.java:8)
    at Main.main(Main.java:33)

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题