import subprocess import shlex import sys import os import tempfile import numpy as np import io import base64 import json import uvicorn import torch from PIL import Image # Install the custom component if needed subprocess.run( shlex.split( "pip install ./gradio_magicquillv2-0.0.1-py3-none-any.whl" ) ) import gradio as gr from fastapi import FastAPI, Request from fastapi.concurrency import run_in_threadpool from fastapi.middleware.cors import CORSMiddleware from gradio_client import Client, handle_file from gradio_magicquillv2 import MagicQuillV2 from util import ( read_base64_image as read_base64_image_utils, tensor_to_base64, get_mask_bbox ) # --- Configuration --- # Set this to the URL of your backend Space (running app_backend.py) BACKEND_URL = "LiuZichen/MagicQuillV2" SAM_URL = "LiuZichen/MagicQuillHelper" print(f"Target Backend URL: {BACKEND_URL}") # We still initialize SAM client globally as it might not require ZeroGPU quotas # or is a helper CPU space. print(f"Connecting to SAM client at: {SAM_URL}") try: sam_client = Client(SAM_URL) except Exception as e: print(f"Failed to connect to SAM client: {e}") sam_client = None def get_zerogpu_headers(request_headers): """ Extracts ZeroGPU specific headers from the incoming request headers. These are required to forward the user's quota token to the backend. """ headers = {} if request_headers: # These are the headers HF injects for ZeroGPU authentication and tracking target_headers = [ "x-ip-token", "x-zerogpu-token", "x-zerogpu-uuid", "authorization", "cookie" ] for h in target_headers: val = request_headers.get(h) if val: headers[h] = val return headers def generate_image_handler(x, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg, request: gr.Request): """ Handler for the Gradio UI. Note the 'request: gr.Request' argument - Gradio automatically injects this. """ merged_image = x['from_frontend']['img'] total_mask = x['from_frontend']['total_mask'] original_image = x['from_frontend']['original_image'] add_color_image = x['from_frontend']['add_color_image'] add_edge_mask = x['from_frontend']['add_edge_mask'] remove_edge_mask = x['from_frontend']['remove_edge_mask'] fill_mask = x['from_frontend']['fill_mask'] add_prop_image = x['from_frontend']['add_prop_image'] positive_prompt = x['from_backend']['prompt'] forward_headers = get_zerogpu_headers(request.headers) print(f"Debug: Received headers keys: {list(request.headers.keys())}") print(forward_headers) try: # 2. Instantiate a client specifically for this request with the forwarded headers. # This ensures the backend sees the 'x-zerogpu-token' of the user, not the server. # gradio_client caches schemas, so re-init is relatively cheap but necessary for headers. client = Client(BACKEND_URL, headers=forward_headers) # Call the backend API res_base64 = client.predict( merged_image, # merged_image total_mask, # total_mask original_image, # original_image add_color_image, # add_color_image add_edge_mask, # add_edge_mask remove_edge_mask, # remove_edge_mask fill_mask, # fill_mask add_prop_image, # add_prop_image positive_prompt, # positive_prompt negative_prompt, # negative_prompt fine_edge, # fine_edge fix_perspective, # fix_perspective grow_size, # grow_size edge_strength, # edge_strength color_strength, # color_strength local_strength, # local_strength seed, # seed steps, # steps cfg, # cfg api_name="/generate" ) x["from_backend"]["generated_image"] = res_base64 except Exception as e: print(f"Error in generation: {e}") x["from_backend"]["generated_image"] = None return x # --- Gradio UI --- with gr.Blocks(title="MagicQuill V2") as demo: with gr.Row(elem_classes="row"): text = gr.Markdown( """ # Welcome to MagicQuill V2! Give us a [GitHub star](https://github.com/zliucz/magicquillv2) if you are interested. Click the [link](https://magicquill.art/v2) to view our demo and tutorial. The paper is on [ArXiv](https://arxiv.org/abs/2512.03046) now. The [ZeroGPU](https://huggingface.co/docs/hub/spaces-zerogpu) quota is 4 minutes per day for normal users and 25 minutes per day for pro users. """) with gr.Row(): ms = MagicQuillV2() with gr.Row(): with gr.Column(): btn = gr.Button("Run", variant="primary") with gr.Column(): with gr.Accordion("parameters", open=False): negative_prompt = gr.Textbox(label="Negative Prompt", value="", interactive=True) fine_edge = gr.Radio(label="Fine Edge", choices=['enable', 'disable'], value='disable', interactive=True) fix_perspective = gr.Radio(label="Fix Perspective", choices=['enable', 'disable'], value='disable', interactive=True) grow_size = gr.Slider(label="Grow Size", minimum=10, maximum=100, value=50, step=1, interactive=True) edge_strength = gr.Slider(label="Edge Strength", minimum=0.0, maximum=5.0, value=0.6, step=0.01, interactive=True) color_strength = gr.Slider(label="Color Strength", minimum=0.0, maximum=5.0, value=1.5, step=0.01, interactive=True) local_strength = gr.Slider(label="Local Strength", minimum=0.0, maximum=5.0, value=1.0, step=0.01, interactive=True) seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True) steps = gr.Slider(label="Steps", minimum=0, maximum=50, value=20, interactive=True) cfg = gr.Slider(label="CFG", minimum=0.0, maximum=20.0, value=3.5, step=0.1, interactive=True) btn.click( generate_image_handler, # Note: We do NOT need to explicitly add 'request' to inputs here. # Gradio handles type hinting for gr.Request automatically. inputs=[ms, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg], outputs=ms ) # --- FastAPI App --- app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def get_root_url(request: Request, route_path: str, root_path: str | None): return root_path gr.route_utils.get_root_url = get_root_url @app.post("/magic_quill/process_background_img") async def process_background_img(request: Request): img = await request.json() from util import process_background # process_background returns tensor [1, H, W, 3] in uint8 or float resized_img_tensor = process_background(img) # tensor_to_base64 from util expects tensor resized_img_base64 = "data:image/webp;base64," + tensor_to_base64( resized_img_tensor, quality=80, method=6 ) return resized_img_base64 @app.post("/magic_quill/segmentation") async def segmentation(request: Request): json_data = await request.json() image_base64 = json_data.get("image", None) coordinates_positive = json_data.get("coordinates_positive", None) coordinates_negative = json_data.get("coordinates_negative", None) bboxes = json_data.get("bboxes", None) if sam_client is None: return {"error": "sam client not initialized"} # Process coordinates and bboxes pos_coordinates = None if coordinates_positive and len(coordinates_positive) > 0: pos_coordinates = [] for coord in coordinates_positive: coord['x'] = int(round(coord['x'])) coord['y'] = int(round(coord['y'])) pos_coordinates.append({'x': coord['x'], 'y': coord['y']}) pos_coordinates = json.dumps(pos_coordinates) neg_coordinates = None if coordinates_negative and len(coordinates_negative) > 0: neg_coordinates = [] for coord in coordinates_negative: coord['x'] = int(round(coord['x'])) coord['y'] = int(round(coord['y'])) neg_coordinates.append({'x': coord['x'], 'y': coord['y']}) neg_coordinates = json.dumps(neg_coordinates) bboxes_xyxy = None if bboxes and len(bboxes) > 0: valid_bboxes = [] for bbox in bboxes: if (bbox.get("startX") is None or bbox.get("startY") is None or bbox.get("endX") is None or bbox.get("endY") is None): continue else: x_min = max(min(int(bbox["startX"]), int(bbox["endX"])), 0) y_min = max(min(int(bbox["startY"]), int(bbox["endY"])), 0) x_max = int(bbox["startX"]) if int(bbox["startX"]) > int(bbox["endX"]) else int(bbox["endX"]) y_max = int(bbox["startY"]) if int(bbox["startY"]) > int(bbox["endY"]) else int(bbox["endY"]) valid_bboxes.append((x_min, y_min, x_max, y_max)) bboxes_xyxy = [] for bbox in valid_bboxes: x_min, y_min, x_max, y_max = bbox bboxes_xyxy.append((x_min, y_min, x_max, y_max)) if bboxes_xyxy: bboxes_xyxy = json.dumps(bboxes_xyxy) print(f"Segmentation request: pos={pos_coordinates}, neg={neg_coordinates}, bboxes={bboxes_xyxy}") try: # Save base64 image to temp file image_bytes = read_base64_image_utils(image_base64) pil_image = Image.open(image_bytes) # Resize for faster transmission (short side 512) original_size = pil_image.size w, h = original_size scale = 512 / min(w, h) if scale < 1: new_w = int(w * scale) new_h = int(h * scale) pil_image_resized = pil_image.resize((new_w, new_h), Image.LANCZOS) print(f"Resized image for segmentation: {original_size} -> {(new_w, new_h)}") # Adjust coordinates and bboxes according to scale if pos_coordinates: pos_coords_list = json.loads(pos_coordinates) for coord in pos_coords_list: coord['x'] = int(coord['x'] * scale) coord['y'] = int(coord['y'] * scale) pos_coordinates = json.dumps(pos_coords_list) if neg_coordinates: neg_coords_list = json.loads(neg_coordinates) for coord in neg_coords_list: coord['x'] = int(coord['x'] * scale) coord['y'] = int(coord['y'] * scale) neg_coordinates = json.dumps(neg_coords_list) if bboxes_xyxy: bboxes_list = json.loads(bboxes_xyxy) new_bboxes = [] for bbox in bboxes_list: new_bboxes.append(( int(bbox[0] * scale), int(bbox[1] * scale), int(bbox[2] * scale), int(bbox[3] * scale) )) bboxes_xyxy = json.dumps(new_bboxes) else: pil_image_resized = pil_image scale = 1.0 with tempfile.NamedTemporaryFile(suffix=".webp", delete=False) as temp_in: pil_image_resized.save(temp_in.name, format="WEBP", quality=80) temp_in_path = temp_in.name # Execute segmentation via Client result_path = await run_in_threadpool( sam_client.predict, handle_file(temp_in_path), pos_coordinates, neg_coordinates, bboxes_xyxy, api_name="/segment" ) os.unlink(temp_in_path) if isinstance(result_path, (list, tuple)): result_path = result_path[0] if not result_path or not os.path.exists(result_path): raise RuntimeError("Client returned invalid result path") mask_pil = Image.open(result_path) if mask_pil.mode != 'L': mask_pil = mask_pil.convert('L') pil_image = pil_image.convert("RGB") if pil_image.size != mask_pil.size: mask_pil = mask_pil.resize(pil_image.size, Image.NEAREST) r, g, b = pil_image.split() res_pil = Image.merge("RGBA", (r, g, b, mask_pil)) mask_tensor = torch.from_numpy(np.array(mask_pil) / 255.0).float().unsqueeze(0) mask_bbox = get_mask_bbox(mask_tensor) if mask_bbox: x_min, y_min, x_max, y_max = mask_bbox seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max} else: seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0} buffered = io.BytesIO() res_pil.save(buffered, format="PNG") image_base64_res = base64.b64encode(buffered.getvalue()).decode("utf-8") return { "error": False, "segmentation_image": "data:image/png;base64," + image_base64_res, "segmentation_bbox": seg_bbox } except Exception as e: print(f"Error in segmentation: {e}") return {"error": str(e)} # Mount the Gradio app # Reduce concurrency for ZeroGPU to prevent rate limiting demo.queue(default_concurrency_limit=10, max_size=20) app = gr.mount_gradio_app(app, demo, path="/", root_path="/demo") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)