型号组
model_gru = Sequential()
model_gru.add(GRU(128, activation='relu', input_shape=(50, 1), return_sequences=True))
model_gru.add(BatchNormalization())
model_gru.add(Dropout(0.2))
model_gru.add(GRU(64, activation='relu'))
model_gru.add(BatchNormalization())
model_gru.add(Dropout(0.2))
model_gru.add(Dense(32,activation='relu'))
model_gru.add(Dense(5,activation='softmax'))
print(model_gru.summary())
这是我的模型,我得到了错误
model_gru.load_weights(r"C:\Users\Admin/gru_model.h5")
q_test_gru = model_gru.predict(test_data, verbose=0)
y_test_gru = q_test_gru.argmax(1)
你能帮助解决这个错误吗?因为输入的大小有问题。
1条答案
按热度按时间bwleehnv1#
第一层应定义如下: