ai_reads / split_train_val.py
nedeljkovignjevic's picture
Upload folder using huggingface_hub
333322c verified
import os
import json
import random
from collections import defaultdict
def get_class_distribution(json_file):
with open(json_file, "r") as f:
data = json.load(f)
category_counts = defaultdict(int)
for ann in data["annotations"]:
category_counts[ann["category_id"]] += 1
category_map = {cat["id"]: cat["name"] for cat in data["categories"]}
print("\n📊 **Class Distribution** 📊")
print(f"Json File: {os.path.basename(json_file)}")
print("=" * 40)
for cat_id, count in sorted(category_counts.items(), key=lambda x: -x[1]):
print(f"{category_map[cat_id]:<20}{count:>5} instances")
print("=" * 40)
def split_annotations(json_file, output_train, output_val, train_ratio=0.8):
with open(json_file, "r") as f:
data = json.load(f)
images = {img["id"]: img for img in data["images"]}
categories = data["categories"]
# Group annotations by image ID and track category counts per image
image_annotations = defaultdict(list)
category_counts = defaultdict(int)
for ann in data["annotations"]:
image_annotations[ann["image_id"]].append(ann)
category_counts[ann["category_id"]] += 1
# Sort categories by frequency to prioritize rare classes
sorted_categories = sorted(category_counts.items(), key=lambda x: x[1])
# Assign images to train/val while balancing classes
train_images = set()
val_images = set()
train_ann_counts = defaultdict(int)
val_ann_counts = defaultdict(int)
for cat_id, _ in sorted_categories: # Start with rarest classes
# Get all images with this category
image_ids = list(
set(
ann["image_id"]
for ann in data["annotations"]
if ann["category_id"] == cat_id
)
)
random.shuffle(image_ids)
# Calculate target counts for this category
total_count = category_counts[cat_id]
target_train = int(total_count * train_ratio)
target_val = total_count - target_train
# Assign images to meet target counts
for img_id in image_ids:
if img_id in train_images or img_id in val_images:
continue # Skip if already assigned
anns = image_annotations[img_id]
cat_count_in_img = sum(1 for ann in anns if ann["category_id"] == cat_id)
if train_ann_counts[cat_id] + cat_count_in_img <= target_train:
train_images.add(img_id)
for ann in anns:
train_ann_counts[ann["category_id"]] += 1
elif val_ann_counts[cat_id] + cat_count_in_img <= target_val:
val_images.add(img_id)
for ann in anns:
val_ann_counts[ann["category_id"]] += 1
else:
# If both sets are full for this category, assign based on overall ratio
if (
len(train_images) / (len(train_images) + len(val_images) + 1e-6)
< train_ratio
):
train_images.add(img_id)
for ann in anns:
train_ann_counts[ann["category_id"]] += 1
else:
val_images.add(img_id)
for ann in anns:
val_ann_counts[ann["category_id"]] += 1
# Compile final lists
train_images_list = [images[img_id] for img_id in train_images]
val_images_list = [images[img_id] for img_id in val_images]
train_annotations = [
ann for img_id in train_images for ann in image_annotations[img_id]
]
val_annotations = [
ann for img_id in val_images for ann in image_annotations[img_id]
]
# Save JSONs
train_data = {
"images": train_images_list,
"annotations": train_annotations,
"categories": categories,
}
val_data = {
"images": val_images_list,
"annotations": val_annotations,
"categories": categories,
}
os.makedirs(os.path.dirname(output_train), exist_ok=True)
os.makedirs(os.path.dirname(output_val), exist_ok=True)
with open(output_train, "w") as f:
json.dump(train_data, f, indent=4)
with open(output_val, "w") as f:
json.dump(val_data, f, indent=4)
print("✅ Annotations split successfully!")
print(
f"Train: {len(train_images_list)} images, {len(train_annotations)} annotations"
)
print(f"Val: {len(val_images_list)} images, {len(val_annotations)} annotations")
# Example usage
json_file = "./data/annotations/other_augmented_annotations_v2.json"
output_train = "./data/annotations/other_augmented_annotations_v2_train.json"
output_val = "./data/annotations/other_augmented_annotations_v2_val.json"
if __name__ == "__main__":
split_annotations(json_file, output_train, output_val)
get_class_distribution(json_file)
get_class_distribution(output_train)
get_class_distribution(output_val)