我正在尝试创建vgg+lstm网络。在此代码中,seq_len是1400。
video = Input(shape=(seq_len, 224, 224,3))
cnn_base = VGG16(input_shape=(224, 224, 3), weights = 'imagenet', include_top=False)
cnn_out = GlobalAveragePooling2D()(cnn_base.output)
cnn = Model(cnn_base.input, cnn_out)
cnn.trainable=False
encoded_frames = TimeDistributed(cnn)(video)
encoded_sequence = LSTM(256)(encoded_frames)
hidden_layer = Dense(1024, activation='relu')(encoded_sequence)
outputs = Dense(1)(hidden_layer)
model = Model(video, outputs)
print(model.summary())
history = model.fit(w_train, y_train, epochs=60, batch_size=50, shuffle=True, validation_split=0.2, verbose=10)
print(history.history.keys())
我的错误是:
ValueError: Input 0 is incompatible with layer model_6: expected shape=(None, 1400, 224, 224, 3), found shape=(None, 224, 224, 3)
有人能帮我解决吗?
2条答案
按热度按时间2ekbmq321#
去掉输入shape中的seq_len,因为它会生成一个5维数组,其中shape=(None,seq_len,224,224,3)。
hpcdzsge2#
您可以在LSTM之前添加TimeDistributed(Flatten())层。