keras tfrecords:编码图像导致图像失真

dfuffjeb  于 2023-01-30  发布在  其他
关注(0)|答案(2)|浏览(177)

我试图从image_dataset_from_directory生成一个tfrecords;但是当我试图可视化图像以检查编码是否正确时,图像却出现了某种失真。

如何创建tfrecord:

第1步:使用image_dataset_from_directory创建数据集

data_dir = 'path to JPG dataset'

load_split = partial(
    tf.keras.preprocessing.image_dataset_from_directory,
    data_dir,
    validation_split=0.2,
    shuffle=True,
    seed=123,
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=1,
)

ds_train = load_split(subset='training')
ds_valid = load_split(subset='validation')

步骤2:编码函数

def process_image(image, label):
    image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
    image = tf.io.encode_jpeg(image)
    
    label = tf.one_hot(label, NUM_CLASSES)
    
    return image, label

def make_example(encoded_image, label):
    image_feature = Feature(
        bytes_list=BytesList(value=[
            encoded_image,
        ]),
    )
    label_feature = Feature(
        float_list=FloatList(value=label)
    )

    features = Features(feature={
        'image': image_feature,
        'label': label_feature,
    })
    
    example = Example(features=features)
    
    return example.SerializeToString()

步骤3:编码和创建tfrecord

ds_train_encoded = (
    ds_train
    .unbatch()
    .map(process_image)
)

ds_valid_encoded = (
    ds_valid
    .unbatch()
    .map(process_image)
)

ds_train_encoded_iter = (
    ds_train_encoded
    .as_numpy_iterator()
)
with tf.io.TFRecordWriter(path='train.tfrecord') as f: # you can pass gs:// path here :) 
    for encoded_image, label in ds_train_encoded_iter:
        example = make_example(encoded_image, label)
        f.write(example)

ds_valid_encoded_iter = (
    ds_valid_encoded
    .as_numpy_iterator()
)
with tf.io.TFRecordWriter(path='/home/et/medai/images/tfrecords/test.tfrecord') as f:
    for encoded_image, label in ds_valid_encoded_iter:
        example = make_example(encoded_image, label)
        f.write(example)

我如何尝试将tfrecords中的图像可视化

步骤1:解码函数

def _parse_image_function(example):
    image_feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([40], tf.float32),
    }

    features = tf.io.parse_single_example(example, image_feature_description)
    image = tf.image.decode_jpeg(features['image'], channels=3)
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    # image = features['image']
    label = features['label']

    return image, label

def read_dataset(filename, batch_size):
    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.map(_parse_image_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.shuffle(500)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    # dataset = dataset.repeat()
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return dataset

第二步:解码并显示

x = read_dataset('/home/et/medai/images/tfrecords/tests_train.tfrecord', 32)

plt.figure(figsize=(10, 10))
batch_size = 32
for images, labels in x.take(1):
    for i in range(batch_size):
        # display.display(display.Image(data=images[i].numpy()))

        ax = plt.subplot(6, 6, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.axis("off")

结果是扭曲的东西:https://i.stack.imgur.com/tCAik.jpg
我不太确定这种失真是从哪里来的。原始图像看起来像这样:
https://i.stack.imgur.com/Zi4HG.png
有什么想法吗?

lkaoscv7

lkaoscv71#

我也遇到过类似的问题。我修复了图像处理中的图像归一化问题(对于您的情况,在process_image中)。
当你使用0~255作为像素数据时,在操作图像数据时,比如转换为字节和调整大小,它往往会分解,因为这些操作会舍入它的像素值。所以,我希望你尝试将图像像素数据归一化为0。到1。的浮点值。
我使用OpenCV解决了这个问题,我希望您可以用类似于我在下面发布的代码的方法来解决您的问题。

# This line distorted my images.
img = cv2.normalize(img, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)

# I changed to this line, and it worked.
img = cv2.normalize(img, None, alpha=0., beta=1., norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
zxlwwiss

zxlwwiss2#

当我面对这个问题时,我使用了:
image =打开(内容,"rb"). read()
而不是:
图像= tf.图像.转换_图像_数据类型(图像,数据类型= tf.uint8)图像= tf. io.编码_jpeg(图像)

相关问题