keras 硬盘驱动器上图像的自动编码器中的Tensorflow数据集问题

niwlg2el  于 2022-11-13  发布在  其他
关注(0)|答案(1)|浏览(161)

我对Tensorflow中的数据集构造感到困惑,无法让我的自动编码器适合我的数据。我不断收到错误,希望有人能看看这个,看看我哪里出错了。我试着只适合数据,而不是批迭代器,也收到了同样的错误。我甚至试着将自己的数据集构造为一个numpy数组,但我完全不明白它在寻找什么。所以这就是我的地方。m当前为:

import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.image import imread
import matplotlib.image as mpimg
import cv2
# Technically not necessary in newest versions of jupyter
%matplotlib inline

from google.colab import drive
drive.mount('/content/gdrive')

my_data_dir = '/content/gdrive/MyDrive/Skyrmion Vision/testFiles/train/'
images = os.listdir(my_data_dir)

data = tf.keras.utils.image_dataset_from_directory('/content/gdrive/MyDrive/Skyrmion Vision/testFiles/train/',batch_size=1,image_size=(171,256))

data_iterator = data.as_numpy_iterator()
batch = data_iterator.next()

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Flatten,Reshape
from tensorflow.keras.optimizers import SGD

encoder = Sequential()
encoder.add(Flatten(input_shape=[171,256]))
encoder.add(Dense(400,activation='relu'))
encoder.add(Dense(200,activation='relu'))
encoder.add(Dense(100,activation='relu'))
encoder.add(Dense(50,activation='relu'))
encoder.add(Dense(25,activation='relu'))

decoder = Sequential()
decoder.add(Dense(50,input_shape=[25],activation='relu'))
decoder.add(Dense(100,activation='relu'))
decoder.add(Dense(200,activation='relu'))
decoder.add(Dense(400,activation='relu'))
decoder.add(Dense(171*256,activation='sigmoid'))

decoder.add(Reshape([171,256]))
autoencoder = Sequential([encoder,decoder])
autoencoder.compile(loss='binary_crossentropy',optimizer=SGD(lr=1.5),metrics=['accuracy'])
autoencoder.fit(batch,batch,epochs=5)

这给了我一个错误,我不太清楚需要修正什么。显然有形状错误?

Epoch 1/5
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-67-aa659ef4ed20> in <module>
----> 1 autoencoder.fit(batch,batch,epochs=5)

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in autograph_handler(*args, **kwargs)
   1145           except Exception as e:  # pylint:disable=broad-except
   1146             if hasattr(e, "ag_error_metadata"):
-> 1147               raise e.ag_error_metadata.to_exception(e)
   1148             else:
   1149               raise

ValueError: in user code:

    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1021, in train_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1010, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1000, in run_step  **
        outputs = model.train_step(data)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 859, in train_step
        y_pred = self(x, training=True)
    File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/input_spec.py", line 200, in assert_input_compatibility
        raise ValueError(f'Layer "{layer_name}" expects {len(input_spec)} input(s),'

    ValueError: Layer "sequential_2" expects 1 input(s), but it received 2 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(None, 171, 256, 3) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int32>]```
uplii1fm

uplii1fm1#

您可以尝试:

...
autoencoder.fit(batch[0],batch[0], epochs=5)

因为你的批数据是一个由图像和标签组成的元组,而你实际上只对图像感兴趣。否则,只需要过滤掉标签,并将数据集馈送到模型中:

data = data.map(lambda x, y: (x, x))
autoencoder.fit(data, epochs=5)

最后两层应该是:

decoder.add(Dense(171*256 * 3,activation='sigmoid'))

decoder.add(Reshape([171,256, 3]))

因为您正在处理3通道图像。您的编码器还需要input_shape:

encoder.add(Flatten(input_shape=[171,256, 3]))

另请参见SO线程是否可以在Keras中使用image_dataset_from_directory()和卷积自动编码器?了解更多信息。

相关问题