keras 我试图拟合一个lstm模型,但我得到了一个均方误差

yebdmbv4  于 2023-06-06  发布在  其他
关注(0)|答案(1)|浏览(165)

谁来帮帮我我是 RNN 的新手,特别是 LSTM 模型。我正在尝试训练一个 LSTM 模型,用于 * 使用多变量时间序列数据 * 进行异常检测 *。我得到这个错误。
节点:“mean_squared_error/SquaredDifference”需要可广播的形状{{node mean_squared_error/SquaredDifference}} [Op:__inference_train_function_5545]
下面是我的模型:

model = Sequential()
model.add(LSTM(128, input_shape=(x_train.shape[1], x_train.shape[2])))
model.add(Dropout(rate=0.2))
model.add(RepeatVector(x_train.shape[1]))
model.add(LSTM(128, return_sequences=True))
model.add(Dropout(rate=0.2))
model.add(TimeDistributed(Dense(x_train.shape[2])))
model.compile(optimizer='adam', loss='mae')
model.summary()
#x_train shape: (249989, 10, 8)

Model summary:
Model: "sequential_5"
____________________________________________________________________
 Layer (type)                    Output Shape              Param #   
====================================================================
 lstm_22 (LSTM)                  (None, 10, 50)            11800     
                                                                 
 dropout_15 (Dropout)            (None, 10, 50)            0         
                                                                 
 lstm_23 (LSTM)                  (None, 50)                20200     
                                                                 
 dropout_16 (Dropout)            (None, 50)                0         
                                                               
 repeat_vector_7 (RepeatVector)  (None, 10, 50)            0         
                                                                       
 lstm_24 (LSTM)                  (None, 10, 50)            20200     
                                                                 
 dropout_17 (Dropout)            (None, 10, 50)            0         
                                                                 
 dense_9 (Dense)                 (None, 10, 32)            1632      
                                                                 
 dropout_18 (Dropout)            (None, 10, 32)            0         
                                                                 
 dense_10 (Dense)                (None, 10, 1)             33        
                                                                 
====================================================================
Total params: 53,865
Trainable params: 53,865
Non-trainable params: 0

完整错误消息:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-64-696fd2112f18> in <module>()
      1 #Fitting the RNN to the training set
----> 2 model.fit(x_train, y_train, validation_split=0.2, callbacks=callbacks, epochs=25, batch_size=72, shuffle=False)

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
      53     An exception on error.
      54   """
---> 55   device_name = ctx.device_name
      56   # pylint: disable=protected-access
      57   try:

InvalidArgumentError: Graph execution error:

Detected at node 'mean_squared_error/SquaredDifference' defined at (most recent call last):
    File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
      "__main__", mod_spec)
    File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
      exec(code, run_globals)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
      app.launch_new_instance()
    File "/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py", line 846, in launch_instance
      app.start()
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py", line 499, in start
      self.io_loop.start()
    File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 132, in start
      self.asyncio_loop.run_forever()
    File "/usr/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
      self._run_once()
    File "/usr/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
      handle._run()
    File "/usr/lib/python3.7/asyncio/events.py", line 88, in _run
      self._context.run(self._callback, *self._args)
    File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 122, in _handle_events
      handler_func(fileobj, events)
    File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 452, in _handle_events
      self._handle_recv()
    File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 481, in _handle_recv
      self._run_callback(callback, msg)
    File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 431, in _run_callback
      callback(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
      return self.dispatch_shell(stream, msg)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
      handler(stream, idents, msg)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
      user_expressions, allow_stdin)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
      return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
      interactivity=interactivity, compiler=compiler, result=result)
    File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2828, in run_ast_nodes
      if self.run_code(code, result):
    File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "<ipython-input-39-696fd2112f18>", line 2, in <module>
      model.fit(x_train, y_train, validation_split=0.2, callbacks=callbacks, epochs=25, batch_size=72, shuffle=False)
    File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1384, in fit
      threading. If unspecified, `use_multiprocessing` will default to
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1021, in train_function
      If an integer, specifies how many training epochs to run before a
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1010, in step_function
      the dataset will be consumed, the evaluation will start from the
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1000, in run_step
      `steps_per_epoch` argument. This argument is not supported with
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 860, in train_step
      validation_data=None,
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 919, in compute_loss
      0 = silent, 1 = progress bar, 2 = one line per epoch.
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py", line 201, in __call__
      loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 141, in __call__
      call_fn = tf.__internal__.autograph.tf_convert(self.call, tf.__internal__.autograph.control_status_ctx())
    File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 245, in call
      ag_fn = tf.__internal__.autograph.tf_convert(self.fn, tf.__internal__.autograph.control_status_ctx())
    File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 1329, in mean_squared_error
      ...     100. * np.mean(np.abs((y_true - y_pred) / y_true), axis=-1))
Node: 'mean_squared_error/SquaredDifference'
required broadcastable shapes
    [[{{node mean_squared_error/SquaredDifference}}]] [Op:__inference_train_function_5545]
vql8enpb

vql8enpb1#

这通常会发生,因为LSTM的最后一层有return_sequences = True,这是softmax层或最终概率计算层的意外行为。

相关问题