Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import random | |
| from collections import defaultdict | |
| def get_class_distribution(json_file): | |
| with open(json_file, "r") as f: | |
| data = json.load(f) | |
| category_counts = defaultdict(int) | |
| for ann in data["annotations"]: | |
| category_counts[ann["category_id"]] += 1 | |
| category_map = {cat["id"]: cat["name"] for cat in data["categories"]} | |
| print("\n📊 **Class Distribution** 📊") | |
| print(f"Json File: {os.path.basename(json_file)}") | |
| print("=" * 40) | |
| for cat_id, count in sorted(category_counts.items(), key=lambda x: -x[1]): | |
| print(f"{category_map[cat_id]:<20} → {count:>5} instances") | |
| print("=" * 40) | |
| def split_annotations(json_file, output_train, output_val, train_ratio=0.8): | |
| with open(json_file, "r") as f: | |
| data = json.load(f) | |
| images = {img["id"]: img for img in data["images"]} | |
| categories = data["categories"] | |
| # Group annotations by image ID and track category counts per image | |
| image_annotations = defaultdict(list) | |
| category_counts = defaultdict(int) | |
| for ann in data["annotations"]: | |
| image_annotations[ann["image_id"]].append(ann) | |
| category_counts[ann["category_id"]] += 1 | |
| # Sort categories by frequency to prioritize rare classes | |
| sorted_categories = sorted(category_counts.items(), key=lambda x: x[1]) | |
| # Assign images to train/val while balancing classes | |
| train_images = set() | |
| val_images = set() | |
| train_ann_counts = defaultdict(int) | |
| val_ann_counts = defaultdict(int) | |
| for cat_id, _ in sorted_categories: # Start with rarest classes | |
| # Get all images with this category | |
| image_ids = list( | |
| set( | |
| ann["image_id"] | |
| for ann in data["annotations"] | |
| if ann["category_id"] == cat_id | |
| ) | |
| ) | |
| random.shuffle(image_ids) | |
| # Calculate target counts for this category | |
| total_count = category_counts[cat_id] | |
| target_train = int(total_count * train_ratio) | |
| target_val = total_count - target_train | |
| # Assign images to meet target counts | |
| for img_id in image_ids: | |
| if img_id in train_images or img_id in val_images: | |
| continue # Skip if already assigned | |
| anns = image_annotations[img_id] | |
| cat_count_in_img = sum(1 for ann in anns if ann["category_id"] == cat_id) | |
| if train_ann_counts[cat_id] + cat_count_in_img <= target_train: | |
| train_images.add(img_id) | |
| for ann in anns: | |
| train_ann_counts[ann["category_id"]] += 1 | |
| elif val_ann_counts[cat_id] + cat_count_in_img <= target_val: | |
| val_images.add(img_id) | |
| for ann in anns: | |
| val_ann_counts[ann["category_id"]] += 1 | |
| else: | |
| # If both sets are full for this category, assign based on overall ratio | |
| if ( | |
| len(train_images) / (len(train_images) + len(val_images) + 1e-6) | |
| < train_ratio | |
| ): | |
| train_images.add(img_id) | |
| for ann in anns: | |
| train_ann_counts[ann["category_id"]] += 1 | |
| else: | |
| val_images.add(img_id) | |
| for ann in anns: | |
| val_ann_counts[ann["category_id"]] += 1 | |
| # Compile final lists | |
| train_images_list = [images[img_id] for img_id in train_images] | |
| val_images_list = [images[img_id] for img_id in val_images] | |
| train_annotations = [ | |
| ann for img_id in train_images for ann in image_annotations[img_id] | |
| ] | |
| val_annotations = [ | |
| ann for img_id in val_images for ann in image_annotations[img_id] | |
| ] | |
| # Save JSONs | |
| train_data = { | |
| "images": train_images_list, | |
| "annotations": train_annotations, | |
| "categories": categories, | |
| } | |
| val_data = { | |
| "images": val_images_list, | |
| "annotations": val_annotations, | |
| "categories": categories, | |
| } | |
| os.makedirs(os.path.dirname(output_train), exist_ok=True) | |
| os.makedirs(os.path.dirname(output_val), exist_ok=True) | |
| with open(output_train, "w") as f: | |
| json.dump(train_data, f, indent=4) | |
| with open(output_val, "w") as f: | |
| json.dump(val_data, f, indent=4) | |
| print("✅ Annotations split successfully!") | |
| print( | |
| f"Train: {len(train_images_list)} images, {len(train_annotations)} annotations" | |
| ) | |
| print(f"Val: {len(val_images_list)} images, {len(val_annotations)} annotations") | |
| # Example usage | |
| json_file = "./data/annotations/other_augmented_annotations_v2.json" | |
| output_train = "./data/annotations/other_augmented_annotations_v2_train.json" | |
| output_val = "./data/annotations/other_augmented_annotations_v2_val.json" | |
| if __name__ == "__main__": | |
| split_annotations(json_file, output_train, output_val) | |
| get_class_distribution(json_file) | |
| get_class_distribution(output_train) | |
| get_class_distribution(output_val) |