我试图做一个文本分类使用ANN.我用keras Python,我从互联网上得到解决它.我的数据字数是1148,但我四舍五入到1200为input_shape.代码如下:
#Arsitektur ANN
model = Sequential()
model.add(Dense(units = 100, activation = 'sigmoid', input_shape=(32, 1200)))
model.add(Dense(units = 2, activation = 'sigmoid'))
opt = Adam (learning_rate=0.001)
model.compile(loss = 'binary_crossentropy', optimizer = opt,
metrics = ['accuracy'])
print(model.summary())
接下来,我用下面的代码包含了超参数:
# Hyperparameter
epochs= 100
batch_size= 32
es = EarlyStopping(monitor="val_loss",mode='min',patience=10)
model_prediction = model.fit(arr_Train_X_Tfidf, Train_Y, epochs=epochs,
batch_size=batch_size, verbose=1,
validation_split=0.1, callbacks =[es])
但得到以下错误:
/usr/local/lib/python3.8/dist-packages/keras/engine/training.py in tf__train_function(iterator)
13 try:
14 do_return = True
---> 15 retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
16 except:
17 do_return = False
ValueError: in user code:
File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1051, in train_function *
return step_function(self, iterator)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1040, in step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1030, in run_step **
outputs = model.train_step(data)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 889, in train_step
y_pred = self(x, training=True)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/usr/local/lib/python3.8/dist-packages/keras/engine/input_spec.py", line 264, in assert_input_compatibility
raise ValueError(f'Input {input_index} of layer "{layer_name}" is '
ValueError: Input 0 of layer "sequential_2" is incompatible with the layer: expected shape=(None, 32, 1200), found shape=(None, 1148)
有人知道问题出在哪里吗?以及如何解决?谢谢
1条答案
按热度按时间bq8i3lrv1#
输入形状不应包括批大小。请尝试以下操作(即不要为input_shape指定“32”):