我想应用VIT进行图像分类。但我有一个问题,我不知道解决它。我的错误是这个“KeyError:'img'“。错误显示时,我应用了最后一个命令,我不知道我的错误在哪里。数据集内的图像是在.png,但我不认为这是错误的。下面有脚本:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import os
import cv2
import matplotlib.pyplot as plt
from transformers import ViTFeatureExtractor, ViTForImageClassification
from transformers import TrainingArguments, Trainer
from tensorflow import keras
from tensorflow.keras import layers
from datasets import load_metric
from PIL import Image as img
from IPython.display import Image, display
from datasets import load_dataset
import torch
dataset = load_dataset("imagefolder", data_dir="Datasets")
dataset
example = dataset["train"][10]
example
dataset["train"].features
example['image']
example['image'].resize((200, 200))
example['label']
dataset["train"].features["label"]
img_class_labels = dataset["train"].features["label"].names
from transformers import ViTFeatureExtractor
from tensorflow import keras
from tensorflow.keras import layers
model_id = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_id)
# learn more about data augmentation here: https://www.tensorflow.org/tutorials/images/data_augmentation
data_augmentation = keras.Sequential(
[
layers.Resizing(feature_extractor.size, feature_extractor.size),
layers.Rescaling(1./255),
layers.RandomFlip("horizontal"),
layers.RandomRotation(factor=0.02),
layers.RandomZoom(
height_factor=0.2, width_factor=0.2
),
],
name="data_augmentation",
)
# use keras image data augementation processing
def augmentation(examples):
# print(examples["img"])
examples["pixel_values"] = [data_augmentation(image) for image in examples["img"]]
return examples
# basic processing (only resizing)
def process(examples):
examples.update(feature_extractor(examples['img'], ))
return examples
# we are also renaming our label col to labels to use `.to_tf_dataset` later
dataset_ds = dataset["train"].rename_column("label", "labels")
processed_dataset = dataset_ds.map(augmentation, batched=True)
processed_dataset
1条答案
按热度按时间ukxgm1gy1#
我猜错误就在这里:
您正在尝试使用“img”键访问“examples”字典。从上面的一些代码来看,键应该是“image”: