rahul7star commited on
Commit
cd46e64
·
verified ·
1 Parent(s): b7323fd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1088 -0
app.py ADDED
@@ -0,0 +1,1088 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+ import gradio as gr
4
+ import sys
5
+ import platform
6
+ import diffusers
7
+ import transformers
8
+ import psutil
9
+ import os
10
+ import time
11
+ import traceback
12
+
13
+ from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
14
+ from diffusers import ZImagePipeline, AutoModel
15
+ from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
16
+ latent_history = []
17
+
18
+ # ============================================================
19
+ # LOGGING BUFFER
20
+ # ============================================================
21
+ LOGS = ""
22
+ def log(msg):
23
+ global LOGS
24
+ print(msg)
25
+ LOGS += msg + "\n"
26
+ return msg
27
+
28
+
29
+ # ============================================================
30
+ # SYSTEM METRICS — LIVE GPU + CPU MONITORING
31
+ # ============================================================
32
+ def log_system_stats(tag=""):
33
+ try:
34
+ log(f"\n===== 🔥 SYSTEM STATS {tag} =====")
35
+
36
+ # ============= GPU STATS =============
37
+ if torch.cuda.is_available():
38
+ allocated = torch.cuda.memory_allocated(0) / 1e9
39
+ reserved = torch.cuda.memory_reserved(0) / 1e9
40
+ total = torch.cuda.get_device_properties(0).total_memory / 1e9
41
+ free = total - allocated
42
+
43
+ log(f"💠 GPU Total : {total:.2f} GB")
44
+ log(f"💠 GPU Allocated : {allocated:.2f} GB")
45
+ log(f"💠 GPU Reserved : {reserved:.2f} GB")
46
+ log(f"💠 GPU Free : {free:.2f} GB")
47
+
48
+ # ============= CPU STATS ============
49
+ cpu = psutil.cpu_percent()
50
+ ram_used = psutil.virtual_memory().used / 1e9
51
+ ram_total = psutil.virtual_memory().total / 1e9
52
+
53
+ log(f"🧠 CPU Usage : {cpu}%")
54
+ log(f"🧠 RAM Used : {ram_used:.2f} GB / {ram_total:.2f} GB")
55
+
56
+ except Exception as e:
57
+ log(f"⚠️ Failed to log system stats: {e}")
58
+
59
+
60
+ # ============================================================
61
+ # ENVIRONMENT INFO
62
+ # ============================================================
63
+ log("===================================================")
64
+ log("🔍 Z-IMAGE-TURBO DEBUGGING + LIVE METRIC LOGGER")
65
+ log("===================================================\n")
66
+
67
+ log(f"📌 PYTHON VERSION : {sys.version.replace(chr(10),' ')}")
68
+ log(f"📌 PLATFORM : {platform.platform()}")
69
+ log(f"📌 TORCH VERSION : {torch.__version__}")
70
+ log(f"📌 TRANSFORMERS VERSION : {transformers.__version__}")
71
+ log(f"📌 DIFFUSERS VERSION : {diffusers.__version__}")
72
+ log(f"📌 CUDA AVAILABLE : {torch.cuda.is_available()}")
73
+
74
+ log_system_stats("AT STARTUP")
75
+
76
+ if not torch.cuda.is_available():
77
+ raise RuntimeError("❌ CUDA Required")
78
+
79
+ device = "cuda"
80
+ gpu_id = 0
81
+
82
+ # ============================================================
83
+ # MODEL SETTINGS
84
+ # ============================================================
85
+ model_cache = "./weights/"
86
+ model_id = "Tongyi-MAI/Z-Image-Turbo"
87
+ torch_dtype = torch.bfloat16
88
+ USE_CPU_OFFLOAD = False
89
+
90
+ log("\n===================================================")
91
+ log("🧠 MODEL CONFIGURATION")
92
+ log("===================================================")
93
+ log(f"Model ID : {model_id}")
94
+ log(f"Model Cache Directory : {model_cache}")
95
+ log(f"torch_dtype : {torch_dtype}")
96
+ log(f"USE_CPU_OFFLOAD : {USE_CPU_OFFLOAD}")
97
+
98
+ log_system_stats("BEFORE TRANSFORMER LOAD")
99
+
100
+
101
+ # ============================================================
102
+ # FUNCTION TO CONVERT LATENTS TO IMAGE
103
+ # ============================================================
104
+ def latent_to_image(latent):
105
+ """
106
+ Convert a latent tensor to a PIL image using pipe.vae
107
+ """
108
+ try:
109
+ img_tensor = pipe.vae.decode(latent)
110
+ img_tensor = (img_tensor / 2 + 0.5).clamp(0, 1)
111
+ pil_img = T.ToPILImage()(img_tensor[0].cpu()) # <--- single image
112
+ return pil_img
113
+ except Exception as e:
114
+ log(f"⚠️ Failed to decode latent: {e}")
115
+ # fallback blank image
116
+ return Image.new("RGB", (latent.shape[-1]*8, latent.shape[-2]*8), color=(255,255,255))
117
+
118
+
119
+
120
+ # ============================================================
121
+ # SAFE TRANSFORMER INSPECTION
122
+ # ============================================================
123
+ def inspect_transformer(model, name):
124
+ log(f"\n🔍🔍 FULL TRANSFORMER DEBUG DUMP: {name}")
125
+ log("=" * 80)
126
+
127
+ try:
128
+ log(f"Model class : {model.__class__.__name__}")
129
+ log(f"DType : {getattr(model, 'dtype', 'unknown')}")
130
+ log(f"Device : {next(model.parameters()).device}")
131
+ log(f"Requires Grad? : {any(p.requires_grad for p in model.parameters())}")
132
+
133
+ # Check quantization
134
+ if hasattr(model, "is_loaded_in_4bit"):
135
+ log(f"4bit Quantization : {model.is_loaded_in_4bit}")
136
+ if hasattr(model, "is_loaded_in_8bit"):
137
+ log(f"8bit Quantization : {model.is_loaded_in_8bit}")
138
+
139
+ # Find blocks
140
+ candidates = ["transformer_blocks", "blocks", "layers", "encoder", "model"]
141
+ blocks = None
142
+ chosen_attr = None
143
+
144
+ for attr in candidates:
145
+ if hasattr(model, attr):
146
+ blocks = getattr(model, attr)
147
+ chosen_attr = attr
148
+ break
149
+
150
+ log(f"Block container attr : {chosen_attr}")
151
+
152
+ if blocks is None:
153
+ log("⚠️ No valid block container found.")
154
+ return
155
+
156
+ if not hasattr(blocks, "__len__"):
157
+ log("⚠️ Blocks exist but not iterable.")
158
+ return
159
+
160
+ total = len(blocks)
161
+ log(f"Total Blocks : {total}")
162
+ log("-" * 80)
163
+
164
+ # Inspect first N blocks
165
+ N = min(20, total)
166
+ for i in range(N):
167
+ block = blocks[i]
168
+ log(f"\n🧩 Block [{i}/{total-1}]")
169
+ log(f"Class: {block.__class__.__name__}")
170
+
171
+ # Print submodules
172
+ for n, m in block.named_children():
173
+ log(f" ├─ {n}: {m.__class__.__name__}")
174
+
175
+ # Print attention related
176
+ if hasattr(block, "attn"):
177
+ attn = block.attn
178
+ log(f" ├─ Attention: {attn.__class__.__name__}")
179
+ log(f" │ Heads : {getattr(attn, 'num_heads', 'unknown')}")
180
+ log(f" │ Dim : {getattr(attn, 'hidden_size', 'unknown')}")
181
+ log(f" │ Backend : {getattr(attn, 'attention_backend', 'unknown')}")
182
+
183
+ # Device + dtype info
184
+ try:
185
+ dev = next(block.parameters()).device
186
+ log(f" ├─ Device : {dev}")
187
+ except StopIteration:
188
+ pass
189
+
190
+ try:
191
+ dt = next(block.parameters()).dtype
192
+ log(f" ├─ DType : {dt}")
193
+ except StopIteration:
194
+ pass
195
+
196
+ log("\n🔚 END TRANSFORMER DEBUG DUMP")
197
+ log("=" * 80)
198
+
199
+ except Exception as e:
200
+ log(f"❌ ERROR IN INSPECTOR: {e}")
201
+ import torch
202
+ import time
203
+
204
+ # ---------- UTILITY ----------
205
+ def pretty_header(title):
206
+ log("\n\n" + "=" * 80)
207
+ log(f"🎛️ {title}")
208
+ log("=" * 80 + "\n")
209
+
210
+
211
+ # ---------- MEMORY ----------
212
+ def get_vram(prefix=""):
213
+ try:
214
+ allocated = torch.cuda.memory_allocated() / 1024**2
215
+ reserved = torch.cuda.memory_reserved() / 1024**2
216
+ log(f"{prefix}Allocated VRAM : {allocated:.2f} MB")
217
+ log(f"{prefix}Reserved VRAM : {reserved:.2f} MB")
218
+ except:
219
+ log(f"{prefix}VRAM: CUDA not available")
220
+
221
+
222
+ # ---------- MODULE INSPECT ----------
223
+ def inspect_module(name, module):
224
+ pretty_header(f"🔬 Inspecting {name}")
225
+
226
+ try:
227
+ log(f"📦 Class : {module.__class__.__name__}")
228
+ log(f"🔢 DType : {getattr(module, 'dtype', 'unknown')}")
229
+ log(f"💻 Device : {next(module.parameters()).device}")
230
+ log(f"🧮 Params : {sum(p.numel() for p in module.parameters()):,}")
231
+
232
+ # Quantization state
233
+ if hasattr(module, "is_loaded_in_4bit"):
234
+ log(f"⚙️ 4-bit QLoRA : {module.is_loaded_in_4bit}")
235
+ if hasattr(module, "is_loaded_in_8bit"):
236
+ log(f"⚙️ 8-bit load : {module.is_loaded_in_8bit}")
237
+
238
+ # Attention backend (DiT)
239
+ if hasattr(module, "set_attention_backend"):
240
+ try:
241
+ attn = getattr(module, "attention_backend", None)
242
+ log(f"🚀 Attention Backend: {attn}")
243
+ except:
244
+ pass
245
+
246
+ # Search for blocks
247
+ candidates = ["transformer_blocks", "blocks", "layers", "encoder", "model"]
248
+ blocks = None
249
+ chosen_attr = None
250
+
251
+ for attr in candidates:
252
+ if hasattr(module, attr):
253
+ blocks = getattr(module, attr)
254
+ chosen_attr = attr
255
+ break
256
+
257
+ log(f"\n📚 Block Container : {chosen_attr}")
258
+
259
+ if blocks is None:
260
+ log("⚠️ No block structure found")
261
+ return
262
+
263
+ if not hasattr(blocks, "__len__"):
264
+ log("⚠️ Blocks exist but are not iterable")
265
+ return
266
+
267
+ total = len(blocks)
268
+ log(f"🔢 Total Blocks : {total}\n")
269
+
270
+ # Inspect first 15 blocks
271
+ N = min(15, total)
272
+
273
+ for i in range(N):
274
+ blk = blocks[i]
275
+ log(f"\n🧩 Block [{i}/{total-1}] — {blk.__class__.__name__}")
276
+
277
+ for n, m in blk.named_children():
278
+ log(f" ├─ {n:<15} {m.__class__.__name__}")
279
+
280
+ # Attention details
281
+ if hasattr(blk, "attn"):
282
+ a = blk.attn
283
+ log(f" ├─ Attention")
284
+ log(f" │ Heads : {getattr(a, 'num_heads', 'unknown')}")
285
+ log(f" │ Dim : {getattr(a, 'hidden_size', 'unknown')}")
286
+ log(f" │ Backend : {getattr(a, 'attention_backend', 'unknown')}")
287
+
288
+ # Device / dtype
289
+ try:
290
+ log(f" ├─ Device : {next(blk.parameters()).device}")
291
+ log(f" ├─ DType : {next(blk.parameters()).dtype}")
292
+ except StopIteration:
293
+ pass
294
+
295
+ get_vram(" ▶ ")
296
+
297
+ except Exception as e:
298
+ log(f"❌ Module inspect error: {e}")
299
+
300
+
301
+ # ---------- LORA INSPECTION ----------
302
+ def inspect_loras(pipe):
303
+ pretty_header("🧩 LoRA ADAPTERS")
304
+
305
+ try:
306
+ if not hasattr(pipe, "lora_state_dict") and not hasattr(pipe, "adapter_names"):
307
+ log("⚠️ No LoRA system detected.")
308
+ return
309
+
310
+ if hasattr(pipe, "adapter_names"):
311
+ names = pipe.adapter_names
312
+ log(f"Available Adapters: {names}")
313
+
314
+ if hasattr(pipe, "active_adapters"):
315
+ log(f"Active Adapters : {pipe.active_adapters}")
316
+
317
+ if hasattr(pipe, "lora_scale"):
318
+ log(f"LoRA Scale : {pipe.lora_scale}")
319
+
320
+ # LoRA modules
321
+ if hasattr(pipe, "transformer") and hasattr(pipe.transformer, "modules"):
322
+ for name, module in pipe.transformer.named_modules():
323
+ if "lora" in name.lower():
324
+ log(f" 🔧 LoRA Module: {name} ({module.__class__.__name__})")
325
+
326
+ except Exception as e:
327
+ log(f"❌ LoRA inspect error: {e}")
328
+
329
+
330
+ # ---------- PIPELINE INSPECTOR ----------
331
+ def debug_pipeline(pipe):
332
+ pretty_header("🚀 FULL PIPELINE DEBUGGING")
333
+
334
+ try:
335
+ log(f"Pipeline Class : {pipe.__class__.__name__}")
336
+ log(f"Attention Impl : {getattr(pipe, 'attn_implementation', 'unknown')}")
337
+ log(f"Device : {pipe.device}")
338
+ except:
339
+ pass
340
+
341
+ get_vram("▶ ")
342
+
343
+ # Inspect TRANSFORMER
344
+ if hasattr(pipe, "transformer"):
345
+ inspect_module("Transformer", pipe.transformer)
346
+
347
+ # Inspect TEXT ENCODER
348
+ if hasattr(pipe, "text_encoder") and pipe.text_encoder is not None:
349
+ inspect_module("Text Encoder", pipe.text_encoder)
350
+
351
+ # Inspect UNET (if ZImage pipeline has it)
352
+ if hasattr(pipe, "unet"):
353
+ inspect_module("UNet", pipe.unet)
354
+
355
+ # LoRA adapters
356
+ inspect_loras(pipe)
357
+
358
+ pretty_header("🎉 END DEBUG REPORT")
359
+
360
+
361
+
362
+ # ============================================================
363
+ # LOAD TRANSFORMER — WITH LIVE STATS
364
+ # ============================================================
365
+ log("\n===================================================")
366
+ log("🔧 LOADING TRANSFORMER BLOCK")
367
+ log("===================================================")
368
+
369
+ log("📌 Logging memory before load:")
370
+ log_system_stats("START TRANSFORMER LOAD")
371
+
372
+ try:
373
+ quant_cfg = DiffusersBitsAndBytesConfig(
374
+ load_in_4bit=True,
375
+ bnb_4bit_quant_type="nf4",
376
+ bnb_4bit_compute_dtype=torch_dtype,
377
+ bnb_4bit_use_double_quant=True,
378
+ )
379
+
380
+ transformer = AutoModel.from_pretrained(
381
+ model_id,
382
+ cache_dir=model_cache,
383
+ subfolder="transformer",
384
+ quantization_config=quant_cfg,
385
+ torch_dtype=torch_dtype,
386
+ device_map=device,
387
+ )
388
+ log("✅ Transformer loaded successfully.")
389
+
390
+ except Exception as e:
391
+ log(f"❌ Transformer load failed: {e}")
392
+ transformer = None
393
+
394
+ log_system_stats("AFTER TRANSFORMER LOAD")
395
+
396
+ if transformer:
397
+ inspect_transformer(transformer, "Transformer")
398
+
399
+
400
+ # ============================================================
401
+ # LOAD TEXT ENCODER
402
+ # ============================================================
403
+ log("\n===================================================")
404
+ log("🔧 LOADING TEXT ENCODER")
405
+ log("===================================================")
406
+
407
+ log_system_stats("START TEXT ENCODER LOAD")
408
+
409
+ try:
410
+ quant_cfg2 = TransformersBitsAndBytesConfig(
411
+ load_in_4bit=True,
412
+ bnb_4bit_quant_type="nf4",
413
+ bnb_4bit_compute_dtype=torch_dtype,
414
+ bnb_4bit_use_double_quant=True,
415
+ )
416
+
417
+ text_encoder = AutoModel.from_pretrained(
418
+ model_id,
419
+ cache_dir=model_cache,
420
+ subfolder="text_encoder",
421
+ quantization_config=quant_cfg2,
422
+ torch_dtype=torch_dtype,
423
+ device_map=device,
424
+ )
425
+ log("✅ Text encoder loaded successfully.")
426
+
427
+ except Exception as e:
428
+ log(f"❌ Text encoder load failed: {e}")
429
+ text_encoder = None
430
+
431
+ log_system_stats("AFTER TEXT ENCODER LOAD")
432
+
433
+ if text_encoder:
434
+ inspect_transformer(text_encoder, "Text Encoder")
435
+
436
+
437
+ # ============================================================
438
+ # BUILD PIPELINE
439
+ # ============================================================
440
+ log("\n===================================================")
441
+ log("🔧 BUILDING PIPELINE")
442
+ log("===================================================")
443
+
444
+ log_system_stats("START PIPELINE BUILD")
445
+
446
+ try:
447
+ pipe = ZImagePipeline.from_pretrained(
448
+ model_id,
449
+ transformer=transformer,
450
+ text_encoder=text_encoder,
451
+ torch_dtype=torch_dtype,
452
+
453
+ )
454
+ # If transformer supports setting backend, prefer flash-3
455
+ try:
456
+ if hasattr(pipe, "transformer") and hasattr(pipe.transformer, "set_attention_backend"):
457
+ pipe.transformer.set_attention_backend("_flash_3")
458
+ log("✅ transformer.set_attention_backend('_flash_3') called")
459
+ except Exception as _e:
460
+ log(f"⚠️ set_attention_backend failed: {_e}")
461
+
462
+ # default LoRA load (keeps your existing behaviour)
463
+ try:
464
+ pipe.load_lora_weights("rahul7star/ZImageLora",
465
+ weight_name="NSFW/doggystyle_pov.safetensors", adapter_name="lora")
466
+ pipe.set_adapters(["lora",], adapter_weights=[1.])
467
+ pipe.fuse_lora(adapter_names=["lora"], lora_scale=0.75)
468
+ except Exception as _e:
469
+ log(f"⚠️ Default LoRA load failed: {_e}")
470
+
471
+ debug_pipeline(pipe)
472
+ # pipe.unload_lora_weights()
473
+ pipe.to("cuda")
474
+ log("✅ Pipeline built successfully.")
475
+ LOGS += log("Pipeline build completed.") + "\n"
476
+ except Exception as e:
477
+ log(f"❌ Pipeline build failed: {e}")
478
+ log(traceback.format_exc())
479
+ pipe = None
480
+
481
+ log_system_stats("AFTER PIPELINE BUILD")
482
+
483
+
484
+ # -----------------------------
485
+ # Monkey-patch prepare_latents (safe)
486
+ # -----------------------------
487
+ if pipe is not None and hasattr(pipe, "prepare_latents"):
488
+ original_prepare_latents = pipe.prepare_latents
489
+
490
+ def logged_prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
491
+ try:
492
+ result_latents = original_prepare_latents(batch_size, num_channels_latents, height, width, dtype, device, generator, latents)
493
+ log_msg = f"🔹 prepare_latents called | shape={result_latents.shape}, dtype={result_latents.dtype}, device={result_latents.device}"
494
+ if hasattr(self, "_latents_log"):
495
+ self._latents_log.append(log_msg)
496
+ else:
497
+ self._latents_log = [log_msg]
498
+ return result_latents
499
+ except Exception as e:
500
+ log(f"⚠️ prepare_latents wrapper failed: {e}")
501
+ raise
502
+
503
+ # apply patch safely
504
+ try:
505
+ pipe.prepare_latents = logged_prepare_latents.__get__(pipe)
506
+ log("✅ prepare_latents monkey-patched")
507
+ except Exception as e:
508
+ log(f"⚠️ Failed to attach prepare_latents patch: {e}")
509
+ else:
510
+ log("❌ WARNING: Pipe not initialized or prepare_latents missing; skipping prepare_latents patch")
511
+
512
+
513
+ from PIL import Image
514
+ import torch
515
+
516
+ # --------------------------
517
+ # Helper: Safe latent extractor
518
+ # --------------------------
519
+ def safe_get_latents(pipe, height, width, generator, device, LOGS):
520
+ """
521
+ Safely prepare latents for any ZImagePipeline variant.
522
+ Returns latents tensor, logs issues instead of failing.
523
+ """
524
+ try:
525
+ # Determine number of channels
526
+ num_channels = 4 # default fallback
527
+ if hasattr(pipe, "unet") and hasattr(pipe.unet, "in_channels"):
528
+ num_channels = pipe.unet.in_channels
529
+ elif hasattr(pipe, "vae") and hasattr(pipe.vae, "latent_channels"):
530
+ num_channels = pipe.vae.latent_channels # some pipelines define this
531
+ LOGS.append(f"🔹 Using num_channels={num_channels} for latents")
532
+
533
+ latents = pipe.prepare_latents(
534
+ batch_size=1,
535
+ num_channels_latents=num_channels,
536
+ height=height,
537
+ width=width,
538
+ dtype=torch.float32,
539
+ device=device,
540
+ generator=generator,
541
+ )
542
+
543
+ LOGS.append(f"🔹 Latents shape: {latents.shape}, dtype: {latents.dtype}, device: {latents.device}")
544
+ return latents
545
+ except Exception as e:
546
+ LOGS.append(f"⚠️ Latent extraction failed: {e}")
547
+ # fallback: guess a safe shape
548
+ fallback_channels = 16 # try standard default for ZImage pipelines
549
+ latents = torch.randn((1, fallback_channels, height // 8, width // 8),
550
+ generator=generator, device=device)
551
+ LOGS.append(f"🔹 Using fallback random latents shape: {latents.shape}")
552
+ return latents
553
+
554
+ # --------------------------
555
+ # Main generation function (kept exactly as your logic)
556
+ # --------------------------
557
+ from huggingface_hub import HfApi, HfFolder
558
+ import torch
559
+ import os
560
+
561
+ HF_REPO_ID = "rahul7star/Zstudio-latent" # Model repo
562
+ HF_TOKEN = HfFolder.get_token() # Make sure you are logged in via `huggingface-cli login`
563
+
564
+ def upload_latents_to_hf(latent_dict, filename="latents.pt"):
565
+ local_path = f"/tmp/{filename}"
566
+ torch.save(latent_dict, local_path)
567
+ try:
568
+ api = HfApi()
569
+ api.upload_file(
570
+ path_or_fileobj=local_path,
571
+ path_in_repo=filename,
572
+ repo_id=HF_REPO_ID,
573
+ token=HF_TOKEN,
574
+ repo_type="model" # since this is a model repo
575
+ )
576
+ os.remove(local_path)
577
+ return f"https://huggingface.co/{HF_REPO_ID}/resolve/main/{filename}"
578
+ except Exception as e:
579
+ os.remove(local_path)
580
+ raise e
581
+
582
+
583
+
584
+ import asyncio
585
+ import torch
586
+ from PIL import Image
587
+
588
+ async def async_upload_latents(latent_dict, filename, LOGS):
589
+ try:
590
+ hf_url = await upload_latents_to_hf(latent_dict, filename=filename) # assume this can be async
591
+ LOGS.append(f"🔹 All preview latents uploaded: {hf_url}")
592
+ except Exception as e:
593
+ LOGS.append(f"⚠️ Failed to upload all preview latents: {e}")
594
+
595
+
596
+ # this code genetae all frame for latest GPU expseinve bt decide fails sp use this later
597
+ @spaces.GPU
598
+ def generate_image_all_latents(prompt, height, width, steps, seed, guidance_scale=0.0):
599
+ LOGS = []
600
+ device = "cpu" # FORCE CPU
601
+ generator = torch.Generator(device).manual_seed(int(seed))
602
+
603
+ placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
604
+ latent_gallery = []
605
+ final_gallery = []
606
+
607
+ last_four_latents = [] # we only upload 4
608
+
609
+ # --------------------------------------------------
610
+ # LATENT PREVIEW GENERATION (CPU MODE)
611
+ # --------------------------------------------------
612
+ try:
613
+ latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
614
+ latents = latents.to("cpu") # keep EVERYTHING CPU
615
+
616
+ timestep_count = len(pipe.scheduler.timesteps)
617
+ preview_every = max(1, timestep_count // 10)
618
+
619
+ for i, t in enumerate(pipe.scheduler.timesteps):
620
+
621
+ # -------------- decode latent preview --------------
622
+ try:
623
+ with torch.no_grad():
624
+ latent_cpu = latents.to(pipe.vae.dtype) # match VAE dtype
625
+ decoded = pipe.vae.decode(latent_cpu).sample # [1,3,H,W]
626
+ decoded = (decoded / 2 + 0.5).clamp(0, 1)
627
+ decoded = decoded[0].permute(1,2,0).cpu().numpy()
628
+ latent_img = Image.fromarray((decoded * 255).astype("uint8"))
629
+ except Exception:
630
+ latent_img = placeholder
631
+ LOGS.append("⚠️ Latent preview decode failed.")
632
+
633
+ latent_gallery.append(latent_img)
634
+
635
+ # store last 4 latent states
636
+ if len(last_four_latents) >= 4:
637
+ last_four_latents.pop(0)
638
+ last_four_latents.append(latents.cpu().clone())
639
+
640
+ # UI preview yields
641
+ if i % preview_every == 0:
642
+ yield None, latent_gallery, LOGS
643
+
644
+ # --------------------------------------------------
645
+ # UPLOAD LAST 4 LATENTS (SYNC)
646
+ # --------------------------------------------------
647
+ try:
648
+ upload_dict = {
649
+ "last_4_latents": last_four_latents,
650
+ "prompt": prompt,
651
+ "seed": seed
652
+ }
653
+
654
+ hf_url = upload_latents_to_hf(
655
+ upload_dict,
656
+ filename=f"latents_last4_{seed}.pt"
657
+ )
658
+
659
+ LOGS.append(f"🔹 Uploaded last 4 latents: {hf_url}")
660
+
661
+ except Exception as e:
662
+ LOGS.append(f"⚠️ Failed to upload latents: {e}")
663
+
664
+ except Exception as e:
665
+ LOGS.append(f"⚠️ Latent generation failed: {e}")
666
+ latent_gallery.append(placeholder)
667
+ yield None, latent_gallery, LOGS
668
+
669
+ # --------------------------------------------------
670
+ # FINAL IMAGE - UNTOUCHED
671
+ # --------------------------------------------------
672
+ try:
673
+ output = pipe(
674
+ prompt=prompt,
675
+ height=height,
676
+ width=width,
677
+ num_inference_steps=steps,
678
+ guidance_scale=guidance_scale,
679
+ generator=generator,
680
+ )
681
+ final_img = output.images[0]
682
+ LOGS.append("✅ Standard pipeline succeeded.")
683
+
684
+ yield final_img, latent_gallery, LOGS
685
+
686
+ except Exception as e2:
687
+ LOGS.append(f"❌ Standard pipeline failed: {e2}")
688
+ yield placeholder, latent_gallery, LOGS
689
+
690
+ @spaces.GPU
691
+ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
692
+ LOGS = []
693
+ device = "cuda"
694
+ cpu_device = "cpu"
695
+ generator = torch.Generator(device).manual_seed(int(seed))
696
+
697
+ placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
698
+ latent_gallery = []
699
+ final_gallery = []
700
+
701
+ last_latents = [] # store last 5 preview latents on CPU
702
+
703
+ try:
704
+ # --- Initial latents ---
705
+ latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
706
+ latents = latents.float().to(cpu_device) # move to CPU
707
+
708
+ num_previews = min(10, steps)
709
+ preview_indices = torch.linspace(0, steps - 1, num_previews).long()
710
+
711
+ for i, step_idx in enumerate(preview_indices):
712
+ try:
713
+ with torch.no_grad():
714
+ # --- Z-Image Turbo-style denoise simulation ---
715
+ t = 1.0 - (i / num_previews) # linear decay [1.0 -> 0.0]
716
+ noise_scale = t ** 0.5 # reduce noise over steps (sqrt for smoother)
717
+ denoise_latent = latents * t + torch.randn_like(latents) * noise_scale
718
+
719
+ # Move to VAE device & dtype
720
+ denoise_latent = denoise_latent.to(pipe.vae.device).to(pipe.vae.dtype)
721
+
722
+ # Decode latent to image
723
+ decoded = pipe.vae.decode(denoise_latent, return_dict=False)[0]
724
+ decoded = (decoded / 2 + 0.5).clamp(0, 1)
725
+ decoded = decoded.cpu().permute(0, 2, 3, 1).float().numpy()
726
+ decoded = (decoded * 255).round().astype("uint8")
727
+ latent_img = Image.fromarray(decoded[0])
728
+
729
+ except Exception as e:
730
+ LOGS.append(f"⚠️ Latent preview decode failed: {e}")
731
+ latent_img = placeholder
732
+
733
+ latent_gallery.append(latent_img)
734
+
735
+ # Keep last 5 latents only
736
+ last_latents.append(denoise_latent.cpu().clone())
737
+ if len(last_latents) > 5:
738
+ last_latents.pop(0)
739
+
740
+ # Show only last 5 previews in UI
741
+ yield None, latent_gallery[-5:], LOGS
742
+
743
+ # Optionally: upload last 5 latents
744
+ # latent_dict = {"latents": last_latents, "prompt": prompt, "seed": seed}
745
+ # hf_url = upload_latents_to_hf(latent_dict, filename=f"latents_last5_{seed}.pt")
746
+ # LOGS.append(f"🔹 Last 5 latents uploaded: {hf_url}")
747
+
748
+ except Exception as e:
749
+ LOGS.append(f"⚠️ Latent generation failed: {e}")
750
+ latent_gallery.append(placeholder)
751
+ yield None, latent_gallery[-5:], LOGS
752
+
753
+ # --- Final image on GPU ---
754
+ try:
755
+ output = pipe(
756
+ prompt=prompt,
757
+ height=height,
758
+ width=width,
759
+ num_inference_steps=steps,
760
+ guidance_scale=guidance_scale,
761
+ generator=generator,
762
+ )
763
+ final_img = output.images[0]
764
+ final_gallery.append(final_img)
765
+ latent_gallery.append(final_img)
766
+ LOGS.append("✅ Standard pipeline succeeded.")
767
+ yield final_img, latent_gallery[-5:] + [final_img], LOGS # last 5 previews + final
768
+
769
+ except Exception as e2:
770
+ LOGS.append(f"❌ Standard pipeline failed: {e2}")
771
+ final_gallery.append(placeholder)
772
+ latent_gallery.append(placeholder)
773
+ yield placeholder, latent_gallery[-5:] + [placeholder], LOGS
774
+
775
+
776
+
777
+ # this is astable vesopn tha can gen final and a noise to latent
778
+ @spaces.GPU
779
+ def generate_image_verygood_realnoise(prompt, height, width, steps, seed, guidance_scale=0.0):
780
+ LOGS = []
781
+ device = "cuda"
782
+ generator = torch.Generator(device).manual_seed(int(seed))
783
+
784
+ placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
785
+ latent_gallery = []
786
+ final_gallery = []
787
+
788
+ # --- Generate latent previews ---
789
+ try:
790
+ latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
791
+ latents = latents.float() # keep float32 until decode
792
+
793
+ num_previews = min(10, steps)
794
+ preview_steps = torch.linspace(0, 1, num_previews)
795
+
796
+ for alpha in preview_steps:
797
+ try:
798
+ with torch.no_grad():
799
+ # Simulate denoising progression like Z-Image Turbo
800
+ preview_latent = latents * alpha + latents * 0 # optional: simple progression
801
+
802
+ # Move to same device and dtype as VAE
803
+ preview_latent = preview_latent.to(pipe.vae.device).to(pipe.vae.dtype)
804
+
805
+ # Decode
806
+ decoded = pipe.vae.decode(preview_latent, return_dict=False)[0]
807
+
808
+ # Convert to PIL following same logic as final image
809
+ decoded = (decoded / 2 + 0.5).clamp(0, 1)
810
+ decoded = decoded.cpu().permute(0, 2, 3, 1).float().numpy()
811
+ decoded = (decoded * 255).round().astype("uint8")
812
+ latent_img = Image.fromarray(decoded[0])
813
+
814
+ except Exception as e:
815
+ LOGS.append(f"⚠️ Latent preview decode failed: {e}")
816
+ latent_img = placeholder
817
+
818
+ latent_gallery.append(latent_img)
819
+ yield None, latent_gallery, LOGS
820
+
821
+ except Exception as e:
822
+ LOGS.append(f"⚠️ Latent generation failed: {e}")
823
+ latent_gallery.append(placeholder)
824
+ yield None, latent_gallery, LOGS
825
+
826
+ # --- Final image: untouched ---
827
+ try:
828
+ output = pipe(
829
+ prompt=prompt,
830
+ height=height,
831
+ width=width,
832
+ num_inference_steps=steps,
833
+ guidance_scale=guidance_scale,
834
+ generator=generator,
835
+ )
836
+ final_img = output.images[0]
837
+ final_gallery.append(final_img)
838
+ latent_gallery.append(final_img) # fallback preview
839
+ LOGS.append("✅ Standard pipeline succeeded.")
840
+ yield final_img, latent_gallery, LOGS
841
+
842
+ except Exception as e2:
843
+ LOGS.append(f"❌ Standard pipeline failed: {e2}")
844
+ final_gallery.append(placeholder)
845
+ latent_gallery.append(placeholder)
846
+ yield placeholder, latent_gallery, LOGS
847
+
848
+
849
+
850
+
851
+ # DO NOT TOUCH this is astable vesopn tha can gen final and a noise to latent with latent upload to repo
852
+ @spaces.GPU
853
+ def generate_image_safe(prompt, height, width, steps, seed, guidance_scale=0.0):
854
+ LOGS = []
855
+ device = "cuda"
856
+ generator = torch.Generator(device).manual_seed(int(seed))
857
+
858
+ placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
859
+ latent_gallery = []
860
+ final_gallery = []
861
+
862
+ # --- Generate latent previews in a loop ---
863
+ try:
864
+ latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
865
+
866
+ # Convert latents to float32 if necessary
867
+ if latents.dtype != torch.float32:
868
+ latents = latents.float()
869
+
870
+ # Loop for multiple previews before final image
871
+ num_previews = min(10, steps) # show ~10 previews
872
+ preview_steps = torch.linspace(0, 1, num_previews)
873
+
874
+ for i, alpha in enumerate(preview_steps):
875
+ try:
876
+ with torch.no_grad():
877
+ # Simple noise interpolation for preview (simulate denoising progress)
878
+ preview_latent = latents * alpha + torch.randn_like(latents) * (1 - alpha)
879
+ # Decode to PIL
880
+ latent_img_tensor = pipe.vae.decode(preview_latent).sample # [1,3,H,W]
881
+ latent_img_tensor = (latent_img_tensor / 2 + 0.5).clamp(0, 1)
882
+ latent_img_tensor = latent_img_tensor.cpu().permute(0, 2, 3, 1)[0]
883
+ latent_img = Image.fromarray((latent_img_tensor.numpy() * 255).astype('uint8'))
884
+ except Exception as e:
885
+ LOGS.append(f"⚠️ Latent preview decode failed: {e}")
886
+ latent_img = placeholder
887
+
888
+ latent_gallery.append(latent_img)
889
+ yield None, latent_gallery, LOGS # update Gradio with intermediate preview
890
+
891
+ # Save final latents to HF
892
+ latent_dict = {"latents": latents.cpu(), "prompt": prompt, "seed": seed}
893
+ try:
894
+ hf_url = upload_latents_to_hf(latent_dict, filename=f"latents_{seed}.pt")
895
+ LOGS.append(f"🔹 Latents uploaded: {hf_url}")
896
+ except Exception as e:
897
+ LOGS.append(f"⚠️ Failed to upload latents: {e}")
898
+
899
+ except Exception as e:
900
+ LOGS.append(f"⚠️ Latent generation failed: {e}")
901
+ latent_gallery.append(placeholder)
902
+ yield None, latent_gallery, LOGS
903
+
904
+ # --- Final image: untouched standard pipeline ---
905
+ try:
906
+ output = pipe(
907
+ prompt=prompt,
908
+ height=height,
909
+ width=width,
910
+ num_inference_steps=steps,
911
+ guidance_scale=guidance_scale,
912
+ generator=generator,
913
+ )
914
+ final_img = output.images[0]
915
+ final_gallery.append(final_img)
916
+ latent_gallery.append(final_img) # fallback preview if needed
917
+ LOGS.append("✅ Standard pipeline succeeded.")
918
+ yield final_img, latent_gallery, LOGS
919
+
920
+ except Exception as e2:
921
+ LOGS.append(f"❌ Standard pipeline failed: {e2}")
922
+ final_gallery.append(placeholder)
923
+ latent_gallery.append(placeholder)
924
+ yield placeholder, latent_gallery, LOGS
925
+
926
+
927
+
928
+
929
+
930
+
931
+
932
+ with gr.Blocks(title="Z-Image-Turbo") as demo:
933
+ gr.Markdown("# 🎨 DO NOT RUN THIS ")
934
+ with gr.Tabs():
935
+ with gr.TabItem("Image & Latents"):
936
+ with gr.Row():
937
+ with gr.Column(scale=1):
938
+ prompt = gr.Textbox(label="Prompt", value="boat in Ocean")
939
+ height = gr.Slider(256, 2048, value=1024, step=8, label="Height")
940
+ width = gr.Slider(256, 2048, value=1024, step=8, label="Width")
941
+ steps = gr.Slider(1, 50, value=20, step=1, label="Inference Steps")
942
+ seed = gr.Number(value=42, label="Seed")
943
+ run_btn = gr.Button("Generate Image")
944
+
945
+ with gr.Column(scale=1):
946
+ final_image = gr.Image(label="Final Image")
947
+ latent_gallery = gr.Gallery(
948
+ label="Latent Steps", columns=4, height=256, preview=True
949
+ )
950
+
951
+ with gr.TabItem("Logs"):
952
+ logs_box = gr.Textbox(label="All Logs", lines=25)
953
+
954
+ # New UI: LoRA repo textbox, dropdown, refresh & rebuild
955
+ with gr.Row():
956
+ lora_repo = gr.Textbox(label="LoRA Repo (HF id)", value="rahul7star/ZImageLora", placeholder="e.g. rahul7star/ZImageLora")
957
+ lora_dropdown = gr.Dropdown(choices=[], label="LoRA files (from local cache)")
958
+ refresh_lora_btn = gr.Button("Refresh LoRA List")
959
+ rebuild_pipe_btn = gr.Button("Rebuild pipeline (use selected LoRA)")
960
+
961
+ # Refresh callback: repopulate dropdown from repo text
962
+ def refresh_lora_list(repo_name):
963
+ try:
964
+ files = list_loras_from_repo(repo_name)
965
+ if not files:
966
+ return gr.update(choices=[], value=None)
967
+ return gr.update(choices=files, value=files[0])
968
+ except Exception as e:
969
+ log(f"⚠️ refresh_lora_list failed: {e}")
970
+ return gr.update(choices=[], value=None)
971
+
972
+ refresh_lora_btn.click(refresh_lora_list, inputs=[lora_repo], outputs=[lora_dropdown])
973
+
974
+ # Rebuild callback: build pipeline with selected lora file path (if any)
975
+ def rebuild_pipeline_with_lora(lora_path, repo_name):
976
+ global pipe, LOGS
977
+ try:
978
+ log(f"🔄 Rebuilding pipeline using LoRA repo={repo_name} file={lora_path}")
979
+ # call existing logic to rebuild: attempt to create new pipeline then load lora file
980
+ pipe = ZImagePipeline.from_pretrained(
981
+ model_id,
982
+ transformer=transformer,
983
+ text_encoder=text_encoder,
984
+ torch_dtype=torch_dtype,
985
+ )
986
+ # try set backend
987
+ try:
988
+ if hasattr(pipe, "transformer") and hasattr(pipe.transformer, "set_attention_backend"):
989
+ pipe.transformer.set_attention_backend("_flash_3")
990
+ except Exception as _e:
991
+ log(f"⚠️ set_attention_backend failed during rebuild: {_e}")
992
+
993
+ # load selected lora if provided
994
+ if lora_path:
995
+ weight_name_to_use = None
996
+
997
+ # If dropdown provided a relative-style path (contains a slash or no leading /),
998
+ # use it directly as weight_name (HF expects "path/inside/repo.safetensors")
999
+ if ("/" in lora_path) and not os.path.isabs(lora_path):
1000
+ weight_name_to_use = lora_path
1001
+ else:
1002
+ # It might be an absolute path in cache; try to compute relative path to repo cache root
1003
+ abs_path = lora_path if os.path.isabs(lora_path) else None
1004
+ if abs_path and os.path.exists(abs_path):
1005
+ # attempt to find repo-root-ish substring in abs_path
1006
+ repo_variants = [
1007
+ repo_name.replace("/", "--"),
1008
+ repo_name.replace("/", "-"),
1009
+ repo_name.replace("/", "_"),
1010
+ repo_name.split("/")[-1],
1011
+ ]
1012
+ chosen_base = None
1013
+ for v in repo_variants:
1014
+ idx = abs_path.find(v)
1015
+ if idx != -1:
1016
+ chosen_base = abs_path[: idx + len(v)]
1017
+ break
1018
+ if chosen_base:
1019
+ try:
1020
+ rel = os.path.relpath(abs_path, chosen_base)
1021
+ if rel and not rel.startswith(".."):
1022
+ weight_name_to_use = rel.replace(os.sep, "/")
1023
+ except Exception:
1024
+ weight_name_to_use = None
1025
+
1026
+ # fallback to basename
1027
+ if weight_name_to_use is None:
1028
+ weight_name_to_use = os.path.basename(lora_path)
1029
+
1030
+ # Now attempt to load
1031
+ try:
1032
+ pipe.load_lora_weights(repo_name or "rahul7star/ZImageLora",
1033
+ weight_name=weight_name_to_use,
1034
+ adapter_name="lora")
1035
+ pipe.set_adapters(["lora"], adapter_weights=[1.])
1036
+ pipe.fuse_lora(adapter_names=["lora"], lora_scale=0.75)
1037
+ log(f"✅ Loaded LoRA weight: {weight_name_to_use} from repo {repo_name}")
1038
+ except Exception as _e:
1039
+ log(f"⚠️ Failed to load selected LoRA during rebuild using weight_name='{weight_name_to_use}': {_e}")
1040
+ # as last resort, try loading using basename
1041
+ try:
1042
+ fallback_name = os.path.basename(lora_path)
1043
+ pipe.load_lora_weights(repo_name or "rahul7star/ZImageLora",
1044
+ weight_name=fallback_name,
1045
+ adapter_name="lora")
1046
+ pipe.set_adapters(["lora"], adapter_weights=[1.])
1047
+ pipe.fuse_lora(adapter_names=["lora"], lora_scale=0.75)
1048
+ log(f"✅ Fallback loaded LoRA weight basename: {fallback_name}")
1049
+ except Exception as _e2:
1050
+ log(f"❌ Fallback LoRA load also failed: {_e2}")
1051
+
1052
+ # finalize
1053
+ debug_pipeline(pipe)
1054
+ pipe.to("cuda")
1055
+ # re-attach monkey patch safely
1056
+ if pipe is not None and hasattr(pipe, "prepare_latents"):
1057
+ try:
1058
+ original_prepare = pipe.prepare_latents
1059
+ def logged_prepare(self, *args, **kwargs):
1060
+ lat = original_prepare(*args, **kwargs)
1061
+ msg = f"🔹 prepare_latents called | shape={lat.shape}, dtype={lat.dtype}"
1062
+ if hasattr(self, "_latents_log"):
1063
+ self._latents_log.append(msg)
1064
+ else:
1065
+ self._latents_log = [msg]
1066
+ return lat
1067
+ pipe.prepare_latents = logged_prepare.__get__(pipe)
1068
+ log("✅ Re-applied prepare_latents monkey patch after rebuild")
1069
+ except Exception as _e:
1070
+ log(f"⚠️ Could not re-apply prepare_latents patch: {_e}")
1071
+ return "\n".join([LOGS, "Rebuild complete."])
1072
+ except Exception as e:
1073
+ log(f"❌ Rebuild pipeline failed: {e}")
1074
+ log(traceback.format_exc())
1075
+ return "\n".join([LOGS, f"Rebuild failed: {e}"])
1076
+
1077
+ rebuild_pipe_btn.click(rebuild_pipeline_with_lora, inputs=[lora_dropdown, lora_repo], outputs=[logs_box])
1078
+
1079
+ # Wire the button AFTER all components exist
1080
+
1081
+ run_btn.click(
1082
+ generate_image,
1083
+ inputs=[prompt, height, width, steps, seed],
1084
+ outputs=[final_image, latent_gallery, logs_box]
1085
+ )
1086
+
1087
+
1088
+ demo.launch()