Keras调用model.fit,其中x是np的二元组,ndarray

mrwjdhj3  于 2022-11-24  发布在  其他
关注(0)|答案(1)|浏览(186)

我有一个回归tf.keras.Model,它包含:

  • x: tuple[np.ndarray, np.ndarray],其中两个项目具有不同的形体
  • 形状为(128, 1152)(1, 256)
  • y: float

我把我的模型和训练整理成这样:

class MyModel(tf.keras.Model):

    def __init__(self):
        ...  # Omitted for brevity

    def call(self, inputs: tuple[tf.Tensor, tf.Tensor], training=None, mask=None):
        # Unpacks the two-tuple
        weights_1, weights_2 = inputs
        ...  # Omitted for brevity

# NOTE: item 0's shape is (128, 1152), item 1's shape is (1, 256)
datapoint_x: tuple[np.ndarray, np.ndarray]
datapoint_y: float

model = MyModel()
model(inputs=datapoint_x)  # Works fine

然而,当我转到fit模型时,我得到一个Exception

>>> model.fit(x=datapoint_x, y=np.array(datapoint_y))

Traceback (most recent call last):
  File "/path/to/python3.10/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-5-a5dfb3dd4846>", line 1, in <module>
    model.fit(x=datapoint_x, y=np.array(datapoint_y))
  File "/path/to/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/path/to/python3.10/site-packages/tensorflow/python/framework/tensor_shape.py", line 910, in __getitem__
    return self._dims[key]
IndexError: tuple index out of range

我研究了这个,self._dims()key0
在一个有两个元组x的数据集上,调用Model.fit的正确方法是什么?

r6l8ljro

r6l8ljro1#

答案是Model.fit正在对x和y进行迭代,所以我必须在x[0]y之前添加一个批处理维。
这可以使用np.newaxisnp.expand_dims轻松完成。

import numpy as np

# NOTE: item 0's shape is (128, 1152), item 1's shape is (1, 256)
datapoint_x: tuple[np.ndarray, np.ndarray]
datapoint_y: float

# NOTE: now item 0's shape is (1, 128, 1152), and item 1's shape remains (1, 256)
batch_x = (datapoint_x[0][np.newaxis, :], datapoint_x[1])
# NOTE: now y's shape is (1,)
batch_y = np.array([datapoint_y])

model.fit(x=batch_x, y=batch_y)

相关问题