keras 如何识别张流角速度中的问题图像?

lsmd5eda  于 2023-03-18  发布在  其他
关注(0)|答案(1)|浏览(121)

我正在尝试加载一个本Map像数据集并使用它来训练我的模型。我正在像这样加载数据集。

data_load = tk.utils.image_dataset_from_directory(
            dir,
            labels="inferred",
            batch_size=128,
            image_size=image_shape,
            shuffle=True,
            seed=42,
            validation_split=0.2,
            subset="training",
        )

这里 dir 是我的数据存储的本地路径。2当我使用这个数据训练我的模型时,使用 model.fit,我得到这个错误。

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
c:\Users\HP\Desktop\SBU\Courses\spring23\ese577\Labs\Lab3\lenet.ipynb Cell 8 in 2
      1 epoch = 15
----> 2 hist = model.fit(data.train, batch_size=batch_size, epochs=epoch)

File c:\python10\lib\site-packages\keras\utils\traceback_utils.py:70, in filter_traceback..error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File c:\python10\lib\site-packages\tensorflow\python\eager\execute.py:52, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     50 try:
     51   ctx.ensure_initialized()
---> 52   tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     53                                       inputs, attrs, num_outputs)
     54 except core._NotOkStatusException as e:
     55   if name is not None:

InvalidArgumentError: Graph execution error:

Number of channels inherent in the image must be 1, 3 or 4, was 2
     [[{{node decode_image/DecodeImage}}]]
     [[IteratorGetNext]] [Op:__inference_train_function_4955]

有趣的是它总是在这个阶段抛出这个错误

Epoch 1/15
  9/124 [=>............................] - ETA: 6:36 - loss: 5.0484 - accuracy: 0.4852

当我搜索类似的错误时,我发现它通常是在阅读bmp图像时观察到的,我所有的图像都是jpg,但我仍然得到这个错误。
如何修复这个错误,或者如何识别坏图像,以便我可以从数据集中删除它并继续我的训练?

sbdsn5lh

sbdsn5lh1#

你将不得不过滤掉任何引起问题的图像。2我使用下面的代码来处理一个目录中的所有图像,并检测有缺陷的图像,这样我就可以从数据集中删除它们。

import os
import cv2
from tqdm import tqdm
datadir=r'c:\datasets\autism\test'# path to  directory with class sub directories holding the image files
bad_img_list=[] # will be a list of defective images
classes=sorted(os.listdir(datadir)) # a list of classes  within the datadir
for klass in classes: # iterate through each class
    classpath=os.path.join(datadir, klass)
    flist=sorted(os.listdir(classpath)) # list of files in the current class   
    for f in tqdm(flist, ncols=100, unit='files', colour='blue', desc=klass): # iterate through the files
        fpath=os.path.join(classpath,f) # path to image file
        try:
            index=f.rfind('.') # find the rightmost . in f
            ext=f[index+1:].lower() # get the files extension and convert to lower case
            good_ext=['jpg', 'jpeg', 'bmp', 'png']# list of allowable extension for image_dataset_from_directory
            if ext not in good_ext:
                raise ValueError('image had improper extension') # create an exception so the file will be appended to bad_img_list                
            img=cv2.imread(fpath) # read in the image
            shape=img.shape # get the image shape (height, width) or (height, width, channels)
            count=len(shape)
            if count == 2: # if shapeis (width, height) image is single channel
                channels=1
            else:
                channels=shape[2] # shapeis (width, height, channels)
            if channels == 2:
                raise ValueError('image had 2 channels') # create an exception so the file will be appended to bad_img_list
        except:
            bad_img_list.append(fpath) # append to bad_img_list if there is an exception
if len(bad_img_list) >0:
    print('below is a list of defective image filepaths')
    for f in bad_img_list:
        print (f)

相关问题