rahul7star commited on
Commit
7665754
Β·
verified Β·
1 Parent(s): d4bd238

Create app_lora1.py

Browse files
Files changed (1) hide show
  1. app_lora1.py +201 -0
app_lora1.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import traceback
5
+ from diffusers import ZImagePipeline
6
+ from huggingface_hub import list_repo_files
7
+ from PIL import Image
8
+
9
+ # ============================================================
10
+ # CONFIG
11
+ # ============================================================
12
+
13
+ MODEL_ID = "Tongyi-MAI/Z-Image-Turbo"
14
+ DEFAULT_LORA_REPO = "rahul7star/ZImageLora"
15
+
16
+ DTYPE = torch.bfloat16
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # ============================================================
20
+ # GLOBAL STATE
21
+ # ============================================================
22
+
23
+ pipe = None
24
+ CURRENT_LORA_REPO = None
25
+ CURRENT_LORA_FILE = None
26
+
27
+ # ============================================================
28
+ # LOGGING
29
+ # ============================================================
30
+
31
+ def log(msg):
32
+ print(msg)
33
+ return msg
34
+
35
+ # ============================================================
36
+ # PIPELINE BUILD (ONCE)
37
+ # ============================================================
38
+
39
+ try:
40
+ pipe = ZImagePipeline.from_pretrained(
41
+ MODEL_ID,
42
+ torch_dtype=DTYPE,
43
+ )
44
+
45
+ if hasattr(pipe, "transformer") and hasattr(pipe.transformer, "set_attention_backend"):
46
+ pipe.transformer.set_attention_backend("_flash_3")
47
+
48
+ pipe.to(DEVICE)
49
+ log("βœ… Pipeline built successfully")
50
+
51
+ except Exception as e:
52
+ log("❌ Pipeline build failed")
53
+ log(traceback.format_exc())
54
+ pipe = None
55
+
56
+ # ============================================================
57
+ # HELPERS
58
+ # ============================================================
59
+
60
+ def list_loras_from_repo(repo_id: str):
61
+ try:
62
+ files = list_repo_files(repo_id)
63
+ return [f for f in files if f.endswith(".safetensors")]
64
+ except Exception as e:
65
+ log(f"❌ Failed to list LoRAs: {e}")
66
+ return []
67
+
68
+ # ============================================================
69
+ # IMAGE GENERATION (SAFE LORA LOGIC)
70
+ # ============================================================
71
+
72
+ def generate_image(prompt, height, width, steps, seed, guidance_scale):
73
+ LOGS = []
74
+
75
+ if pipe is None:
76
+ return None, [], "❌ Pipeline not initialized"
77
+
78
+ generator = torch.Generator(DEVICE).manual_seed(int(seed))
79
+ placeholder = Image.new("RGB", (width, height), (255, 255, 255))
80
+ previews = []
81
+
82
+ # ---- Always start clean ----
83
+ try:
84
+ pipe.unload_lora_weights()
85
+ except Exception:
86
+ pass
87
+
88
+ # ---- Load LoRA for this run only ----
89
+ if CURRENT_LORA_FILE:
90
+ try:
91
+ pipe.load_lora_weights(
92
+ CURRENT_LORA_REPO,
93
+ weight_name=CURRENT_LORA_FILE
94
+ )
95
+ LOGS.append(f"🧩 LoRA loaded: {CURRENT_LORA_FILE}")
96
+ except Exception as e:
97
+ LOGS.append(f"❌ LoRA load failed: {e}")
98
+
99
+ # ---- Preview steps (lightweight) ----
100
+ try:
101
+ num_previews = min(5, steps)
102
+ for i in range(num_previews):
103
+ out = pipe(
104
+ prompt=prompt,
105
+ height=height // 4,
106
+ width=width // 4,
107
+ num_inference_steps=i + 1,
108
+ guidance_scale=guidance_scale,
109
+ generator=generator,
110
+ )
111
+ img = out.images[0].resize((width, height))
112
+ previews.append(img)
113
+ yield None, previews, "\n".join(LOGS)
114
+ except Exception as e:
115
+ LOGS.append(f"⚠️ Preview failed: {e}")
116
+
117
+ # ---- Final image ----
118
+ try:
119
+ out = pipe(
120
+ prompt=prompt,
121
+ height=height,
122
+ width=width,
123
+ num_inference_steps=steps,
124
+ guidance_scale=guidance_scale,
125
+ generator=generator,
126
+ )
127
+ final_img = out.images[0]
128
+ previews.append(final_img)
129
+ LOGS.append("βœ… Image generated")
130
+
131
+ yield final_img, previews, "\n".join(LOGS)
132
+
133
+ except Exception as e:
134
+ LOGS.append(f"❌ Generation failed: {e}")
135
+ yield placeholder, previews, "\n".join(LOGS)
136
+
137
+ finally:
138
+ # ---- CRITICAL: unload after run ----
139
+ try:
140
+ pipe.unload_lora_weights()
141
+ LOGS.append("🧹 LoRA unloaded")
142
+ except Exception:
143
+ pass
144
+
145
+ # ============================================================
146
+ # GRADIO UI
147
+ # ============================================================
148
+
149
+ with gr.Blocks(title="Z-Image-Turbo (LoRA Safe)") as demo:
150
+ gr.Markdown("# 🎨 Z-Image-Turbo β€” Runtime LoRA (SAFE MODE)")
151
+
152
+ with gr.Row():
153
+ with gr.Column(scale=1):
154
+ prompt = gr.Textbox(label="Prompt", value="boat in ocean")
155
+ height = gr.Slider(256, 2048, value=1024, step=8, label="Height")
156
+ width = gr.Slider(256, 2048, value=1024, step=8, label="Width")
157
+ steps = gr.Slider(1, 50, value=20, step=1, label="Inference Steps")
158
+ guidance = gr.Slider(0, 10, value=0.0, step=0.5, label="Guidance Scale")
159
+ seed = gr.Number(value=42, label="Seed")
160
+ run_btn = gr.Button("πŸš€ Generate")
161
+
162
+ with gr.Column(scale=1):
163
+ final_image = gr.Image(label="Final Image")
164
+ gallery = gr.Gallery(label="Steps", columns=4, height=256)
165
+
166
+ gr.Markdown("## 🧩 LoRA Controls")
167
+
168
+ with gr.Row():
169
+ lora_repo = gr.Textbox(label="LoRA Repo", value=DEFAULT_LORA_REPO)
170
+ lora_dropdown = gr.Dropdown(label="LoRA File", choices=[])
171
+ refresh_btn = gr.Button("πŸ”„ Refresh")
172
+
173
+ logs_box = gr.Textbox(label="Logs", lines=18)
174
+
175
+ # ---- Callbacks ----
176
+
177
+ def refresh_loras(repo):
178
+ files = list_loras_from_repo(repo)
179
+ return gr.update(choices=files, value=files[0] if files else None)
180
+
181
+ refresh_btn.click(refresh_loras, inputs=[lora_repo], outputs=[lora_dropdown])
182
+
183
+ def select_lora(lora_file, repo):
184
+ global CURRENT_LORA_FILE, CURRENT_LORA_REPO
185
+ CURRENT_LORA_FILE = lora_file
186
+ CURRENT_LORA_REPO = repo
187
+ return f"🧩 Selected LoRA: {lora_file}"
188
+
189
+ lora_dropdown.change(
190
+ select_lora,
191
+ inputs=[lora_dropdown, lora_repo],
192
+ outputs=[logs_box],
193
+ )
194
+
195
+ run_btn.click(
196
+ generate_image,
197
+ inputs=[prompt, height, width, steps, seed, guidance],
198
+ outputs=[final_image, gallery, logs_box],
199
+ )
200
+
201
+ demo.launch()