谁来帮帮我我是 RNN 的新手,特别是 LSTM 模型。我正在尝试训练一个 LSTM 模型,用于 * 使用多变量时间序列数据 * 进行异常检测 *。我得到这个错误。
节点:“mean_squared_error/SquaredDifference”需要可广播的形状{{node mean_squared_error/SquaredDifference}} [Op:__inference_train_function_5545]
下面是我的模型:
model = Sequential()
model.add(LSTM(128, input_shape=(x_train.shape[1], x_train.shape[2])))
model.add(Dropout(rate=0.2))
model.add(RepeatVector(x_train.shape[1]))
model.add(LSTM(128, return_sequences=True))
model.add(Dropout(rate=0.2))
model.add(TimeDistributed(Dense(x_train.shape[2])))
model.compile(optimizer='adam', loss='mae')
model.summary()
#x_train shape: (249989, 10, 8)
Model summary:
Model: "sequential_5"
____________________________________________________________________
Layer (type) Output Shape Param #
====================================================================
lstm_22 (LSTM) (None, 10, 50) 11800
dropout_15 (Dropout) (None, 10, 50) 0
lstm_23 (LSTM) (None, 50) 20200
dropout_16 (Dropout) (None, 50) 0
repeat_vector_7 (RepeatVector) (None, 10, 50) 0
lstm_24 (LSTM) (None, 10, 50) 20200
dropout_17 (Dropout) (None, 10, 50) 0
dense_9 (Dense) (None, 10, 32) 1632
dropout_18 (Dropout) (None, 10, 32) 0
dense_10 (Dense) (None, 10, 1) 33
====================================================================
Total params: 53,865
Trainable params: 53,865
Non-trainable params: 0
完整错误消息:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-64-696fd2112f18> in <module>()
1 #Fitting the RNN to the training set
----> 2 model.fit(x_train, y_train, validation_split=0.2, callbacks=callbacks, epochs=25, batch_size=72, shuffle=False)
1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
53 An exception on error.
54 """
---> 55 device_name = ctx.device_name
56 # pylint: disable=protected-access
57 try:
InvalidArgumentError: Graph execution error:
Detected at node 'mean_squared_error/SquaredDifference' defined at (most recent call last):
File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
app.launch_new_instance()
File "/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py", line 846, in launch_instance
app.start()
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py", line 499, in start
self.io_loop.start()
File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 132, in start
self.asyncio_loop.run_forever()
File "/usr/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
self._run_once()
File "/usr/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
handle._run()
File "/usr/lib/python3.7/asyncio/events.py", line 88, in _run
self._context.run(self._callback, *self._args)
File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 122, in _handle_events
handler_func(fileobj, events)
File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 452, in _handle_events
self._handle_recv()
File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 481, in _handle_recv
self._run_callback(callback, msg)
File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 431, in _run_callback
callback(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
return self.dispatch_shell(stream, msg)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
handler(stream, idents, msg)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
user_expressions, allow_stdin)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
interactivity=interactivity, compiler=compiler, result=result)
File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2828, in run_ast_nodes
if self.run_code(code, result):
File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-39-696fd2112f18>", line 2, in <module>
model.fit(x_train, y_train, validation_split=0.2, callbacks=callbacks, epochs=25, batch_size=72, shuffle=False)
File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
return fn(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1384, in fit
threading. If unspecified, `use_multiprocessing` will default to
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1021, in train_function
If an integer, specifies how many training epochs to run before a
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1010, in step_function
the dataset will be consumed, the evaluation will start from the
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1000, in run_step
`steps_per_epoch` argument. This argument is not supported with
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 860, in train_step
validation_data=None,
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 919, in compute_loss
0 = silent, 1 = progress bar, 2 = one line per epoch.
File "/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py", line 201, in __call__
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 141, in __call__
call_fn = tf.__internal__.autograph.tf_convert(self.call, tf.__internal__.autograph.control_status_ctx())
File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 245, in call
ag_fn = tf.__internal__.autograph.tf_convert(self.fn, tf.__internal__.autograph.control_status_ctx())
File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 1329, in mean_squared_error
... 100. * np.mean(np.abs((y_true - y_pred) / y_true), axis=-1))
Node: 'mean_squared_error/SquaredDifference'
required broadcastable shapes
[[{{node mean_squared_error/SquaredDifference}}]] [Op:__inference_train_function_5545]
1条答案
按热度按时间vql8enpb1#
这通常会发生,因为LSTM的最后一层有
return_sequences = True
,这是softmax
层或最终概率计算层的意外行为。