Keras输入层和输出层的维度问题

uidvcgyl  于 2023-05-23  发布在  其他
关注(0)|答案(1)|浏览(233)

我在用Keras编写网络时遇到了问题。这个网络必须,给定一组属于一个圆的点作为输入,确定这个圆。基本上,该网络具有维度阵列(num circles,num points,2)作为输入,其在2维中表示多个“num circles”的圆的多个“num points”的坐标。网络必须有一个向量(num circles,3)作为输出,其中包含圆心的2个坐标和相应的半径。该网络是:

import numpy as np
import keras
from keras import layers

x_train = array_of_circles  
# Input vector of shape (100, 1000, 2) i.e., 100 circles of 1000 points each

y_train = np.concatenate((array_of_centers, array_of_radius.reshape(-1,1)), axis=1)   
# Output: [x_center, y_center, radius]

# Define the neural network model
model = keras.Sequential([
        layers.InputLayer(input_shape=(num_points, 2)),
        layers.Dense(64, activation='relu'),
        layers.Dense(64, activation='relu'),
        layers.Dense(32, activation='relu'),
        layers.Dense(3)                                                               # Output layer with 3 units for x_center, y_center, and radius
])

# Compile the model
model.compile(optimizer='adam', loss='mean_squared_error')

# Train the model
model.fit(x_train, y_train, epochs=10, batch_size=32)

我试着编译代码,得到了这样的消息:

Incompatible shapes: [32,1000,3] vs. [32,3]
     [[{{node gradient_tape/mean_squared_error/BroadcastGradientArgs}}]] [Op:__inference_train_function_6137]

如何匹配输入和输出形状?

k2fxgqgv

k2fxgqgv1#

不知道这是不是你要找的,但下面的工作对我来说没有错误:

import numpy as np
import keras
from keras import layers

x_train= np.zeros( (100, 1000, 2))
y_train = np.zeros((100, 1000, 3))

print(x_train.shape, y_train.shape)

# Define the neural network model
model = keras.Sequential(
    [
        layers.InputLayer(input_shape=(1000, 2)),
        layers.Dense(64, activation="relu"),
        layers.Dense(64, activation="relu"),
        layers.Dense(32, activation="relu"),
        layers.Dense(3),  # Output layer with 3 units for x_center, y_center, and radius
    ]
)

# Compile the model
model.compile(optimizer="adam", loss="mean_squared_error")

# Train the model
model.fit(x_train, y_train, epochs=10, batch_size=32)

输出:

(100, 1000, 2) (100, 1000, 3)
Epoch 1/10
4/4 [==============================] - 1s 21ms/step - loss: 0.0000e+00
Epoch 2/10
4/4 [==============================] - 0s 20ms/step - loss: 0.0000e+00
Epoch 3/10
4/4 [==============================] - 0s 20ms/step - loss: 0.0000e+00
Epoch 4/10
4/4 [==============================] - 0s 20ms/step - loss: 0.0000e+00
Epoch 5/10
4/4 [==============================] - 0s 19ms/step - loss: 0.0000e+00
Epoch 6/10
4/4 [==============================] - 0s 19ms/step - loss: 0.0000e+00
Epoch 7/10
4/4 [==============================] - 0s 19ms/step - loss: 0.0000e+00
Epoch 8/10
4/4 [==============================] - 0s 19ms/step - loss: 0.0000e+00
Epoch 9/10
4/4 [==============================] - 0s 19ms/step - loss: 0.0000e+00
Epoch 10/10
4/4 [==============================] - 0s 19ms/step - loss: 0.0000e+00
<keras.callbacks.History at 0x170fc8c5bb0>

相关问题