使用deeplearning4j rnn/lstm进行时间序列预测的数据维度不正确?

ar5n3qh5  于 2021-07-03  发布在  Java
关注(0)|答案(0)|浏览(448)

在deeplearning4j中是否有检查我输入数据的维度?我正在尝试对股票数据集进行时间序列预测,当我尝试编译模型时,总是收到相同的错误消息:

Exception in thread "main" java.lang.IllegalStateException: 3D input expected to RNN layer expected, got 2

我的模型的代码如下所示,但我认为问题出在我的数据上,而不是模型本身:

MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .weightInit(WeightInit.XAVIER)
            .list()
            .layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(105).nOut(600).build())
            .layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE)
                    .activation(Activation.TANH).nIn(600).nOut(1).build())
            .build();

是否有方法确保我输入到模型中的dataset/DataSeterator对象使用了正确的维度?谢谢你的帮助。

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题