如何解决keras上的输入格式问题

vbopmzt1  于 2021-09-08  发布在  Java
关注(0)|答案(0)|浏览(386)

我正在用keras训练cnn。模型架构如下所示:

model.add(layers.Conv2D(32, (3, 3), activation='relu',input_shape=(150, 150, 3),padding='same'))
model.add(layers.MaxPooling2D((2, 2),strides=(2,2), padding='same'))
model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(layers.MaxPooling2D((2, 2),strides=(2,2), padding='same'))
model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(layers.MaxPooling2D((2, 2), strides=(2,2),padding='same'))
model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(layers.MaxPooling2D((2, 2),strides=(2,2), padding='same'))
model.add(layers.Flatten())
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))

模型的编译和拟合部分为:

model.compile(loss='binary_crossentropy',optimizer=optimizers.RMSprop(lr=1e-4),metrics=['acc'])

train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(train_dir,target_size=(150, 150),batch_size=20, class_mode='binary')
val_generator = test_datagen.flow_from_directory(val_dir,target_size=(150, 150),batch_size=20,class_mode='binary')

history = model.fit_generator(train_generator,steps_per_epoch=100,epochs=30,validation_data=val_generator,validation_steps=50)

这给了我以下错误:

warnings.warn('`Model.fit_generator` is deprecated and '
2021-07-10 13:04:09.841898: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/30
Traceback (most recent call last):
  File "C:/Users/User/PycharmProjects/dog_vs_cat/model.py", line 33, in <module>
    history = model.fit_generator(train_generator,steps_per_epoch=100,epochs=30,validation_data=val_generator,validation_steps=50)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\keras\engine\training.py", line 1932, in fit_generator
    initial_epoch=initial_epoch)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\keras\engine\training.py", line 1158, in fit
    tmp_logs = self.train_function(iterator)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py", line 889, in __call__
    result = self._call(*args,**kwds)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py", line 950, in _call
    return self._stateless_fn(*args,**kwds)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py", line 3024, in __call__
    filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py", line 1961, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py", line 596, in call
    ctx=ctx)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError:  Default MaxPoolingOp only supports NHWC on device type CPU
     [[node sequential/max_pooling2d/MaxPool (defined at \Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\keras\layers\pooling.py:355) ]] [Op:__inference_train_function_1187]

Errors may have originated from an input operation.
Input Source operations connected to node sequential/max_pooling2d/MaxPool:
 sequential/conv2d/Relu (defined at \Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\keras\backend.py:4700)

Function call stack:
train_function

Process finished with exit code 1

我已经搜索过了,建议使用 data_format="channels_last" 作为输入形状,我正在使用。然后,我用 data_format="channels_first" ,它仍然给出相同的错误。
我使用的是keras 2.4.3。
有什么帮助吗?

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题