我正尝试用我自己的数据集重新创建这个作品:https://www.kaggle.com/code/amyjang/tensorflow-pneumonia-classification-on-x-rays/notebook
我对代码做了一些细微的调整以适应我的数据,但我不认为这是导致这里出现问题的原因;当然有可能。
我的代码:
AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE = 16
IMAGE_SIZE = [180, 180]
EPOCHS = 25
CLASS_NAMES = np.array(['active', 'inactive'])
train_list_ds = tf.data.Dataset.from_tensor_slices(glob.glob(f'{WORKING}/train/*/*'))
val_list_ds = tf.data.Dataset.from_tensor_slices(glob.glob(f'{WORKING}/val/*/*'))
test_list_ds = tf.data.Dataset.from_tensor_slices(glob.glob(f'{WORKING}/test/*/*'))
#print(next(train_list_ds.batch(60_000).as_numpy_iterator())[:5])
def get_label(path):
# switcher that converts the keywords to 1 or zero
if tf.strings.split(path, os.path.sep) == "inactive":
return 0 # inactive
else:
return 1 # active
def decode_img(img):
# convert the compressed string to a 3D uint8 tensor
img = tf.image.decode_jpeg(img, channels=3)
# Use `convert_image_dtype` to convert to floats in the [0,1] range.
img = tf.image.convert_image_dtype(img, tf.float32)
# resize the image to the desired size.
resizedImage = tf.image.resize(img, IMAGE_SIZE)
return resizedImage
def process_path(file_path):
label = get_label(file_path)
# load the raw data from the file as a string
img = tf.io.read_file(file_path)
img = decode_img(img)
return img, label
train_ds = train_list_ds.map(process_path, num_parallel_calls=AUTOTUNE)
val_ds = val_list_ds.map(process_path, num_parallel_calls=AUTOTUNE)
test_ds = test_list_ds.map(process_path, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.batch(BATCH_SIZE)
image_batch, label_batch = next(iter(train_ds))
和错误:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-21-d0b00c7e96e2> in <module>
68 test_ds = test_ds.batch(BATCH_SIZE)
69
---> 70 image_batch, label_batch = next(iter(train_ds))
71
72 # Use buffered prefetching so we can yield data from disk without having I/O become blocking.
3 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
7184 def raise_from_not_ok_status(e, name):
7185 e.message += (" name: " + name if name is not None else "")
-> 7186 raise core._status_to_exception(e) from None # pylint: disable=protected-access
7187
7188
InvalidArgumentError: Input to reshape is a tensor with 10 values, but the requested shape has 1
[[{{node Reshape}}]] [Op:IteratorGetNext]
我可以从错误中收集到,我在调整大小时有一个不匹配,我相信它指向了错误的位置,虽然它是与resizedImage = tf.image.resize(img, IMAGE_SIZE)
行有关
我已经看过了文档和SO,但是我找不到任何可能对我有帮助的东西。我已经试着通过在屏幕上打印东西来尽我所能地调试。我也试着改变IMAGE_SIZE
,但是没有什么区别。
在图像上,它们在磁盘上的大小不同,以JPG格式。我希望这不重要,因为我们可以使用这一步来调整它们的大小,以便稍后由模型处理。
其他有用的信息是,我正在开发专业版的GoogleCollab,文件存储在Google驱动器中。我试着先发制人地写一些人们可能会问的东西。
最后,Amy在Kaggle上的代码中有一个prepare_for_training
函数,它在我在上面的代码中提供的最后一行之前被调用。我可以在不调用该函数的情况下触发同样的错误,所以我故意省略了它,以帮助保持示例代码的简洁。如果你想看它,笔记本中有一个指向它的快速链接:https://www.kaggle.com/code/amyjang/tensorflow-pneumonia-classification-on-x-rays?scriptVersionId=39162263&cellId=26
1条答案
按热度按时间monwx1rj1#
问题可能来自于
get_label(path)
方法,因为你是根据sep拆分路径,它会返回一个包含很多元素的列表,确保你只选择一个元素来执行测试,试试这个:我假设最后一个元素是标签,你只需要用标签在列表中的位置来替换这个索引。