keras CNN LSTM视频分类

2ic8powd  于 2022-11-13  发布在  其他
关注(0)|答案(2)|浏览(158)

我正在尝试创建vgg+lstm网络。在此代码中,seq_len是1400。

video = Input(shape=(seq_len, 224, 224,3))
cnn_base = VGG16(input_shape=(224, 224, 3), weights = 'imagenet', include_top=False)
cnn_out = GlobalAveragePooling2D()(cnn_base.output)
cnn = Model(cnn_base.input, cnn_out)
cnn.trainable=False

encoded_frames = TimeDistributed(cnn)(video)
encoded_sequence = LSTM(256)(encoded_frames)
hidden_layer = Dense(1024, activation='relu')(encoded_sequence)
outputs = Dense(1)(hidden_layer)
model = Model(video, outputs)
print(model.summary())

history = model.fit(w_train, y_train, epochs=60, batch_size=50, shuffle=True, validation_split=0.2, verbose=10)
print(history.history.keys())

我的错误是:

ValueError: Input 0 is incompatible with layer model_6: expected shape=(None, 1400, 224, 224, 3), found shape=(None, 224, 224, 3)

有人能帮我解决吗?

2ekbmq32

2ekbmq321#

去掉输入shape中的seq_len,因为它会生成一个5维数组,其中shape=(None,seq_len,224,224,3)。

hpcdzsge

hpcdzsge2#

您可以在LSTM之前添加TimeDistributed(Flatten())层。

相关问题