rahul7star commited on
Commit
ece8251
·
verified ·
1 Parent(s): 01ffa33

Create app_lora.py

Browse files
Files changed (1) hide show
  1. app_lora.py +1091 -0
app_lora.py ADDED
@@ -0,0 +1,1091 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+ import gradio as gr
4
+ import sys
5
+ import platform
6
+ import diffusers
7
+ import transformers
8
+ import psutil
9
+ import os
10
+ import time
11
+ import traceback
12
+
13
+ from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
14
+ from diffusers import ZImagePipeline, AutoModel
15
+ from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
16
+ latent_history = []
17
+
18
+ # ============================================================
19
+ # LOGGING BUFFER
20
+ # ============================================================
21
+ LOGS = ""
22
+ def log(msg):
23
+ global LOGS
24
+ print(msg)
25
+ LOGS += msg + "\n"
26
+ return msg
27
+
28
+
29
+ # ============================================================
30
+ # SYSTEM METRICS — LIVE GPU + CPU MONITORING
31
+ # ============================================================
32
+ def log_system_stats(tag=""):
33
+ try:
34
+ log(f"\n===== 🔥 SYSTEM STATS {tag} =====")
35
+
36
+ # ============= GPU STATS =============
37
+ if torch.cuda.is_available():
38
+ allocated = torch.cuda.memory_allocated(0) / 1e9
39
+ reserved = torch.cuda.memory_reserved(0) / 1e9
40
+ total = torch.cuda.get_device_properties(0).total_memory / 1e9
41
+ free = total - allocated
42
+
43
+ log(f"💠 GPU Total : {total:.2f} GB")
44
+ log(f"💠 GPU Allocated : {allocated:.2f} GB")
45
+ log(f"💠 GPU Reserved : {reserved:.2f} GB")
46
+ log(f"💠 GPU Free : {free:.2f} GB")
47
+
48
+ # ============= CPU STATS ============
49
+ cpu = psutil.cpu_percent()
50
+ ram_used = psutil.virtual_memory().used / 1e9
51
+ ram_total = psutil.virtual_memory().total / 1e9
52
+
53
+ log(f"🧠 CPU Usage : {cpu}%")
54
+ log(f"🧠 RAM Used : {ram_used:.2f} GB / {ram_total:.2f} GB")
55
+
56
+ except Exception as e:
57
+ log(f"⚠️ Failed to log system stats: {e}")
58
+
59
+
60
+ # ============================================================
61
+ # ENVIRONMENT INFO
62
+ # ============================================================
63
+ log("===================================================")
64
+ log("🔍 Z-IMAGE-TURBO DEBUGGING + LIVE METRIC LOGGER")
65
+ log("===================================================\n")
66
+
67
+ log(f"📌 PYTHON VERSION : {sys.version.replace(chr(10),' ')}")
68
+ log(f"📌 PLATFORM : {platform.platform()}")
69
+ log(f"📌 TORCH VERSION : {torch.__version__}")
70
+ log(f"📌 TRANSFORMERS VERSION : {transformers.__version__}")
71
+ log(f"📌 DIFFUSERS VERSION : {diffusers.__version__}")
72
+ log(f"📌 CUDA AVAILABLE : {torch.cuda.is_available()}")
73
+
74
+ log_system_stats("AT STARTUP")
75
+
76
+ if not torch.cuda.is_available():
77
+ raise RuntimeError("❌ CUDA Required")
78
+
79
+ device = "cuda"
80
+ gpu_id = 0
81
+
82
+ # ============================================================
83
+ # MODEL SETTINGS
84
+ # ============================================================
85
+ model_cache = "./weights/"
86
+ model_id = "Tongyi-MAI/Z-Image-Turbo"
87
+ torch_dtype = torch.bfloat16
88
+ USE_CPU_OFFLOAD = False
89
+
90
+ log("\n===================================================")
91
+ log("🧠 MODEL CONFIGURATION")
92
+ log("===================================================")
93
+ log(f"Model ID : {model_id}")
94
+ log(f"Model Cache Directory : {model_cache}")
95
+ log(f"torch_dtype : {torch_dtype}")
96
+ log(f"USE_CPU_OFFLOAD : {USE_CPU_OFFLOAD}")
97
+
98
+ log_system_stats("BEFORE TRANSFORMER LOAD")
99
+
100
+
101
+ # ============================================================
102
+ # FUNCTION TO CONVERT LATENTS TO IMAGE
103
+ # ============================================================
104
+ def latent_to_image(latent):
105
+ """
106
+ Convert a latent tensor to a PIL image using pipe.vae
107
+ """
108
+ try:
109
+ img_tensor = pipe.vae.decode(latent)
110
+ img_tensor = (img_tensor / 2 + 0.5).clamp(0, 1)
111
+ pil_img = T.ToPILImage()(img_tensor[0].cpu()) # <--- single image
112
+ return pil_img
113
+ except Exception as e:
114
+ log(f"⚠️ Failed to decode latent: {e}")
115
+ # fallback blank image
116
+ return Image.new("RGB", (latent.shape[-1]*8, latent.shape[-2]*8), color=(255,255,255))
117
+
118
+
119
+
120
+ # ============================================================
121
+ # SAFE TRANSFORMER INSPECTION
122
+ # ============================================================
123
+ def inspect_transformer(model, name):
124
+ log(f"\n🔍🔍 FULL TRANSFORMER DEBUG DUMP: {name}")
125
+ log("=" * 80)
126
+
127
+ try:
128
+ log(f"Model class : {model.__class__.__name__}")
129
+ log(f"DType : {getattr(model, 'dtype', 'unknown')}")
130
+ log(f"Device : {next(model.parameters()).device}")
131
+ log(f"Requires Grad? : {any(p.requires_grad for p in model.parameters())}")
132
+
133
+ # Check quantization
134
+ if hasattr(model, "is_loaded_in_4bit"):
135
+ log(f"4bit Quantization : {model.is_loaded_in_4bit}")
136
+ if hasattr(model, "is_loaded_in_8bit"):
137
+ log(f"8bit Quantization : {model.is_loaded_in_8bit}")
138
+
139
+ # Find blocks
140
+ candidates = ["transformer_blocks", "blocks", "layers", "encoder", "model"]
141
+ blocks = None
142
+ chosen_attr = None
143
+
144
+ for attr in candidates:
145
+ if hasattr(model, attr):
146
+ blocks = getattr(model, attr)
147
+ chosen_attr = attr
148
+ break
149
+
150
+ log(f"Block container attr : {chosen_attr}")
151
+
152
+ if blocks is None:
153
+ log("⚠️ No valid block container found.")
154
+ return
155
+
156
+ if not hasattr(blocks, "__len__"):
157
+ log("⚠️ Blocks exist but not iterable.")
158
+ return
159
+
160
+ total = len(blocks)
161
+ log(f"Total Blocks : {total}")
162
+ log("-" * 80)
163
+
164
+ # Inspect first N blocks
165
+ N = min(20, total)
166
+ for i in range(N):
167
+ block = blocks[i]
168
+ log(f"\n🧩 Block [{i}/{total-1}]")
169
+ log(f"Class: {block.__class__.__name__}")
170
+
171
+ # Print submodules
172
+ for n, m in block.named_children():
173
+ log(f" ├─ {n}: {m.__class__.__name__}")
174
+
175
+ # Print attention related
176
+ if hasattr(block, "attn"):
177
+ attn = block.attn
178
+ log(f" ├─ Attention: {attn.__class__.__name__}")
179
+ log(f" │ Heads : {getattr(attn, 'num_heads', 'unknown')}")
180
+ log(f" │ Dim : {getattr(attn, 'hidden_size', 'unknown')}")
181
+ log(f" │ Backend : {getattr(attn, 'attention_backend', 'unknown')}")
182
+
183
+ # Device + dtype info
184
+ try:
185
+ dev = next(block.parameters()).device
186
+ log(f" ├─ Device : {dev}")
187
+ except StopIteration:
188
+ pass
189
+
190
+ try:
191
+ dt = next(block.parameters()).dtype
192
+ log(f" ├─ DType : {dt}")
193
+ except StopIteration:
194
+ pass
195
+
196
+ log("\n🔚 END TRANSFORMER DEBUG DUMP")
197
+ log("=" * 80)
198
+
199
+ except Exception as e:
200
+ log(f"❌ ERROR IN INSPECTOR: {e}")
201
+ import torch
202
+ import time
203
+
204
+ # ---------- UTILITY ----------
205
+ def pretty_header(title):
206
+ log("\n\n" + "=" * 80)
207
+ log(f"🎛️ {title}")
208
+ log("=" * 80 + "\n")
209
+
210
+
211
+ # ---------- MEMORY ----------
212
+ def get_vram(prefix=""):
213
+ try:
214
+ allocated = torch.cuda.memory_allocated() / 1024**2
215
+ reserved = torch.cuda.memory_reserved() / 1024**2
216
+ log(f"{prefix}Allocated VRAM : {allocated:.2f} MB")
217
+ log(f"{prefix}Reserved VRAM : {reserved:.2f} MB")
218
+ except:
219
+ log(f"{prefix}VRAM: CUDA not available")
220
+
221
+
222
+ # ---------- MODULE INSPECT ----------
223
+ def inspect_module(name, module):
224
+ pretty_header(f"🔬 Inspecting {name}")
225
+
226
+ try:
227
+ log(f"📦 Class : {module.__class__.__name__}")
228
+ log(f"🔢 DType : {getattr(module, 'dtype', 'unknown')}")
229
+ log(f"💻 Device : {next(module.parameters()).device}")
230
+ log(f"🧮 Params : {sum(p.numel() for p in module.parameters()):,}")
231
+
232
+ # Quantization state
233
+ if hasattr(module, "is_loaded_in_4bit"):
234
+ log(f"⚙️ 4-bit QLoRA : {module.is_loaded_in_4bit}")
235
+ if hasattr(module, "is_loaded_in_8bit"):
236
+ log(f"⚙️ 8-bit load : {module.is_loaded_in_8bit}")
237
+
238
+ # Attention backend (DiT)
239
+ if hasattr(module, "set_attention_backend"):
240
+ try:
241
+ attn = getattr(module, "attention_backend", None)
242
+ log(f"🚀 Attention Backend: {attn}")
243
+ except:
244
+ pass
245
+
246
+ # Search for blocks
247
+ candidates = ["transformer_blocks", "blocks", "layers", "encoder", "model"]
248
+ blocks = None
249
+ chosen_attr = None
250
+
251
+ for attr in candidates:
252
+ if hasattr(module, attr):
253
+ blocks = getattr(module, attr)
254
+ chosen_attr = attr
255
+ break
256
+
257
+ log(f"\n📚 Block Container : {chosen_attr}")
258
+
259
+ if blocks is None:
260
+ log("⚠️ No block structure found")
261
+ return
262
+
263
+ if not hasattr(blocks, "__len__"):
264
+ log("⚠️ Blocks exist but are not iterable")
265
+ return
266
+
267
+ total = len(blocks)
268
+ log(f"🔢 Total Blocks : {total}\n")
269
+
270
+ # Inspect first 15 blocks
271
+ N = min(15, total)
272
+
273
+ for i in range(N):
274
+ blk = blocks[i]
275
+ log(f"\n🧩 Block [{i}/{total-1}] — {blk.__class__.__name__}")
276
+
277
+ for n, m in blk.named_children():
278
+ log(f" ├─ {n:<15} {m.__class__.__name__}")
279
+
280
+ # Attention details
281
+ if hasattr(blk, "attn"):
282
+ a = blk.attn
283
+ log(f" ├─ Attention")
284
+ log(f" │ Heads : {getattr(a, 'num_heads', 'unknown')}")
285
+ log(f" │ Dim : {getattr(a, 'hidden_size', 'unknown')}")
286
+ log(f" │ Backend : {getattr(a, 'attention_backend', 'unknown')}")
287
+
288
+ # Device / dtype
289
+ try:
290
+ log(f" ├─ Device : {next(blk.parameters()).device}")
291
+ log(f" ├─ DType : {next(blk.parameters()).dtype}")
292
+ except StopIteration:
293
+ pass
294
+
295
+ get_vram(" ▶ ")
296
+
297
+ except Exception as e:
298
+ log(f"❌ Module inspect error: {e}")
299
+
300
+
301
+ # ---------- LORA INSPECTION ----------
302
+ def inspect_loras(pipe):
303
+ pretty_header("🧩 LoRA ADAPTERS")
304
+
305
+ try:
306
+ if not hasattr(pipe, "lora_state_dict") and not hasattr(pipe, "adapter_names"):
307
+ log("⚠️ No LoRA system detected.")
308
+ return
309
+
310
+ if hasattr(pipe, "adapter_names"):
311
+ names = pipe.adapter_names
312
+ log(f"Available Adapters: {names}")
313
+
314
+ if hasattr(pipe, "active_adapters"):
315
+ log(f"Active Adapters : {pipe.active_adapters}")
316
+
317
+ if hasattr(pipe, "lora_scale"):
318
+ log(f"LoRA Scale : {pipe.lora_scale}")
319
+
320
+ # LoRA modules
321
+ if hasattr(pipe, "transformer") and hasattr(pipe.transformer, "modules"):
322
+ for name, module in pipe.transformer.named_modules():
323
+ if "lora" in name.lower():
324
+ log(f" 🔧 LoRA Module: {name} ({module.__class__.__name__})")
325
+
326
+ except Exception as e:
327
+ log(f"❌ LoRA inspect error: {e}")
328
+
329
+
330
+ # ---------- PIPELINE INSPECTOR ----------
331
+ def debug_pipeline(pipe):
332
+ pretty_header("🚀 FULL PIPELINE DEBUGGING")
333
+
334
+ try:
335
+ log(f"Pipeline Class : {pipe.__class__.__name__}")
336
+ log(f"Attention Impl : {getattr(pipe, 'attn_implementation', 'unknown')}")
337
+ log(f"Device : {pipe.device}")
338
+ except:
339
+ pass
340
+
341
+ get_vram("▶ ")
342
+
343
+ # Inspect TRANSFORMER
344
+ if hasattr(pipe, "transformer"):
345
+ inspect_module("Transformer", pipe.transformer)
346
+
347
+ # Inspect TEXT ENCODER
348
+ if hasattr(pipe, "text_encoder") and pipe.text_encoder is not None:
349
+ inspect_module("Text Encoder", pipe.text_encoder)
350
+
351
+ # Inspect UNET (if ZImage pipeline has it)
352
+ if hasattr(pipe, "unet"):
353
+ inspect_module("UNet", pipe.unet)
354
+
355
+ # LoRA adapters
356
+ inspect_loras(pipe)
357
+
358
+ pretty_header("🎉 END DEBUG REPORT")
359
+
360
+
361
+
362
+ # ============================================================
363
+ # LOAD TRANSFORMER — WITH LIVE STATS
364
+ # ============================================================
365
+ log("\n===================================================")
366
+ log("🔧 LOADING TRANSFORMER BLOCK")
367
+ log("===================================================")
368
+
369
+ log("📌 Logging memory before load:")
370
+ log_system_stats("START TRANSFORMER LOAD")
371
+
372
+ try:
373
+ quant_cfg = DiffusersBitsAndBytesConfig(
374
+ load_in_4bit=True,
375
+ bnb_4bit_quant_type="nf4",
376
+ bnb_4bit_compute_dtype=torch_dtype,
377
+ bnb_4bit_use_double_quant=True,
378
+ )
379
+
380
+ transformer = AutoModel.from_pretrained(
381
+ model_id,
382
+ cache_dir=model_cache,
383
+ subfolder="transformer",
384
+ quantization_config=quant_cfg,
385
+ torch_dtype=torch_dtype,
386
+ device_map=device,
387
+ )
388
+ log("✅ Transformer loaded successfully.")
389
+
390
+ except Exception as e:
391
+ log(f"❌ Transformer load failed: {e}")
392
+ transformer = None
393
+
394
+ log_system_stats("AFTER TRANSFORMER LOAD")
395
+
396
+ if transformer:
397
+ inspect_transformer(transformer, "Transformer")
398
+
399
+
400
+ # ============================================================
401
+ # LOAD TEXT ENCODER
402
+ # ============================================================
403
+ log("\n===================================================")
404
+ log("🔧 LOADING TEXT ENCODER")
405
+ log("===================================================")
406
+
407
+ log_system_stats("START TEXT ENCODER LOAD")
408
+
409
+ try:
410
+ quant_cfg2 = TransformersBitsAndBytesConfig(
411
+ load_in_4bit=True,
412
+ bnb_4bit_quant_type="nf4",
413
+ bnb_4bit_compute_dtype=torch_dtype,
414
+ bnb_4bit_use_double_quant=True,
415
+ )
416
+
417
+ text_encoder = AutoModel.from_pretrained(
418
+ model_id,
419
+ cache_dir=model_cache,
420
+ subfolder="text_encoder",
421
+ quantization_config=quant_cfg2,
422
+ torch_dtype=torch_dtype,
423
+ device_map=device,
424
+ )
425
+ log("✅ Text encoder loaded successfully.")
426
+
427
+ except Exception as e:
428
+ log(f"❌ Text encoder load failed: {e}")
429
+ text_encoder = None
430
+
431
+ log_system_stats("AFTER TEXT ENCODER LOAD")
432
+
433
+ if text_encoder:
434
+ inspect_transformer(text_encoder, "Text Encoder")
435
+
436
+
437
+ # ============================================================
438
+ # BUILD PIPELINE
439
+ # ============================================================
440
+ log("\n===================================================")
441
+ log("🔧 BUILDING PIPELINE")
442
+ log("===================================================")
443
+
444
+ log_system_stats("START PIPELINE BUILD")
445
+
446
+ try:
447
+ pipe = ZImagePipeline.from_pretrained(
448
+ model_id,
449
+ transformer=transformer,
450
+ text_encoder=text_encoder,
451
+ torch_dtype=torch_dtype,
452
+ )
453
+
454
+ # Prefer flash attention if supported
455
+ try:
456
+ if hasattr(pipe, "transformer") and hasattr(pipe.transformer, "set_attention_backend"):
457
+ pipe.transformer.set_attention_backend("_flash_3")
458
+ log("✅ transformer.set_attention_backend('_flash_3') called")
459
+ except Exception as _e:
460
+ log(f"⚠️ set_attention_backend failed: {_e}")
461
+
462
+ # 🚫 NO default LoRA here
463
+ # 🚫 NO fuse
464
+ # 🚫 NO unload
465
+
466
+ pipe.to("cuda")
467
+ log("✅ Pipeline built successfully.")
468
+ LOGS += log("Pipeline build completed.") + "\n"
469
+
470
+ except Exception as e:
471
+ log(f"❌ Pipeline build failed: {e}")
472
+ log(traceback.format_exc())
473
+ pipe = None
474
+
475
+
476
+ log_system_stats("AFTER PIPELINE BUILD")
477
+
478
+
479
+ # -----------------------------
480
+ # Monkey-patch prepare_latents (safe)
481
+ # -----------------------------
482
+ if pipe is not None and hasattr(pipe, "prepare_latents"):
483
+ original_prepare_latents = pipe.prepare_latents
484
+
485
+ def logged_prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
486
+ try:
487
+ result_latents = original_prepare_latents(batch_size, num_channels_latents, height, width, dtype, device, generator, latents)
488
+ log_msg = f"🔹 prepare_latents called | shape={result_latents.shape}, dtype={result_latents.dtype}, device={result_latents.device}"
489
+ if hasattr(self, "_latents_log"):
490
+ self._latents_log.append(log_msg)
491
+ else:
492
+ self._latents_log = [log_msg]
493
+ return result_latents
494
+ except Exception as e:
495
+ log(f"⚠️ prepare_latents wrapper failed: {e}")
496
+ raise
497
+
498
+ # apply patch safely
499
+ try:
500
+ pipe.prepare_latents = logged_prepare_latents.__get__(pipe)
501
+ log("✅ prepare_latents monkey-patched")
502
+ except Exception as e:
503
+ log(f"⚠️ Failed to attach prepare_latents patch: {e}")
504
+ else:
505
+ log("❌ WARNING: Pipe not initialized or prepare_latents missing; skipping prepare_latents patch")
506
+
507
+
508
+ from PIL import Image
509
+ import torch
510
+
511
+ # --------------------------
512
+ # Helper: Safe latent extractor
513
+ # --------------------------
514
+ def safe_get_latents(pipe, height, width, generator, device, LOGS):
515
+ """
516
+ Safely prepare latents for any ZImagePipeline variant.
517
+ Returns latents tensor, logs issues instead of failing.
518
+ """
519
+ try:
520
+ # Determine number of channels
521
+ num_channels = 4 # default fallback
522
+ if hasattr(pipe, "unet") and hasattr(pipe.unet, "in_channels"):
523
+ num_channels = pipe.unet.in_channels
524
+ elif hasattr(pipe, "vae") and hasattr(pipe.vae, "latent_channels"):
525
+ num_channels = pipe.vae.latent_channels # some pipelines define this
526
+ LOGS.append(f"🔹 Using num_channels={num_channels} for latents")
527
+
528
+ latents = pipe.prepare_latents(
529
+ batch_size=1,
530
+ num_channels_latents=num_channels,
531
+ height=height,
532
+ width=width,
533
+ dtype=torch.float32,
534
+ device=device,
535
+ generator=generator,
536
+ )
537
+
538
+ LOGS.append(f"🔹 Latents shape: {latents.shape}, dtype: {latents.dtype}, device: {latents.device}")
539
+ return latents
540
+ except Exception as e:
541
+ LOGS.append(f"⚠️ Latent extraction failed: {e}")
542
+ # fallback: guess a safe shape
543
+ fallback_channels = 16 # try standard default for ZImage pipelines
544
+ latents = torch.randn((1, fallback_channels, height // 8, width // 8),
545
+ generator=generator, device=device)
546
+ LOGS.append(f"🔹 Using fallback random latents shape: {latents.shape}")
547
+ return latents
548
+
549
+ # --------------------------
550
+ # Main generation function (kept exactly as your logic)
551
+ # --------------------------
552
+ from huggingface_hub import HfApi, HfFolder
553
+ import torch
554
+ import os
555
+
556
+ HF_REPO_ID = "rahul7star/Zstudio-latent" # Model repo
557
+ HF_TOKEN = HfFolder.get_token() # Make sure you are logged in via `huggingface-cli login`
558
+
559
+ def upload_latents_to_hf(latent_dict, filename="latents.pt"):
560
+ local_path = f"/tmp/{filename}"
561
+ torch.save(latent_dict, local_path)
562
+ try:
563
+ api = HfApi()
564
+ api.upload_file(
565
+ path_or_fileobj=local_path,
566
+ path_in_repo=filename,
567
+ repo_id=HF_REPO_ID,
568
+ token=HF_TOKEN,
569
+ repo_type="model" # since this is a model repo
570
+ )
571
+ os.remove(local_path)
572
+ return f"https://huggingface.co/{HF_REPO_ID}/resolve/main/{filename}"
573
+ except Exception as e:
574
+ os.remove(local_path)
575
+ raise e
576
+
577
+
578
+
579
+ import asyncio
580
+ import torch
581
+ from PIL import Image
582
+
583
+ async def async_upload_latents(latent_dict, filename, LOGS):
584
+ try:
585
+ hf_url = await upload_latents_to_hf(latent_dict, filename=filename) # assume this can be async
586
+ LOGS.append(f"🔹 All preview latents uploaded: {hf_url}")
587
+ except Exception as e:
588
+ LOGS.append(f"⚠️ Failed to upload all preview latents: {e}")
589
+
590
+
591
+ # this code genetae all frame for latest GPU expseinve bt decide fails sp use this later
592
+ @spaces.GPU
593
+ def generate_image_all_latents(prompt, height, width, steps, seed, guidance_scale=0.0):
594
+ LOGS = []
595
+ device = "cpu" # FORCE CPU
596
+ generator = torch.Generator(device).manual_seed(int(seed))
597
+
598
+ placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
599
+ latent_gallery = []
600
+ final_gallery = []
601
+
602
+ last_four_latents = [] # we only upload 4
603
+
604
+ # --------------------------------------------------
605
+ # LATENT PREVIEW GENERATION (CPU MODE)
606
+ # --------------------------------------------------
607
+ try:
608
+ latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
609
+ latents = latents.to("cpu") # keep EVERYTHING CPU
610
+
611
+ timestep_count = len(pipe.scheduler.timesteps)
612
+ preview_every = max(1, timestep_count // 10)
613
+
614
+ for i, t in enumerate(pipe.scheduler.timesteps):
615
+
616
+ # -------------- decode latent preview --------------
617
+ try:
618
+ with torch.no_grad():
619
+ latent_cpu = latents.to(pipe.vae.dtype) # match VAE dtype
620
+ decoded = pipe.vae.decode(latent_cpu).sample # [1,3,H,W]
621
+ decoded = (decoded / 2 + 0.5).clamp(0, 1)
622
+ decoded = decoded[0].permute(1,2,0).cpu().numpy()
623
+ latent_img = Image.fromarray((decoded * 255).astype("uint8"))
624
+ except Exception:
625
+ latent_img = placeholder
626
+ LOGS.append("⚠️ Latent preview decode failed.")
627
+
628
+ latent_gallery.append(latent_img)
629
+
630
+ # store last 4 latent states
631
+ if len(last_four_latents) >= 4:
632
+ last_four_latents.pop(0)
633
+ last_four_latents.append(latents.cpu().clone())
634
+
635
+ # UI preview yields
636
+ if i % preview_every == 0:
637
+ yield None, latent_gallery, LOGS
638
+
639
+ # --------------------------------------------------
640
+ # UPLOAD LAST 4 LATENTS (SYNC)
641
+ # --------------------------------------------------
642
+ try:
643
+ upload_dict = {
644
+ "last_4_latents": last_four_latents,
645
+ "prompt": prompt,
646
+ "seed": seed
647
+ }
648
+
649
+ hf_url = upload_latents_to_hf(
650
+ upload_dict,
651
+ filename=f"latents_last4_{seed}.pt"
652
+ )
653
+
654
+ LOGS.append(f"🔹 Uploaded last 4 latents: {hf_url}")
655
+
656
+ except Exception as e:
657
+ LOGS.append(f"⚠️ Failed to upload latents: {e}")
658
+
659
+ except Exception as e:
660
+ LOGS.append(f"⚠️ Latent generation failed: {e}")
661
+ latent_gallery.append(placeholder)
662
+ yield None, latent_gallery, LOGS
663
+
664
+ # --------------------------------------------------
665
+ # FINAL IMAGE - UNTOUCHED
666
+ # --------------------------------------------------
667
+ try:
668
+ output = pipe(
669
+ prompt=prompt,
670
+ height=height,
671
+ width=width,
672
+ num_inference_steps=steps,
673
+ guidance_scale=guidance_scale,
674
+ generator=generator,
675
+ )
676
+ final_img = output.images[0]
677
+ LOGS.append("✅ Standard pipeline succeeded.")
678
+
679
+ yield final_img, latent_gallery, LOGS
680
+
681
+ except Exception as e2:
682
+ LOGS.append(f"❌ Standard pipeline failed: {e2}")
683
+ yield placeholder, latent_gallery, LOGS
684
+
685
+ @spaces.GPU
686
+ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
687
+ LOGS = []
688
+ device = "cuda"
689
+ cpu_device = "cpu"
690
+ generator = torch.Generator(device).manual_seed(int(seed))
691
+
692
+ placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
693
+ latent_gallery = []
694
+ final_gallery = []
695
+
696
+ last_latents = [] # store last 5 preview latents on CPU
697
+
698
+ try:
699
+ # --- Initial latents ---
700
+ latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
701
+ latents = latents.float().to(cpu_device) # move to CPU
702
+
703
+ num_previews = min(10, steps)
704
+ preview_indices = torch.linspace(0, steps - 1, num_previews).long()
705
+
706
+ for i, step_idx in enumerate(preview_indices):
707
+ try:
708
+ with torch.no_grad():
709
+ # --- Z-Image Turbo-style denoise simulation ---
710
+ t = 1.0 - (i / num_previews) # linear decay [1.0 -> 0.0]
711
+ noise_scale = t ** 0.5 # reduce noise over steps (sqrt for smoother)
712
+ denoise_latent = latents * t + torch.randn_like(latents) * noise_scale
713
+
714
+ # Move to VAE device & dtype
715
+ denoise_latent = denoise_latent.to(pipe.vae.device).to(pipe.vae.dtype)
716
+
717
+ # Decode latent to image
718
+ decoded = pipe.vae.decode(denoise_latent, return_dict=False)[0]
719
+ decoded = (decoded / 2 + 0.5).clamp(0, 1)
720
+ decoded = decoded.cpu().permute(0, 2, 3, 1).float().numpy()
721
+ decoded = (decoded * 255).round().astype("uint8")
722
+ latent_img = Image.fromarray(decoded[0])
723
+
724
+ except Exception as e:
725
+ LOGS.append(f"⚠️ Latent preview decode failed: {e}")
726
+ latent_img = placeholder
727
+
728
+ latent_gallery.append(latent_img)
729
+
730
+ # Keep last 5 latents only
731
+ last_latents.append(denoise_latent.cpu().clone())
732
+ if len(last_latents) > 5:
733
+ last_latents.pop(0)
734
+
735
+ # Show only last 5 previews in UI
736
+ yield None, latent_gallery[-5:], LOGS
737
+
738
+ # Optionally: upload last 5 latents
739
+ # latent_dict = {"latents": last_latents, "prompt": prompt, "seed": seed}
740
+ # hf_url = upload_latents_to_hf(latent_dict, filename=f"latents_last5_{seed}.pt")
741
+ # LOGS.append(f"🔹 Last 5 latents uploaded: {hf_url}")
742
+
743
+ except Exception as e:
744
+ LOGS.append(f"⚠️ Latent generation failed: {e}")
745
+ latent_gallery.append(placeholder)
746
+ yield None, latent_gallery[-5:], LOGS
747
+
748
+ # --- Final image on GPU ---
749
+ try:
750
+ output = pipe(
751
+ prompt=prompt,
752
+ height=height,
753
+ width=width,
754
+ num_inference_steps=steps,
755
+ guidance_scale=guidance_scale,
756
+ generator=generator,
757
+ )
758
+ final_img = output.images[0]
759
+ final_gallery.append(final_img)
760
+ latent_gallery.append(final_img)
761
+ LOGS.append("✅ Standard pipeline succeeded.")
762
+ yield final_img, latent_gallery[-5:] + [final_img], LOGS # last 5 previews + final
763
+
764
+ except Exception as e2:
765
+ LOGS.append(f"❌ Standard pipeline failed: {e2}")
766
+ final_gallery.append(placeholder)
767
+ latent_gallery.append(placeholder)
768
+ yield placeholder, latent_gallery[-5:] + [placeholder], LOGS
769
+
770
+
771
+
772
+ # this is astable vesopn tha can gen final and a noise to latent
773
+ @spaces.GPU
774
+ def generate_image_verygood_realnoise(prompt, height, width, steps, seed, guidance_scale=0.0):
775
+ LOGS = []
776
+ device = "cuda"
777
+ generator = torch.Generator(device).manual_seed(int(seed))
778
+
779
+ placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
780
+ latent_gallery = []
781
+ final_gallery = []
782
+
783
+ # --- Generate latent previews ---
784
+ try:
785
+ latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
786
+ latents = latents.float() # keep float32 until decode
787
+
788
+ num_previews = min(10, steps)
789
+ preview_steps = torch.linspace(0, 1, num_previews)
790
+
791
+ for alpha in preview_steps:
792
+ try:
793
+ with torch.no_grad():
794
+ # Simulate denoising progression like Z-Image Turbo
795
+ preview_latent = latents * alpha + latents * 0 # optional: simple progression
796
+
797
+ # Move to same device and dtype as VAE
798
+ preview_latent = preview_latent.to(pipe.vae.device).to(pipe.vae.dtype)
799
+
800
+ # Decode
801
+ decoded = pipe.vae.decode(preview_latent, return_dict=False)[0]
802
+
803
+ # Convert to PIL following same logic as final image
804
+ decoded = (decoded / 2 + 0.5).clamp(0, 1)
805
+ decoded = decoded.cpu().permute(0, 2, 3, 1).float().numpy()
806
+ decoded = (decoded * 255).round().astype("uint8")
807
+ latent_img = Image.fromarray(decoded[0])
808
+
809
+ except Exception as e:
810
+ LOGS.append(f"⚠️ Latent preview decode failed: {e}")
811
+ latent_img = placeholder
812
+
813
+ latent_gallery.append(latent_img)
814
+ yield None, latent_gallery, LOGS
815
+
816
+ except Exception as e:
817
+ LOGS.append(f"⚠️ Latent generation failed: {e}")
818
+ latent_gallery.append(placeholder)
819
+ yield None, latent_gallery, LOGS
820
+
821
+ # --- Final image: untouched ---
822
+ try:
823
+ output = pipe(
824
+ prompt=prompt,
825
+ height=height,
826
+ width=width,
827
+ num_inference_steps=steps,
828
+ guidance_scale=guidance_scale,
829
+ generator=generator,
830
+ )
831
+ final_img = output.images[0]
832
+ final_gallery.append(final_img)
833
+ latent_gallery.append(final_img) # fallback preview
834
+ LOGS.append("✅ Standard pipeline succeeded.")
835
+ yield final_img, latent_gallery, LOGS
836
+
837
+ except Exception as e2:
838
+ LOGS.append(f"❌ Standard pipeline failed: {e2}")
839
+ final_gallery.append(placeholder)
840
+ latent_gallery.append(placeholder)
841
+ yield placeholder, latent_gallery, LOGS
842
+
843
+
844
+
845
+
846
+ # DO NOT TOUCH this is astable vesopn tha can gen final and a noise to latent with latent upload to repo
847
+ @spaces.GPU
848
+ def generate_image_safe(prompt, height, width, steps, seed, guidance_scale=0.0):
849
+ LOGS = []
850
+ device = "cuda"
851
+ generator = torch.Generator(device).manual_seed(int(seed))
852
+
853
+ placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
854
+ latent_gallery = []
855
+ final_gallery = []
856
+
857
+ # --- Generate latent previews in a loop ---
858
+ try:
859
+ latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
860
+
861
+ # Convert latents to float32 if necessary
862
+ if latents.dtype != torch.float32:
863
+ latents = latents.float()
864
+
865
+ # Loop for multiple previews before final image
866
+ num_previews = min(10, steps) # show ~10 previews
867
+ preview_steps = torch.linspace(0, 1, num_previews)
868
+
869
+ for i, alpha in enumerate(preview_steps):
870
+ try:
871
+ with torch.no_grad():
872
+ # Simple noise interpolation for preview (simulate denoising progress)
873
+ preview_latent = latents * alpha + torch.randn_like(latents) * (1 - alpha)
874
+ # Decode to PIL
875
+ latent_img_tensor = pipe.vae.decode(preview_latent).sample # [1,3,H,W]
876
+ latent_img_tensor = (latent_img_tensor / 2 + 0.5).clamp(0, 1)
877
+ latent_img_tensor = latent_img_tensor.cpu().permute(0, 2, 3, 1)[0]
878
+ latent_img = Image.fromarray((latent_img_tensor.numpy() * 255).astype('uint8'))
879
+ except Exception as e:
880
+ LOGS.append(f"⚠️ Latent preview decode failed: {e}")
881
+ latent_img = placeholder
882
+
883
+ latent_gallery.append(latent_img)
884
+ yield None, latent_gallery, LOGS # update Gradio with intermediate preview
885
+
886
+ # Save final latents to HF
887
+ latent_dict = {"latents": latents.cpu(), "prompt": prompt, "seed": seed}
888
+ try:
889
+ hf_url = upload_latents_to_hf(latent_dict, filename=f"latents_{seed}.pt")
890
+ LOGS.append(f"🔹 Latents uploaded: {hf_url}")
891
+ except Exception as e:
892
+ LOGS.append(f"⚠️ Failed to upload latents: {e}")
893
+
894
+ except Exception as e:
895
+ LOGS.append(f"⚠️ Latent generation failed: {e}")
896
+ latent_gallery.append(placeholder)
897
+ yield None, latent_gallery, LOGS
898
+
899
+ # --- Final image: untouched standard pipeline ---
900
+ try:
901
+ output = pipe(
902
+ prompt=prompt,
903
+ height=height,
904
+ width=width,
905
+ num_inference_steps=steps,
906
+ guidance_scale=guidance_scale,
907
+ generator=generator,
908
+ )
909
+ final_img = output.images[0]
910
+ final_gallery.append(final_img)
911
+ latent_gallery.append(final_img) # fallback preview if needed
912
+ LOGS.append("✅ Standard pipeline succeeded.")
913
+ yield final_img, latent_gallery, LOGS
914
+
915
+ except Exception as e2:
916
+ LOGS.append(f"❌ Standard pipeline failed: {e2}")
917
+ final_gallery.append(placeholder)
918
+ latent_gallery.append(placeholder)
919
+ yield placeholder, latent_gallery, LOGS
920
+
921
+
922
+
923
+
924
+
925
+
926
+
927
+ import gradio as gr
928
+
929
+ with gr.Blocks(title="Z-Image-Turbo") as demo:
930
+ gr.Markdown("# 🎨 Z-Image-Turbo (LoRA-enabled UI)")
931
+
932
+ # =========================
933
+ # MAIN TABS
934
+ # =========================
935
+ with gr.Tabs():
936
+
937
+ # -------- Image Tab --------
938
+ with gr.TabItem("Image & Latents"):
939
+ with gr.Row():
940
+ with gr.Column(scale=1):
941
+ prompt = gr.Textbox(
942
+ label="Prompt",
943
+ value="boat in Ocean"
944
+ )
945
+ height = gr.Slider(
946
+ 256, 2048, value=1024, step=8, label="Height"
947
+ )
948
+ width = gr.Slider(
949
+ 256, 2048, value=1024, step=8, label="Width"
950
+ )
951
+ steps = gr.Slider(
952
+ 1, 50, value=20, step=1, label="Inference Steps"
953
+ )
954
+ seed = gr.Number(
955
+ value=42, label="Seed"
956
+ )
957
+ run_btn = gr.Button("🚀 Generate Image")
958
+
959
+ with gr.Column(scale=1):
960
+ final_image = gr.Image(label="Final Image")
961
+ latent_gallery = gr.Gallery(
962
+ label="Latent Steps",
963
+ columns=4,
964
+ height=256,
965
+ preview=True,
966
+ )
967
+
968
+ # -------- Logs Tab --------
969
+ with gr.TabItem("Logs"):
970
+ logs_box = gr.Textbox(
971
+ label="Logs",
972
+ lines=25,
973
+ interactive=False
974
+ )
975
+
976
+ # =========================
977
+ # LoRA CONTROLS
978
+ # =========================
979
+ gr.Markdown("## 🧩 LoRA Controls")
980
+
981
+ with gr.Row():
982
+ lora_repo = gr.Textbox(
983
+ label="LoRA Repo (HF)",
984
+ value="rahul7star/ZImageLora",
985
+ placeholder="username/repo"
986
+ )
987
+
988
+ lora_file = gr.Dropdown(
989
+ label="LoRA file (.safetensors)",
990
+ choices=[]
991
+ )
992
+
993
+ lora_strength = gr.Slider(
994
+ 0.0, 2.0, value=1.0, step=0.05, label="LoRA strength"
995
+ )
996
+
997
+ with gr.Row():
998
+ refresh_lora_btn = gr.Button("🔄 Refresh LoRA List")
999
+ apply_lora_btn = gr.Button("✅ Apply LoRA")
1000
+ clear_lora_btn = gr.Button("❌ Clear LoRA")
1001
+
1002
+ # =========================
1003
+ # CALLBACKS
1004
+ # =========================
1005
+
1006
+
1007
+ def refresh_lora_list(repo_name):
1008
+ try:
1009
+ files = list_loras_from_repo(repo_name)
1010
+ if not files:
1011
+ log(f"⚠️ No LoRA files found in {repo_name}")
1012
+ return gr.update(choices=[], value=None)
1013
+
1014
+ log(f"📦 Found {len(files)} LoRA files in {repo_name}")
1015
+ return gr.update(choices=files, value=files[0])
1016
+
1017
+ except Exception as e:
1018
+ log(f"❌ Failed to list LoRA files: {e}")
1019
+ return gr.update(choices=[], value=None)
1020
+
1021
+ refresh_lora_btn.click(
1022
+ refresh_lora_list,
1023
+ inputs=[lora_repo],
1024
+ outputs=[lora_file]
1025
+ )
1026
+
1027
+ def apply_lora(repo_name, lora_filename, strength):
1028
+ global pipe
1029
+
1030
+ if pipe is None:
1031
+ return "❌ Pipeline not initialized"
1032
+
1033
+ if not lora_filename:
1034
+ return "⚠️ No LoRA file selected"
1035
+
1036
+ try:
1037
+ pipe.load_lora_weights(
1038
+ repo_name,
1039
+ weight_name=lora_filename,
1040
+ adapter_name="ui_lora"
1041
+ )
1042
+ pipe.set_adapters(["ui_lora"], [strength])
1043
+
1044
+ log(f"✅ Applied LoRA: {repo_name}/{lora_filename} (strength={strength})")
1045
+
1046
+ if hasattr(pipe, "peft_config"):
1047
+ log(f"🎯 Active adapters: {list(pipe.peft_config.keys())}")
1048
+
1049
+ return "LoRA applied"
1050
+
1051
+ except Exception as e:
1052
+ log(f"❌ Failed to apply LoRA: {e}")
1053
+ return f"Failed: {e}"
1054
+
1055
+ apply_lora_btn.click(
1056
+ apply_lora,
1057
+ inputs=[lora_repo, lora_file, lora_strength],
1058
+ outputs=[logs_box]
1059
+ )
1060
+
1061
+ def clear_lora():
1062
+ global pipe
1063
+ if pipe is None:
1064
+ return "❌ Pipeline not initialized"
1065
+
1066
+ try:
1067
+ pipe.set_adapters([], [])
1068
+ log("🧹 LoRA cleared")
1069
+ return "LoRA cleared"
1070
+ except Exception as e:
1071
+ log(f"❌ Failed to clear LoRA: {e}")
1072
+ return f"Failed: {e}"
1073
+
1074
+ clear_lora_btn.click(
1075
+ clear_lora,
1076
+ outputs=[logs_box]
1077
+ )
1078
+
1079
+ # =========================
1080
+ # GENERATION
1081
+ # =========================
1082
+ run_btn.click(
1083
+ generate_image,
1084
+ inputs=[prompt, height, width, steps, seed],
1085
+ outputs=[final_image, latent_gallery, logs_box]
1086
+ )
1087
+
1088
+
1089
+
1090
+
1091
+ demo.launch()