pytorch 创建和读取自定义web数据集

njthzxwz  于 2023-03-18  发布在  其他
关注(0)|答案(1)|浏览(166)

使用Python包webdataset,如何从一个由目录组成的数据集创建一个tar存档。每个目录都是一个类,由文件组成,比如jpeg图像?如何在数据加载器上使用迭代器读取这样的web数据集?

fykwrbwg

fykwrbwg1#

使用webdataset.TarWriteros.walk类编写如下文件:

def get_class_label(path):
    return os.path.basename(os.path.dirname(path))

# Create a TarWriter object to write the dataset to a tar archive
with wds.TarWriter("dataset.tar") as tar:
    # Iterate over the files in the dataset directory
    for root, dirs, files in os.walk(data_dir):
        # Iterate over the files in each subdirectory
        for filename in files:
            # Construct the path to the file
            path = os.path.join(root, filename)
            # Get the class label for the file
            class_label = get_class_label(path)
            # Write the file and its metadata to the TarWriter
            with open(path, "rb") as f:
                tar.write({"__key__": path, "class": class_label, "jpg": f.read()})

要读取此tar数据集,可以执行以下操作:

# Define a function to preprocess the data
def preprocess(data):
    image = data["jpg"]
    class_label = data["class"]
    #Use torchvision.transforms to apply preprocessing steps and convert to tensor
    return image, class_label

# Create a WebLoader object to load the data from the tar archive
dataset = wds.WebLoader("dataset.tar").map(preprocess)

# Use the dataset with a PyTorch DataLoader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

有关详细信息,请参阅webdataset指南here

相关问题