我尝试在Keras中使用LSTM进行深度学习。我使用一些信号作为输入(nb_sig
),这些信号在训练过程中可能会发生变化,样本数量固定(nb_sample
)。我想进行参数识别,因此我的输出层是我的参数编号的大小(nb_param
)
因此我创建了大小为(nb_sig
x nb_sample
)和标签为(nb_param
x nb_sample
)的训练集
我的问题是我找不到深度学习模型的正确维度。我尝试了这个:
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, LSTM
nb_sample = 500
nb_sig = 100 # number that may change during the training
nb_param = 10
train = np.random.rand(nb_sig,nb_sample)
label = np.random.rand(nb_sig,nb_param)
print(train.shape,label.shape)
DLmodel = Sequential()
DLmodel.add(LSTM(units=nb_sample, return_sequences=True, input_shape =(None,nb_sample), activation='tanh'))
DLmodel.add(Dense(nb_param, activation="linear", kernel_initializer="uniform"))
DLmodel.compile(loss='mean_squared_error', optimizer='RMSprop', metrics=['accuracy', 'mse'], run_eagerly=True)
print(DLmodel.summary())
DLmodel.fit(train, label, epochs=10, batch_size=nb_sig)
但我收到了以下错误消息:
Traceback (most recent call last):
File "C:\Users\maxime\Desktop\SESAME\PycharmProjects\LargeScale_2022_09_07\di3.py", line 22, in <module>
DLmodel.fit(train, label, epochs=10, batch_size=nb_sig)
File "C:\Python310\lib\site-packages\keras\utils\traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "C:\Python310\lib\site-packages\keras\engine\input_spec.py", line 232, in assert_input_compatibility
raise ValueError(
ValueError: Exception encountered when calling layer "sequential" " f"(type Sequential).
Input 0 of layer "lstm" is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: (100, 500)
Call arguments received by layer "sequential" " f"(type Sequential):
• inputs=tf.Tensor(shape=(100, 500), dtype=float32)
• training=True
• mask=None
我不明白我应该把什么作为LSTM层的input_shape
,因为我在训练期间使用的信号数量会改变,这对我来说不是很清楚。
2条答案
按热度按时间1yjd4xko1#
LSTM的输入应为3D,第一个维度为样本大小(在您的情况下为500)。假设输入的形状为(500,x,y),则input_shape应为(x,y)。
bwitn5fc2#
根据Keras documentation,LSTM层采用三维Tensor作为输入,并且需要一个专用于时间步长的维度。由于您使用的是默认参数time_major=False,因此输入的格式应为[batch,timesteps,feature]。
此related question可帮助您更好地理解LSTM输入形状。