python 如何平均划分COCO数据集?

sy5wg1nm  于 2023-09-29  发布在  Python
关注(0)|答案(1)|浏览(120)

我从here下载了完整的COCO数据集。我需要修改instances_train2017.json文件以获得以下内容:

  • 修改注解文件,使其将整个训练数据集平均划分为所有类的1/3
  • 意思是假设我有100个class_1的图像/注解,所以我希望修改后的注解文件在json文件中保存图像对象/dict的100/3

我已经写了这样的代码,但它需要太多的时间和结果/修改的文件是错误的:

import json
from collections import defaultdict
import random
# Path to your local COCO-format JSON annotation file
original_annotation_file = 'coco/annotations/instances_train2017.json'
output_annotation_file = 'evenly.json'
# Load the local COCO-format dataset from your JSON file
with open(original_annotation_file, 'r') as f:
    coco_data = json.load(f)

class_counts = defaultdict(int)
target_class_counts = defaultdict(int)

# Calculate the target count for each class
for ann in coco_data['annotations']:
    class_id = ann['category_id']
    class_counts[class_id] += 1

for class_id, count in class_counts.items():
    target_class_counts[class_id] = count // 3

# Create a list to hold the selected annotations
selected_annotations = []

# Iterate through the annotations and select the subset
for ann in coco_data['annotations']:
    class_id = ann['category_id']

    # Only include this annotation if we haven't reached the target count for this class
    if class_counts[class_id] <= target_class_counts[class_id]:
        selected_annotations.append(ann)

        # Update the count for this class
        class_counts[class_id] += 1

# Create a new COCO-format JSON data structure for the subset
subset_data = {
    'info': coco_data['info'],
    'licenses': coco_data['licenses'],
    'categories': coco_data['categories'],
    'images': coco_data['images'],
    'annotations': selected_annotations
}

# Shuffle the selected annotations to mix up the classes if desired
random.shuffle(subset_data['annotations'])

# Write the subset data to a new JSON file
with open(output_annotation_file, 'w') as f:
    json.dump(subset_data, f)
k7fdbhmy

k7fdbhmy1#

我相信如果在将示例附加到selected_annotations之后减少target_class_counts[class_id],就可以解决问题。所以替换下面的代码,它应该可以解决你的问题。

# Only include this annotation if we haven't reached the target count for this class
if target_class_counts[class_id] > 1:
    selected_annotations.append(ann)

    # Update the count for this class
    target_class_counts[class_id] -= 1

但还是会很慢。我建议创建一个列表字典,其中键是category_id,值是标记为该键的annotations数据

data = defaultdict(list)
for ann in coco_data["annotation"]:
    class_id = ann["category_id"]
    data[class_id].append(ann)

现在,您可以迭代data并打乱每个列表,然后使用切片来获取您不想要的数据部分。

相关问题