import spaces import os import io import torch import gradio as gr import requests from diffusers import DiffusionPipeline, ZImagePipeline # ========================================================= # CONFIG # ========================================================= SCRIPTS_REPO_API = ( "https://api.github.com/repos/asomoza/diffusers-recipes/contents/" "models/z-image/scripts" ) LOCAL_SCRIPTS_DIR = "z_image_scripts" MODEL_ID = "Tongyi-MAI/Z-Image-Turbo" os.makedirs(LOCAL_SCRIPTS_DIR, exist_ok=True) # ========================================================= # GLOBAL STATE (CPU SAFE) # ========================================================= SCRIPT_CODE = {} # script_name -> code (CPU only) PIPELINES = {} # script_name -> pipeline (GPU only, lazy) log_buffer = io.StringIO() # ========================================================= # LOGGING # ========================================================= def log(msg): print(msg) log_buffer.write(msg + "\n") def pipeline_technology_info(pipe): tech = [] # Device map if hasattr(pipe, "hf_device_map"): tech.append("Device map: enabled") else: tech.append(f"Device: {pipe.device}") # Transformer dtype if hasattr(pipe, "transformer"): try: tech.append(f"Transformer dtype: {pipe.transformer.dtype}") except Exception: pass # Layerwise casting (Z-Image specific) if hasattr(pipe.transformer, "layerwise_casting"): lw = pipe.transformer.layerwise_casting tech.append( f"Layerwise casting: storage={lw.storage_dtype}, compute={lw.compute_dtype}" ) # VAE dtype if hasattr(pipe, "vae"): try: tech.append(f"VAE dtype: {pipe.vae.dtype}") except Exception: pass # Quantization / GGUF if hasattr(pipe, "quantization_config"): tech.append(f"Quantization: {pipe.quantization_config}") # Attention backend if hasattr(pipe, "config"): attn = getattr(pipe.config, "attn_implementation", None) if attn: tech.append(f"Attention: {attn}") return "\n".join(f"• {t}" for t in tech) # ========================================================= # LATENT INFO # ========================================================= def pipeline_debug_info(pipe): return f""" Pipeline Info ------------- Device: {pipe.device} Transformer: {pipe.transformer.__class__.__name__} VAE: {pipe.vae.__class__.__name__} """ def latent_shape_info(height, width, pipe): h = height // pipe.vae_scale_factor w = width // pipe.vae_scale_factor return f"Expected latent size: ({h}, {w})" # ========================================================= # DOWNLOAD SCRIPTS (CPU ONLY) # ========================================================= def download_scripts(): resp = requests.get(SCRIPTS_REPO_API) resp.raise_for_status() scripts = [] for item in resp.json(): if item["name"].endswith(".py"): scripts.append(item["name"]) path = os.path.join(LOCAL_SCRIPTS_DIR, item["name"]) if not os.path.exists(path): content = requests.get(item["download_url"]).text with open(path, "w") as f: f.write(content) return sorted(scripts) SCRIPT_NAMES = download_scripts() # ========================================================= # REGISTER SCRIPTS (CPU ONLY) # ========================================================= def register_scripts(selected_scripts): SCRIPT_CODE.clear() for name in selected_scripts: path = os.path.join(LOCAL_SCRIPTS_DIR, name) with open(path, "r") as f: code = f.read() SCRIPT_CODE[name] = code # Log the .py file and extract pipe lines log(f"=== Registering script: {name} ===") extract_pipe_lines(code) # This logs the full script + pipe lines return f"{len(SCRIPT_CODE)} script(s) registered ✅" # ========================================================= # EXTRACT LINES AFTER FROM_PRETRAINED # ========================================================= def extract_pipe_lines(script_code: str): lines = script_code.splitlines() # Log full .py file log("=== SCRIPT CONTENT START ===") for l in lines: log(l) log("=== SCRIPT CONTENT END ===") pipe_lines = [] found = False for line in lines: stripped = line.strip() if not found and stripped.startswith("pipe = ZImagePipeline.from_pretrained"): found = True pipe_lines.append(line) elif found: # Include all subsequent lines that reference 'pipe' if "pipe" in stripped: pipe_lines.append(line) # Log the extracted lines after from_pretrained log("🔧 Extracted pipe-related lines:") for l in pipe_lines: log(f"• {l.strip()}") return pipe_lines def extract_pipe_lines0(script_code: str): lines = script_code.splitlines() print(lines) pipe_lines = [] found = False for line in lines: stripped = line.strip() if not found and stripped.startswith("pipe = ZImagePipeline.from_pretrained"): found = True pipe_lines.append(line) elif found: if "pipe" in stripped: pipe_lines.append(line) log(f"🔧 Building pipeline from {pipe_lines}") return pipe_lines # ========================================================= # GPU-ONLY PIPELINE BUILDER # ========================================================= # ========================================================= # GPU-ONLY PIPELINE BUILDER (CRITICAL) # ========================================================= def get_pipeline(script_name): if script_name in PIPELINES: return PIPELINES[script_name] log(f"🔧 Building pipeline from {script_name}") namespace = { "__file__": script_name, "__name__": "__main__", # Minimal required globals "torch": torch, } try: exec(SCRIPT_CODE[script_name], namespace) except Exception as e: log(f"❌ Script failed: {script_name}") raise RuntimeError(f"Pipeline build failed for {script_name}") from e if "pipe" not in namespace: raise RuntimeError( f"{script_name} did not define `pipe`.\n" f"Each script MUST assign a variable named `pipe`." ) PIPELINES[script_name] = namespace["pipe"] log(f"✅ Pipeline ready: {script_name}") return PIPELINES[script_name] def get_pipeline_fallback(script_name): if script_name in PIPELINES: return PIPELINES[script_name] log(f"🔧 Building pipeline from {script_name}") namespace = { "__file__": script_name, "__name__": "__main__", # Minimal required globals "torch": torch, } try: exec(SCRIPT_CODE[script_name], namespace) except Exception as e: log(f"❌ Script failed: {script_name}") raise RuntimeError(f"Pipeline build failed for {script_name}") from e if "pipe" not in namespace: raise RuntimeError( f"{script_name} did not define `pipe`.\n" f"Each script MUST assign a variable named `pipe`." ) PIPELINES[script_name] = namespace["pipe"] log(f"✅ Pipeline ready: {script_name}") return PIPELINES[script_name] # ========================================================= # IMAGE GENERATION # ========================================================= @spaces.GPU def generate_image( prompt, height, width, num_inference_steps, seed, randomize_seed, num_images, pipeline_name, ): log_buffer.truncate(0) log_buffer.seek(0) if pipeline_name not in SCRIPT_CODE: raise RuntimeError("Pipeline not registered") pipe = get_pipeline(pipeline_name) log("=== PIPELINE TECHNOLOGY ===") log(pipeline_technology_info(pipe)) if not hasattr(pipe, "hf_device_map"): pipe = pipe.to("cuda") log("=== NEW GENERATION REQUEST ===") log(f"Pipeline: {pipeline_name}") log(f"Prompt: {prompt}") log(f"Height: {height}, Width: {width}") log(f"Steps: {num_inference_steps}") log(f"Images: {num_images}") if randomize_seed: seed = torch.randint(0, 2**32 - 1, (1,)).item() log(f"Random Seed → {seed}") else: log(f"Seed → {seed}") num_images = min(max(1, int(num_images)), 3) generator = torch.Generator("cuda").manual_seed(int(seed)) # Run pipeline result = pipe( prompt=prompt, height=int(height), width=int(width), num_inference_steps=int(num_inference_steps), guidance_scale=0.0, generator=generator, max_sequence_length=1024, num_images_per_prompt=num_images, output_type="pil", ) try: log(pipeline_debug_info(pipe)) log(latent_shape_info(height, width, pipe)) except Exception as e: log(f"Diagnostics error: {e}") log("✅ Generation complete") return result.images, seed, log_buffer.getvalue() # ========================================================= # GRADIO UI (original layout) # ========================================================= with gr.Blocks(title="Z-Image-Turbo Multi Image Demo") as demo: gr.Markdown("# 🎨 Z-Image-Turbo — Multi Image ") with gr.Row(): with gr.Column(scale=1): script_selector = gr.CheckboxGroup( choices=SCRIPT_NAMES, label="Select pipeline scripts" ) register_btn = gr.Button("Register Scripts") status = gr.Textbox(label="Status", interactive=False) prompt = gr.Textbox(label="Prompt", lines=4) with gr.Row(): height = gr.Slider(512, 2048, 1024, step=64, label="Height") width = gr.Slider(512, 2048, 1024, step=64, label="Width") num_images = gr.Slider(1, 3, 2, step=1, label="Number of Images") num_inference_steps = gr.Slider( 1, 20, 9, step=1, label="Inference Steps", info="9 steps = 8 DiT forward passes" ) with gr.Row(): seed = gr.Number(label="Seed", value=42, precision=0) randomize_seed = gr.Checkbox(label="Randomize Seed", value=False) generate_btn = gr.Button("🚀 Generate", variant="primary") with gr.Column(scale=1): pipeline_picker = gr.Dropdown( choices=[], label="Active Pipeline", ) output_images = gr.Gallery( label="Generated Images", type="pil", columns=2 ) used_seed = gr.Number(label="Seed Used", interactive=False) debug_log = gr.Textbox( label="Debug Log Output", lines=25, interactive=False ) register_btn.click( register_scripts, inputs=[script_selector], outputs=[status] ) register_btn.click( lambda s: gr.update(choices=s, value=s[0] if s else None), inputs=[script_selector], outputs=[pipeline_picker] ) generate_btn.click( generate_image, inputs=[prompt, height, width, num_inference_steps, seed, randomize_seed, num_images, pipeline_picker], outputs=[output_images, used_seed, debug_log] ) demo.queue() demo.launch()