java DL 4J中的LSTM-所有输出值相同

pkbketx9  于 2023-02-02  发布在  Java
关注(0)|答案(1)|浏览(217)

我尝试使用DeepLearning 4J创建一个简单的LSTM,它有2个输入特性,时间序列长度为1。在训练完网络之后,输入测试数据会产生相同的,任意的结果,不管输入值是什么。2我的代码如下所示。

(更新)

public class LSTMRegression {
    public static final int inputSize = 2,
                            lstmLayerSize = 4,
                            outputSize = 1;
    
    public static final double learningRate = 0.01;

    public static void main(String[] args) {
        int miniBatchSize = 99;
        
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .miniBatch(false)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater(new Sgd(learningRate))
                .list()
                .layer(0, new LSTM.Builder().nIn(inputSize).nOut(lstmLayerSize)
                        .weightInit(WeightInit.XAVIER)
                        .activation(Activation.TANH).build())
//                .layer(1, new LSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize)
//                        .weightInit(WeightInit.XAVIER)
//                        .activation(Activation.SIGMOID).build())
//                .layer(2, new LSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize)
//                        .weightInit(WeightInit.XAVIER)
//                        .activation(Activation.SIGMOID).build())
                .layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE)
                        .weightInit(WeightInit.XAVIER)
                        .activation(Activation.IDENTITY)
                        .nIn(lstmLayerSize).nOut(outputSize).build())
                
                .backpropType(BackpropType.TruncatedBPTT)
                .tBPTTForwardLength(miniBatchSize)
                .tBPTTBackwardLength(miniBatchSize)
                .build();
        
        final var network = new MultiLayerNetwork(conf);
        final DataSet train = getTrain();
        final INDArray test = getTest();
        
        final DataNormalization normalizer = new NormalizerMinMaxScaler(0, 1);
//                                          = new NormalizerStandardize();
        
        normalizer.fitLabel(true);
        normalizer.fit(train);

        normalizer.transform(train);
        normalizer.transform(test);
        
        network.init();
        
        for (int i = 0; i < 100; i++)
            network.fit(train);
        
        final INDArray output = network.output(test);
        
        normalizer.revertLabels(output);
        
        System.out.println(output);
    }
    
    public static INDArray getTest() {
        double[][][] test = new double[][][]{
            {{20}, {203}},
            {{16}, {183}},
            {{20}, {190}},
            {{18.6}, {193}},
            {{18.9}, {184}},
            {{17.2}, {199}},
            {{20}, {190}},
            {{17}, {181}},
            {{19}, {197}},
            {{16.5}, {198}},
            ...
        };
        
        INDArray input = Nd4j.create(test);
        
        return input;
    }
    
    public static DataSet getTrain() {
        double[][][] inputArray = {
            {{18.7}, {181}},
            {{17.4}, {186}},
            {{18}, {195}},
            {{19.3}, {193}},
            {{20.6}, {190}},
            {{17.8}, {181}},
            {{19.6}, {195}},
            {{18.1}, {193}},
            {{20.2}, {190}},
            {{17.1}, {186}},
            ...
        };
        
        double[][] outputArray = {
                {3750},
                {3800},
                {3250},
                {3450},
                {3650},
                {3625},
                {4675},
                {3475},
                {4250},
                {3300},
                ...
        };
        
        INDArray input = Nd4j.create(inputArray);
        INDArray labels = Nd4j.create(outputArray);
        
        return new DataSet(input, labels);
    }
}

下面是一个输出示例:

(更新)

00:06:04.554 [main] WARN  o.d.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [99, 2, 1] and labels with shape [99, 1]
00:06:04.554 [main] WARN  o.d.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [99, 2, 1] and labels with shape [99, 1]
00:06:04.555 [main] WARN  o.d.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [99, 2, 1] and labels with shape [99, 1]
00:06:04.555 [main] WARN  o.d.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [99, 2, 1] and labels with shape [99, 1]
00:06:04.555 [main] WARN  o.d.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [99, 2, 1] and labels with shape [99, 1]
00:06:04.555 [main] WARN  o.d.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [99, 2, 1] and labels with shape [99, 1]
00:06:04.555 [main] WARN  o.d.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [99, 2, 1] and labels with shape [99, 1]
00:06:04.555 [main] WARN  o.d.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [99, 2, 1] and labels with shape [99, 1]
00:06:04.555 [main] WARN  o.d.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [99, 2, 1] and labels with shape [99, 1]
00:06:04.555 [main] WARN  o.d.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [99, 2, 1] and labels with shape [99, 1]

[[[3198.1614]], 

 [[2986.7781]], 

 [[3059.7017]], 

 [[3105.3828]], 

 [[2994.0127]], 

 [[3191.4468]], 

 [[3059.7017]], 

 [[2962.4341]], 

 [[3147.4412]], 

 [[3183.5991]]]

到目前为止,我已经尝试过改变一些超参数,包括更新器(以前的Adam)、隐藏层中的激活函数(以前的ReLU)和学习率;这些都没有解决这个问题。
谢谢你。

yjghlzjz

yjghlzjz1#

这总是一个调整问题或输入数据。在您的情况下,您的输入数据是错误的。
你几乎总是需要规范化你的输入数据,否则你的网络什么也学不到。这对你的输出也是如此。你的输出标签也应该被规范化。
片段如下:

//Normalize data, including labels (fitLabel=true)
        NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
        normalizer.fitLabel(true);
        normalizer.fit(trainData);              //Collect training data statistics

        normalizer.transform(trainData);
        normalizer.transform(testData);

以下是恢复的方法:

//Revert data back to original values for plotting
        normalizer.revert(trainData);
        normalizer.revert(testData);
        normalizer.revertLabels(predicted);

有不同种类的标准化器,下面的只是从0到1。有时NormalizeStandardize在这里可能更好。它将通过减去平均值并除以数据的方差来标准化数据。它将是这样的:

NormalizerStandardize myNormalizer = new NormalizerStandardize();
        myNormalizer.fitLabel(true);
        myNormalizer.fit(sampleDataSet);

之后,您的网络应正常训练。
编辑:如果这不起作用,由于您的数据集的大小,dl4j也有一个旋钮(我在下面的评论中解释了这一点),通常在我们假设您的数据是minibatch的情况下是正确的。不是10个数据点),否则训练可能会到处都是。我们可以关闭小批量假设:

ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
                .miniBatch(false)

对于多层网络也是如此。
同样值得注意的是你的架构对于DL来说是一个非常小的不现实的问题。DL通常需要更多的数据才能正常工作。这就是为什么你会看到层堆叠多次。对于这样的问题,我建议将层的数量减少到1。
在每一层,实际上发生的是一种形式的信息压缩。当数据点数量很小时,当网络饱和时,最终会失去信号。在这种情况下,后续层往往不会很好地学习。

相关问题