我试图从image_dataset_from_directory
生成一个tfrecords
;但是当我试图可视化图像以检查编码是否正确时,图像却出现了某种失真。
如何创建tfrecord:
第1步:使用image_dataset_from_directory
创建数据集
data_dir = 'path to JPG dataset'
load_split = partial(
tf.keras.preprocessing.image_dataset_from_directory,
data_dir,
validation_split=0.2,
shuffle=True,
seed=123,
image_size=(IMG_HEIGHT, IMG_WIDTH),
batch_size=1,
)
ds_train = load_split(subset='training')
ds_valid = load_split(subset='validation')
步骤2:编码函数
def process_image(image, label):
image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
image = tf.io.encode_jpeg(image)
label = tf.one_hot(label, NUM_CLASSES)
return image, label
def make_example(encoded_image, label):
image_feature = Feature(
bytes_list=BytesList(value=[
encoded_image,
]),
)
label_feature = Feature(
float_list=FloatList(value=label)
)
features = Features(feature={
'image': image_feature,
'label': label_feature,
})
example = Example(features=features)
return example.SerializeToString()
步骤3:编码和创建tfrecord
ds_train_encoded = (
ds_train
.unbatch()
.map(process_image)
)
ds_valid_encoded = (
ds_valid
.unbatch()
.map(process_image)
)
ds_train_encoded_iter = (
ds_train_encoded
.as_numpy_iterator()
)
with tf.io.TFRecordWriter(path='train.tfrecord') as f: # you can pass gs:// path here :)
for encoded_image, label in ds_train_encoded_iter:
example = make_example(encoded_image, label)
f.write(example)
ds_valid_encoded_iter = (
ds_valid_encoded
.as_numpy_iterator()
)
with tf.io.TFRecordWriter(path='/home/et/medai/images/tfrecords/test.tfrecord') as f:
for encoded_image, label in ds_valid_encoded_iter:
example = make_example(encoded_image, label)
f.write(example)
我如何尝试将tfrecords中的图像可视化
步骤1:解码函数
def _parse_image_function(example):
image_feature_description = {
'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([40], tf.float32),
}
features = tf.io.parse_single_example(example, image_feature_description)
image = tf.image.decode_jpeg(features['image'], channels=3)
image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
# image = features['image']
label = features['label']
return image, label
def read_dataset(filename, batch_size):
dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.map(_parse_image_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.shuffle(500)
dataset = dataset.batch(batch_size, drop_remainder=True)
# dataset = dataset.repeat()
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset
第二步:解码并显示
x = read_dataset('/home/et/medai/images/tfrecords/tests_train.tfrecord', 32)
plt.figure(figsize=(10, 10))
batch_size = 32
for images, labels in x.take(1):
for i in range(batch_size):
# display.display(display.Image(data=images[i].numpy()))
ax = plt.subplot(6, 6, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.axis("off")
结果是扭曲的东西:https://i.stack.imgur.com/tCAik.jpg
我不太确定这种失真是从哪里来的。原始图像看起来像这样:
https://i.stack.imgur.com/Zi4HG.png
有什么想法吗?
2条答案
按热度按时间lkaoscv71#
我也遇到过类似的问题。我修复了图像处理中的图像归一化问题(对于您的情况,在
process_image
中)。当你使用0~255作为像素数据时,在操作图像数据时,比如转换为字节和调整大小,它往往会分解,因为这些操作会舍入它的像素值。所以,我希望你尝试将图像像素数据归一化为0。到1。的浮点值。
我使用OpenCV解决了这个问题,我希望您可以用类似于我在下面发布的代码的方法来解决您的问题。
zxlwwiss2#
当我面对这个问题时,我使用了:
image =打开(内容,"rb"). read()
而不是:
图像= tf.图像.转换_图像_数据类型(图像,数据类型= tf.uint8)图像= tf. io.编码_jpeg(图像)