Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import os | |
| import io | |
| import torch | |
| import gradio as gr | |
| import requests | |
| from diffusers import DiffusionPipeline, ZImagePipeline | |
| # ========================================================= | |
| # CONFIG | |
| # ========================================================= | |
| SCRIPTS_REPO_API = ( | |
| "https://api.github.com/repos/asomoza/diffusers-recipes/contents/" | |
| "models/z-image/scripts" | |
| ) | |
| LOCAL_SCRIPTS_DIR = "z_image_scripts" | |
| MODEL_ID = "Tongyi-MAI/Z-Image-Turbo" | |
| os.makedirs(LOCAL_SCRIPTS_DIR, exist_ok=True) | |
| # ========================================================= | |
| # GLOBAL STATE (CPU SAFE) | |
| # ========================================================= | |
| SCRIPT_CODE = {} # script_name -> code (CPU only) | |
| PIPELINES = {} # script_name -> pipeline (GPU only, lazy) | |
| log_buffer = io.StringIO() | |
| # ========================================================= | |
| # LOGGING | |
| # ========================================================= | |
| def log(msg): | |
| print(msg) | |
| log_buffer.write(msg + "\n") | |
| def pipeline_technology_info(pipe): | |
| tech = [] | |
| # Device map | |
| if hasattr(pipe, "hf_device_map"): | |
| tech.append("Device map: enabled") | |
| else: | |
| tech.append(f"Device: {pipe.device}") | |
| # Transformer dtype | |
| if hasattr(pipe, "transformer"): | |
| try: | |
| tech.append(f"Transformer dtype: {pipe.transformer.dtype}") | |
| except Exception: | |
| pass | |
| # Layerwise casting (Z-Image specific) | |
| if hasattr(pipe.transformer, "layerwise_casting"): | |
| lw = pipe.transformer.layerwise_casting | |
| tech.append( | |
| f"Layerwise casting: storage={lw.storage_dtype}, compute={lw.compute_dtype}" | |
| ) | |
| # VAE dtype | |
| if hasattr(pipe, "vae"): | |
| try: | |
| tech.append(f"VAE dtype: {pipe.vae.dtype}") | |
| except Exception: | |
| pass | |
| # Quantization / GGUF | |
| if hasattr(pipe, "quantization_config"): | |
| tech.append(f"Quantization: {pipe.quantization_config}") | |
| # Attention backend | |
| if hasattr(pipe, "config"): | |
| attn = getattr(pipe.config, "attn_implementation", None) | |
| if attn: | |
| tech.append(f"Attention: {attn}") | |
| return "\n".join(f"β’ {t}" for t in tech) | |
| # ========================================================= | |
| # LATENT INFO | |
| # ========================================================= | |
| def pipeline_debug_info(pipe): | |
| return f""" | |
| Pipeline Info | |
| ------------- | |
| Device: {pipe.device} | |
| Transformer: {pipe.transformer.__class__.__name__} | |
| VAE: {pipe.vae.__class__.__name__} | |
| """ | |
| def latent_shape_info(height, width, pipe): | |
| h = height // pipe.vae_scale_factor | |
| w = width // pipe.vae_scale_factor | |
| return f"Expected latent size: ({h}, {w})" | |
| # ========================================================= | |
| # DOWNLOAD SCRIPTS (CPU ONLY) | |
| # ========================================================= | |
| def download_scripts(): | |
| resp = requests.get(SCRIPTS_REPO_API) | |
| resp.raise_for_status() | |
| scripts = [] | |
| for item in resp.json(): | |
| if item["name"].endswith(".py"): | |
| scripts.append(item["name"]) | |
| path = os.path.join(LOCAL_SCRIPTS_DIR, item["name"]) | |
| if not os.path.exists(path): | |
| content = requests.get(item["download_url"]).text | |
| with open(path, "w") as f: | |
| f.write(content) | |
| return sorted(scripts) | |
| SCRIPT_NAMES = download_scripts() | |
| # ========================================================= | |
| # REGISTER SCRIPTS (CPU ONLY) | |
| # ========================================================= | |
| def register_scripts(selected_scripts): | |
| SCRIPT_CODE.clear() | |
| for name in selected_scripts: | |
| path = os.path.join(LOCAL_SCRIPTS_DIR, name) | |
| with open(path, "r") as f: | |
| code = f.read() | |
| SCRIPT_CODE[name] = code | |
| # Log the .py file and extract pipe lines | |
| log(f"=== Registering script: {name} ===") | |
| extract_pipe_lines(code) # This logs the full script + pipe lines | |
| return f"{len(SCRIPT_CODE)} script(s) registered β " | |
| # ========================================================= | |
| # EXTRACT LINES AFTER FROM_PRETRAINED | |
| # ========================================================= | |
| def extract_pipe_lines(script_code: str): | |
| lines = script_code.splitlines() | |
| # Log full .py file | |
| log("=== SCRIPT CONTENT START ===") | |
| for l in lines: | |
| log(l) | |
| log("=== SCRIPT CONTENT END ===") | |
| pipe_lines = [] | |
| found = False | |
| for line in lines: | |
| stripped = line.strip() | |
| if not found and stripped.startswith("pipe = ZImagePipeline.from_pretrained"): | |
| found = True | |
| pipe_lines.append(line) | |
| elif found: | |
| # Include all subsequent lines that reference 'pipe' | |
| if "pipe" in stripped: | |
| pipe_lines.append(line) | |
| # Log the extracted lines after from_pretrained | |
| log("π§ Extracted pipe-related lines:") | |
| for l in pipe_lines: | |
| log(f"β’ {l.strip()}") | |
| return pipe_lines | |
| def extract_pipe_lines0(script_code: str): | |
| lines = script_code.splitlines() | |
| print(lines) | |
| pipe_lines = [] | |
| found = False | |
| for line in lines: | |
| stripped = line.strip() | |
| if not found and stripped.startswith("pipe = ZImagePipeline.from_pretrained"): | |
| found = True | |
| pipe_lines.append(line) | |
| elif found: | |
| if "pipe" in stripped: | |
| pipe_lines.append(line) | |
| log(f"π§ Building pipeline from {pipe_lines}") | |
| return pipe_lines | |
| # ========================================================= | |
| # GPU-ONLY PIPELINE BUILDER | |
| # ========================================================= | |
| # ========================================================= | |
| # GPU-ONLY PIPELINE BUILDER (CRITICAL) | |
| # ========================================================= | |
| def get_pipeline(script_name): | |
| if script_name in PIPELINES: | |
| return PIPELINES[script_name] | |
| log(f"π§ Building pipeline from {script_name}") | |
| namespace = { | |
| "__file__": script_name, | |
| "__name__": "__main__", | |
| # Minimal required globals | |
| "torch": torch, | |
| } | |
| try: | |
| exec(SCRIPT_CODE[script_name], namespace) | |
| except Exception as e: | |
| log(f"β Script failed: {script_name}") | |
| raise RuntimeError(f"Pipeline build failed for {script_name}") from e | |
| if "pipe" not in namespace: | |
| raise RuntimeError( | |
| f"{script_name} did not define `pipe`.\n" | |
| f"Each script MUST assign a variable named `pipe`." | |
| ) | |
| PIPELINES[script_name] = namespace["pipe"] | |
| log(f"β Pipeline ready: {script_name}") | |
| return PIPELINES[script_name] | |
| def get_pipeline_fallback(script_name): | |
| if script_name in PIPELINES: | |
| return PIPELINES[script_name] | |
| log(f"π§ Building pipeline from {script_name}") | |
| namespace = { | |
| "__file__": script_name, | |
| "__name__": "__main__", | |
| # Minimal required globals | |
| "torch": torch, | |
| } | |
| try: | |
| exec(SCRIPT_CODE[script_name], namespace) | |
| except Exception as e: | |
| log(f"β Script failed: {script_name}") | |
| raise RuntimeError(f"Pipeline build failed for {script_name}") from e | |
| if "pipe" not in namespace: | |
| raise RuntimeError( | |
| f"{script_name} did not define `pipe`.\n" | |
| f"Each script MUST assign a variable named `pipe`." | |
| ) | |
| PIPELINES[script_name] = namespace["pipe"] | |
| log(f"β Pipeline ready: {script_name}") | |
| return PIPELINES[script_name] | |
| # ========================================================= | |
| # IMAGE GENERATION | |
| # ========================================================= | |
| def generate_image( | |
| prompt, | |
| height, | |
| width, | |
| num_inference_steps, | |
| seed, | |
| randomize_seed, | |
| num_images, | |
| pipeline_name, | |
| ): | |
| log_buffer.truncate(0) | |
| log_buffer.seek(0) | |
| print(prompt) | |
| if pipeline_name not in SCRIPT_CODE: | |
| raise RuntimeError("Pipeline not registered") | |
| pipe = get_pipeline(pipeline_name) | |
| log("=== PIPELINE TECHNOLOGY ===") | |
| log(pipeline_technology_info(pipe)) | |
| if not hasattr(pipe, "hf_device_map"): | |
| pipe = pipe.to("cuda") | |
| log("=== NEW GENERATION REQUEST ===") | |
| log(f"Pipeline: {pipeline_name}") | |
| log(f"Prompt: {prompt}") | |
| log(f"Height: {height}, Width: {width}") | |
| log(f"Steps: {num_inference_steps}") | |
| log(f"Images: {num_images}") | |
| if randomize_seed: | |
| seed = torch.randint(0, 2**32 - 1, (1,)).item() | |
| log(f"Random Seed β {seed}") | |
| else: | |
| log(f"Seed β {seed}") | |
| num_images = min(max(1, int(num_images)), 3) | |
| generator = torch.Generator("cuda").manual_seed(int(seed)) | |
| # Run pipeline | |
| result = pipe( | |
| prompt=prompt, | |
| height=int(height), | |
| width=int(width), | |
| num_inference_steps=int(num_inference_steps), | |
| guidance_scale=0.0, | |
| generator=generator, | |
| max_sequence_length=1024, | |
| num_images_per_prompt=num_images, | |
| output_type="pil", | |
| ) | |
| try: | |
| log(pipeline_debug_info(pipe)) | |
| log(latent_shape_info(height, width, pipe)) | |
| except Exception as e: | |
| log(f"Diagnostics error: {e}") | |
| log("β Generation complete") | |
| return result.images, seed, log_buffer.getvalue() | |
| # ========================================================= | |
| # GRADIO UI (original layout) | |
| # ========================================================= | |
| with gr.Blocks(title="Z-Image-Turbo Multi Image Demo") as demo: | |
| gr.Markdown("# π¨ Z-Image-Turbo β Multi Image ") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| script_selector = gr.CheckboxGroup( | |
| choices=SCRIPT_NAMES, | |
| label="Select pipeline scripts" | |
| ) | |
| register_btn = gr.Button("Register Scripts") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| prompt = gr.Textbox(label="Prompt", lines=4) | |
| with gr.Row(): | |
| height = gr.Slider(512, 2048, 1024, step=64, label="Height") | |
| width = gr.Slider(512, 2048, 1024, step=64, label="Width") | |
| num_images = gr.Slider(1, 3, 2, step=1, label="Number of Images") | |
| num_inference_steps = gr.Slider( | |
| 1, 20, 9, step=1, label="Inference Steps", | |
| info="9 steps = 8 DiT forward passes" | |
| ) | |
| with gr.Row(): | |
| seed = gr.Number(label="Seed", value=42, precision=0) | |
| randomize_seed = gr.Checkbox(label="Randomize Seed", value=False) | |
| generate_btn = gr.Button("π Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| pipeline_picker = gr.Dropdown( | |
| choices=[], | |
| label="Active Pipeline", | |
| ) | |
| output_images = gr.Gallery( label="Generated Images", type="pil", columns=2 ) | |
| used_seed = gr.Number(label="Seed Used", interactive=False) | |
| debug_log = gr.Textbox( | |
| label="Debug Log Output", | |
| lines=25, | |
| interactive=False | |
| ) | |
| register_btn.click( | |
| register_scripts, | |
| inputs=[script_selector], | |
| outputs=[status] | |
| ) | |
| register_btn.click( | |
| lambda s: gr.update(choices=s, value=s[0] if s else None), | |
| inputs=[script_selector], | |
| outputs=[pipeline_picker] | |
| ) | |
| generate_btn.click( | |
| generate_image, | |
| inputs=[prompt, height, width, num_inference_steps, seed, randomize_seed, num_images, pipeline_picker], | |
| outputs=[output_images, used_seed, debug_log] | |
| ) | |
| demo.queue() | |
| demo.launch() | |