Zimg-script / app_lora1.py
rahul7star's picture
Update app_lora1.py
8ed8e09 verified
raw
history blame
11.7 kB
import spaces
import os
import io
import torch
import gradio as gr
import requests
from diffusers import DiffusionPipeline, ZImagePipeline
# =========================================================
# CONFIG
# =========================================================
SCRIPTS_REPO_API = (
"https://api.github.com/repos/asomoza/diffusers-recipes/contents/"
"models/z-image/scripts"
)
LOCAL_SCRIPTS_DIR = "z_image_scripts"
MODEL_ID = "Tongyi-MAI/Z-Image-Turbo"
os.makedirs(LOCAL_SCRIPTS_DIR, exist_ok=True)
# =========================================================
# GLOBAL STATE (CPU SAFE)
# =========================================================
SCRIPT_CODE = {} # script_name -> code (CPU only)
PIPELINES = {} # script_name -> pipeline (GPU only, lazy)
log_buffer = io.StringIO()
# =========================================================
# LOGGING
# =========================================================
def log(msg):
print(msg)
log_buffer.write(msg + "\n")
def pipeline_technology_info(pipe):
tech = []
# Device map
if hasattr(pipe, "hf_device_map"):
tech.append("Device map: enabled")
else:
tech.append(f"Device: {pipe.device}")
# Transformer dtype
if hasattr(pipe, "transformer"):
try:
tech.append(f"Transformer full info:\n{pipe.transformer}")
except Exception:
pass
# Layerwise casting (Z-Image specific)
if hasattr(pipe.transformer, "layerwise_casting"):
lw = pipe.transformer.layerwise_casting
tech.append(
f"Layerwise casting: storage={lw.storage_dtype}, compute={lw.compute_dtype}"
)
# VAE dtype
if hasattr(pipe, "vae"):
try:
tech.append(f"VAE dtype: {pipe.vae.dtype}")
except Exception:
pass
# Quantization / GGUF
if hasattr(pipe, "quantization_config"):
tech.append(f"Quantization: {pipe.quantization_config}")
# Attention backend
if hasattr(pipe, "config"):
attn = getattr(pipe.config, "attn_implementation", None)
if attn:
tech.append(f"Attention: {attn}")
return "\n".join(f"β€’ {t}" for t in tech)
# =========================================================
# LATENT INFO
# =========================================================
def pipeline_debug_info(pipe):
return f"""
Pipeline Info
-------------
Device: {pipe.device}
Transformer: {pipe.transformer.__class__.__name__}
VAE: {pipe.vae.__class__.__name__}
"""
def latent_shape_info(height, width, pipe):
h = height // pipe.vae_scale_factor
w = width // pipe.vae_scale_factor
return f"Expected latent size: ({h}, {w})"
# =========================================================
# DOWNLOAD SCRIPTS (CPU ONLY)
# =========================================================
def download_scripts():
resp = requests.get(SCRIPTS_REPO_API)
resp.raise_for_status()
scripts = []
for item in resp.json():
if item["name"].endswith(".py"):
scripts.append(item["name"])
path = os.path.join(LOCAL_SCRIPTS_DIR, item["name"])
if not os.path.exists(path):
content = requests.get(item["download_url"]).text
with open(path, "w") as f:
f.write(content)
return sorted(scripts)
SCRIPT_NAMES = download_scripts()
# =========================================================
# REGISTER SCRIPTS (CPU ONLY)
# =========================================================
def register_scripts(selected_scripts):
SCRIPT_CODE.clear()
for name in selected_scripts:
path = os.path.join(LOCAL_SCRIPTS_DIR, name)
with open(path, "r") as f:
code = f.read()
SCRIPT_CODE[name] = code
# Log the .py file and extract pipe lines
log(f"=== Registering script: {name} ===")
extract_pipe_lines(code) # This logs the full script + pipe lines
return f"{len(SCRIPT_CODE)} script(s) registered βœ…"
# =========================================================
# EXTRACT LINES AFTER FROM_PRETRAINED
# =========================================================
def extract_pipe_lines(script_code: str):
lines = script_code.splitlines()
# Log full .py file
log("=== SCRIPT CONTENT START ===")
for l in lines:
log(l)
log("=== SCRIPT CONTENT END ===")
pipe_lines = []
found = False
for line in lines:
stripped = line.strip()
if not found and stripped.startswith("pipe = ZImagePipeline.from_pretrained"):
found = True
pipe_lines.append(line)
elif found:
# Include all subsequent lines that reference 'pipe'
if "pipe" in stripped:
pipe_lines.append(line)
# Log the extracted lines after from_pretrained
log("πŸ”§ Extracted pipe-related lines:")
for l in pipe_lines:
log(f"β€’ {l.strip()}")
return pipe_lines
def extract_pipe_lines0(script_code: str):
lines = script_code.splitlines()
print(lines)
pipe_lines = []
found = False
for line in lines:
stripped = line.strip()
if not found and stripped.startswith("pipe = ZImagePipeline.from_pretrained"):
found = True
pipe_lines.append(line)
elif found:
if "pipe" in stripped:
pipe_lines.append(line)
log(f"πŸ”§ Building pipeline from {pipe_lines}")
return pipe_lines
# =========================================================
# GPU-ONLY PIPELINE BUILDER
# =========================================================
# =========================================================
# GPU-ONLY PIPELINE BUILDER (CRITICAL)
# =========================================================
def get_pipeline(script_name):
if script_name in PIPELINES:
return PIPELINES[script_name]
log(f"πŸ”§ Building pipeline from {script_name}")
namespace = {
"__file__": script_name,
"__name__": "__main__",
# Minimal required globals
"torch": torch,
}
try:
exec(SCRIPT_CODE[script_name], namespace)
except Exception as e:
log(f"❌ Script failed: {script_name}")
raise RuntimeError(f"Pipeline build failed for {script_name}") from e
if "pipe" not in namespace:
raise RuntimeError(
f"{script_name} did not define `pipe`.\n"
f"Each script MUST assign a variable named `pipe`."
)
PIPELINES[script_name] = namespace["pipe"]
log(f"βœ… Pipeline ready: {script_name}")
return PIPELINES[script_name]
def get_pipeline_fallback(script_name):
if script_name in PIPELINES:
return PIPELINES[script_name]
log(f"πŸ”§ Building pipeline from {script_name}")
namespace = {
"__file__": script_name,
"__name__": "__main__",
# Minimal required globals
"torch": torch,
}
try:
exec(SCRIPT_CODE[script_name], namespace)
except Exception as e:
log(f"❌ Script failed: {script_name}")
raise RuntimeError(f"Pipeline build failed for {script_name}") from e
if "pipe" not in namespace:
raise RuntimeError(
f"{script_name} did not define `pipe`.\n"
f"Each script MUST assign a variable named `pipe`."
)
PIPELINES[script_name] = namespace["pipe"]
log(f"βœ… Pipeline ready: {script_name}")
return PIPELINES[script_name]
# =========================================================
# IMAGE GENERATION
# =========================================================
@spaces.GPU
def generate_image(
prompt,
height,
width,
num_inference_steps,
seed,
randomize_seed,
num_images,
pipeline_name,
):
log_buffer.truncate(0)
log_buffer.seek(0)
print(prompt)
if pipeline_name not in SCRIPT_CODE:
raise RuntimeError("Pipeline not registered")
pipe = get_pipeline(pipeline_name)
log("=== PIPELINE TECHNOLOGY ===")
log(pipeline_technology_info(pipe))
if not hasattr(pipe, "hf_device_map"):
pipe = pipe.to("cuda")
log("=== NEW GENERATION REQUEST ===")
log(f"Pipeline: {pipeline_name}")
log(f"Prompt: {prompt}")
log(f"Height: {height}, Width: {width}")
log(f"Steps: {num_inference_steps}")
log(f"Images: {num_images}")
if randomize_seed:
seed = torch.randint(0, 2**32 - 1, (1,)).item()
log(f"Random Seed β†’ {seed}")
else:
log(f"Seed β†’ {seed}")
num_images = min(max(1, int(num_images)), 3)
generator = torch.Generator("cuda").manual_seed(int(seed))
# Run pipeline
result = pipe(
prompt=prompt,
height=int(height),
width=int(width),
num_inference_steps=int(num_inference_steps),
guidance_scale=0.0,
generator=generator,
max_sequence_length=1024,
num_images_per_prompt=num_images,
output_type="pil",
)
try:
log(pipeline_debug_info(pipe))
log(latent_shape_info(height, width, pipe))
except Exception as e:
log(f"Diagnostics error: {e}")
log("βœ… Generation complete")
return result.images, seed, log_buffer.getvalue()
# =========================================================
# GRADIO UI (original layout)
# =========================================================
with gr.Blocks(title="Z-Image-Turbo Multi Image Demo") as demo:
gr.Markdown("# 🎨 Z-Image-Turbo β€” Multi Image ")
with gr.Row():
with gr.Column(scale=1):
script_selector = gr.CheckboxGroup(
choices=SCRIPT_NAMES,
label="Select pipeline scripts"
)
register_btn = gr.Button("Register Scripts")
status = gr.Textbox(label="Status", interactive=False)
prompt = gr.Textbox(label="Prompt", lines=4)
with gr.Row():
height = gr.Slider(512, 2048, 1024, step=64, label="Height")
width = gr.Slider(512, 2048, 1024, step=64, label="Width")
num_images = gr.Slider(1, 3, 2, step=1, label="Number of Images")
num_inference_steps = gr.Slider(
1, 20, 9, step=1, label="Inference Steps",
info="9 steps = 8 DiT forward passes"
)
with gr.Row():
seed = gr.Number(label="Seed", value=42, precision=0)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
generate_btn = gr.Button("πŸš€ Generate", variant="primary")
with gr.Column(scale=1):
pipeline_picker = gr.Dropdown(
choices=[],
label="Active Pipeline",
)
output_images = gr.Gallery( label="Generated Images", type="pil", columns=2 )
used_seed = gr.Number(label="Seed Used", interactive=False)
debug_log = gr.Textbox(
label="Debug Log Output",
lines=25,
interactive=False
)
register_btn.click(
register_scripts,
inputs=[script_selector],
outputs=[status]
)
register_btn.click(
lambda s: gr.update(choices=s, value=s[0] if s else None),
inputs=[script_selector],
outputs=[pipeline_picker]
)
generate_btn.click(
generate_image,
inputs=[prompt, height, width, num_inference_steps, seed, randomize_seed, num_images, pipeline_picker],
outputs=[output_images, used_seed, debug_log]
)
demo.queue()
demo.launch()