将增量与keras和变换器一起应用于影像分类时出现错误“img”

i5desfxk  于 2023-01-21  发布在  其他
关注(0)|答案(1)|浏览(114)

我想应用VIT进行图像分类。但我有一个问题,我不知道解决它。我的错误是这个“KeyError:'img'“。错误显示时,我应用了最后一个命令,我不知道我的错误在哪里。数据集内的图像是在.png,但我不认为这是错误的。下面有脚本:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import os
import cv2
import matplotlib.pyplot as plt
from transformers import ViTFeatureExtractor, ViTForImageClassification
from transformers import TrainingArguments, Trainer
from tensorflow import keras 
from tensorflow.keras import layers
from datasets import load_metric
from PIL import Image as img
from IPython.display import Image, display
from datasets import load_dataset 
import torch
dataset = load_dataset("imagefolder", data_dir="Datasets")

dataset
example = dataset["train"][10]
example
dataset["train"].features
example['image']
example['image'].resize((200, 200))
example['label']
dataset["train"].features["label"]

img_class_labels = dataset["train"].features["label"].names

from transformers import ViTFeatureExtractor
from tensorflow import keras
from tensorflow.keras import layers

model_id = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_id)

# learn more about data augmentation here: https://www.tensorflow.org/tutorials/images/data_augmentation
data_augmentation = keras.Sequential(
    [
        layers.Resizing(feature_extractor.size, feature_extractor.size),
        layers.Rescaling(1./255),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)
# use keras image data augementation processing
def augmentation(examples):
    # print(examples["img"])
    examples["pixel_values"] = [data_augmentation(image) for image in examples["img"]]
    return examples

# basic processing (only resizing)
def process(examples):
    examples.update(feature_extractor(examples['img'], ))
    return examples

# we are also renaming our label col to labels to use `.to_tf_dataset` later
dataset_ds = dataset["train"].rename_column("label", "labels")

processed_dataset = dataset_ds.map(augmentation, batched=True)
processed_dataset
ukxgm1gy

ukxgm1gy1#

我猜错误就在这里:

def augmentation(examples):
    # print(examples["img"])
    examples["pixel_values"] = [data_augmentation(image) for image in examples["img"]]
    return examples

您正在尝试使用“img”键访问“examples”字典。从上面的一些代码来看,键应该是“image”:

examples["pixel_values"] = [data_augmentation(image) for image in examples["image"]]

相关问题