import os import json import random from collections import defaultdict def get_class_distribution(json_file): with open(json_file, "r") as f: data = json.load(f) category_counts = defaultdict(int) for ann in data["annotations"]: category_counts[ann["category_id"]] += 1 category_map = {cat["id"]: cat["name"] for cat in data["categories"]} print("\nšŸ“Š **Class Distribution** šŸ“Š") print(f"Json File: {os.path.basename(json_file)}") print("=" * 40) for cat_id, count in sorted(category_counts.items(), key=lambda x: -x[1]): print(f"{category_map[cat_id]:<20} → {count:>5} instances") print("=" * 40) def split_annotations(json_file, output_train, output_val, train_ratio=0.8): with open(json_file, "r") as f: data = json.load(f) images = {img["id"]: img for img in data["images"]} categories = data["categories"] # Group annotations by image ID and track category counts per image image_annotations = defaultdict(list) category_counts = defaultdict(int) for ann in data["annotations"]: image_annotations[ann["image_id"]].append(ann) category_counts[ann["category_id"]] += 1 # Sort categories by frequency to prioritize rare classes sorted_categories = sorted(category_counts.items(), key=lambda x: x[1]) # Assign images to train/val while balancing classes train_images = set() val_images = set() train_ann_counts = defaultdict(int) val_ann_counts = defaultdict(int) for cat_id, _ in sorted_categories: # Start with rarest classes # Get all images with this category image_ids = list( set( ann["image_id"] for ann in data["annotations"] if ann["category_id"] == cat_id ) ) random.shuffle(image_ids) # Calculate target counts for this category total_count = category_counts[cat_id] target_train = int(total_count * train_ratio) target_val = total_count - target_train # Assign images to meet target counts for img_id in image_ids: if img_id in train_images or img_id in val_images: continue # Skip if already assigned anns = image_annotations[img_id] cat_count_in_img = sum(1 for ann in anns if ann["category_id"] == cat_id) if train_ann_counts[cat_id] + cat_count_in_img <= target_train: train_images.add(img_id) for ann in anns: train_ann_counts[ann["category_id"]] += 1 elif val_ann_counts[cat_id] + cat_count_in_img <= target_val: val_images.add(img_id) for ann in anns: val_ann_counts[ann["category_id"]] += 1 else: # If both sets are full for this category, assign based on overall ratio if ( len(train_images) / (len(train_images) + len(val_images) + 1e-6) < train_ratio ): train_images.add(img_id) for ann in anns: train_ann_counts[ann["category_id"]] += 1 else: val_images.add(img_id) for ann in anns: val_ann_counts[ann["category_id"]] += 1 # Compile final lists train_images_list = [images[img_id] for img_id in train_images] val_images_list = [images[img_id] for img_id in val_images] train_annotations = [ ann for img_id in train_images for ann in image_annotations[img_id] ] val_annotations = [ ann for img_id in val_images for ann in image_annotations[img_id] ] # Save JSONs train_data = { "images": train_images_list, "annotations": train_annotations, "categories": categories, } val_data = { "images": val_images_list, "annotations": val_annotations, "categories": categories, } os.makedirs(os.path.dirname(output_train), exist_ok=True) os.makedirs(os.path.dirname(output_val), exist_ok=True) with open(output_train, "w") as f: json.dump(train_data, f, indent=4) with open(output_val, "w") as f: json.dump(val_data, f, indent=4) print("āœ… Annotations split successfully!") print( f"Train: {len(train_images_list)} images, {len(train_annotations)} annotations" ) print(f"Val: {len(val_images_list)} images, {len(val_annotations)} annotations") # Example usage json_file = "./data/annotations/other_augmented_annotations_v2.json" output_train = "./data/annotations/other_augmented_annotations_v2_train.json" output_val = "./data/annotations/other_augmented_annotations_v2_val.json" if __name__ == "__main__": split_annotations(json_file, output_train, output_val) get_class_distribution(json_file) get_class_distribution(output_train) get_class_distribution(output_val)