#!/usr/bin/env python3 """Run Location ControlNet inference for OSM, Canny, SAM and save to demo_images/.""" import os import sys from pathlib import Path REPO = Path(__file__).resolve().parent sys.path.insert(0, str(REPO)) import torch from PIL import Image from geosynth_pipeline import load_geosynth_pipeline_with_location, run_with_location PROMPT = "Satellite image features a city neighborhood" LON, LAT = -90.2, 38.6 SEED = 42 def main(): for control_type in ["OSM", "Canny", "SAM"]: subfolder = f"controlnet/GeoSynth-Location-{control_type}" in_dir = REPO / "demo_images" / f"GeoSynth-Location-{control_type}" in_path = in_dir / "input.jpeg" out_path = in_dir / "output.jpeg" if not in_path.exists(): print(f"Skipping {control_type}: {in_path} not found") continue print(f"Loading GeoSynth-Location-{control_type}...") pipe = load_geosynth_pipeline_with_location( str(REPO), controlnet_subfolder=subfolder, satclip_path=str(REPO / "satclip_location_encoder"), coordnet_path=str(REPO / "coordnet"), local_files_only=True, ) pipe = pipe.to("cuda") img = Image.open(in_path).convert("RGB").resize((512, 512)) gen = torch.manual_seed(SEED) out = run_with_location( pipe, PROMPT, image=img, lon=LON, lat=LAT, num_inference_steps=20, generator=gen, ) out.images[0].save(out_path) print(f"Saved {out_path}") if __name__ == "__main__": main()