rahul7star commited on
Commit
0e866d9
·
verified ·
1 Parent(s): ab410e0

Update app_lora1.py

Browse files
Files changed (1) hide show
  1. app_lora1.py +94 -90
app_lora1.py CHANGED
@@ -21,10 +21,10 @@ MODEL_ID = "Tongyi-MAI/Z-Image-Turbo"
21
  os.makedirs(LOCAL_SCRIPTS_DIR, exist_ok=True)
22
 
23
  # =========================================================
24
- # GLOBAL STATE (CPU SAFE)
25
  # =========================================================
26
- SCRIPT_CODE = {} # script_name -> code (CPU only)
27
- PIPELINES = {} # script_name -> pipeline (GPU only, lazy)
28
  log_buffer = io.StringIO()
29
 
30
 
@@ -39,39 +39,29 @@ def log(msg):
39
  def pipeline_technology_info(pipe):
40
  tech = []
41
 
42
- # Device map
43
  if hasattr(pipe, "hf_device_map"):
44
  tech.append("Device map: enabled")
45
  else:
46
  tech.append(f"Device: {pipe.device}")
47
 
48
- # Transformer dtype
49
  if hasattr(pipe, "transformer"):
50
  try:
51
  tech.append(f"Transformer dtype: {pipe.transformer.dtype}")
52
  except Exception:
53
  pass
54
-
55
- # Layerwise casting (Z-Image specific)
56
  if hasattr(pipe.transformer, "layerwise_casting"):
57
  lw = pipe.transformer.layerwise_casting
58
- tech.append(
59
- f"Layerwise casting: storage={lw.storage_dtype}, "
60
- f"compute={lw.compute_dtype}"
61
- )
62
 
63
- # VAE dtype
64
  if hasattr(pipe, "vae"):
65
  try:
66
  tech.append(f"VAE dtype: {pipe.vae.dtype}")
67
  except Exception:
68
  pass
69
 
70
- # GGUF / quantization
71
  if hasattr(pipe, "quantization_config"):
72
  tech.append(f"Quantization: {pipe.quantization_config}")
73
 
74
- # Attention backend
75
  if hasattr(pipe, "config"):
76
  attn = pipe.config.get("attn_implementation", None)
77
  if attn:
@@ -106,7 +96,23 @@ def register_pipeline_feature(pipe, text: str):
106
 
107
 
108
  # =========================================================
109
- # DOWNLOAD SCRIPTS (CPU ONLY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  # =========================================================
111
  def download_scripts():
112
  resp = requests.get(SCRIPTS_REPO_API)
@@ -129,21 +135,19 @@ SCRIPT_NAMES = download_scripts()
129
 
130
 
131
  # =========================================================
132
- # REGISTER SELECTED SCRIPTS (NO CUDA)
133
  # =========================================================
134
  def register_scripts(selected_scripts):
135
  SCRIPT_CODE.clear()
136
-
137
  for name in selected_scripts:
138
  path = os.path.join(LOCAL_SCRIPTS_DIR, name)
139
  with open(path, "r") as f:
140
  SCRIPT_CODE[name] = f.read()
141
-
142
  return f"{len(SCRIPT_CODE)} script(s) registered ✅"
143
 
144
 
145
  # =========================================================
146
- # GPU-ONLY PIPELINE BUILDER (CRITICAL)
147
  # =========================================================
148
  def get_pipeline(script_name):
149
  if script_name in PIPELINES:
@@ -154,10 +158,9 @@ def get_pipeline(script_name):
154
  namespace = {
155
  "__file__": script_name,
156
  "__name__": "__main__",
157
-
158
- # Minimal required globals
159
  "torch": torch,
160
  "register_pipeline_feature": register_pipeline_feature,
 
161
  }
162
 
163
  try:
@@ -167,19 +170,27 @@ def get_pipeline(script_name):
167
  raise RuntimeError(f"Pipeline build failed for {script_name}") from e
168
 
169
  if "pipe" not in namespace:
170
- raise RuntimeError(
171
- f"{script_name} did not define `pipe`.\n"
172
- f"Each script MUST assign a variable named `pipe`."
173
- )
 
 
 
 
174
 
175
- PIPELINES[script_name] = namespace["pipe"]
 
 
 
 
176
  log(f"✅ Pipeline ready: {script_name}")
177
 
178
- return PIPELINES[script_name]
179
 
180
 
181
  # =========================================================
182
- # IMAGE GENERATION (LOGIC UNCHANGED)
183
  # =========================================================
184
  @spaces.GPU
185
  def generate_image(
@@ -200,13 +211,6 @@ def generate_image(
200
 
201
  pipe = get_pipeline(pipeline_name)
202
 
203
- # ✅ Correct, universal, ZeroGPU-safe
204
- if not hasattr(pipe, "hf_device_map"):
205
- pipe = pipe.to("cuda")
206
-
207
- # =========================================================
208
- # LOG PIPELINE TECHNOLOGY AND REGISTERED FEATURES
209
- # =========================================================
210
  log("=== PIPELINE TECHNOLOGY ===")
211
  log(pipeline_technology_info(pipe))
212
 
@@ -217,9 +221,6 @@ def generate_image(
217
  else:
218
  log("✔ No explicit pipeline features registered")
219
 
220
- # =========================================================
221
- # GENERATION LOG
222
- # =========================================================
223
  log("=== NEW GENERATION REQUEST ===")
224
  log(f"Pipeline: {pipeline_name}")
225
  log(f"Prompt: {prompt}")
@@ -248,11 +249,15 @@ def generate_image(
248
  output_type="pil",
249
  )
250
 
251
- # Resize images to 512x512 for Gradio
 
252
  fixed_images = []
253
  for img in result.images:
254
  if isinstance(img, Image.Image):
255
- img = img.resize((512, 512), Image.BICUBIC)
 
 
 
256
  fixed_images.append(img)
257
 
258
  try:
@@ -269,68 +274,67 @@ def generate_image(
269
  # =========================================================
270
  # GRADIO UI
271
  # =========================================================
272
- with gr.Blocks(title="Z-Image Turbo ZeroGPU") as demo:
273
- gr.Markdown("## Z-Image Turbo (Script-Driven · ZeroGPU Safe)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  script_selector = gr.CheckboxGroup(
276
  choices=SCRIPT_NAMES,
277
  label="Select pipeline scripts",
278
  )
279
 
280
- register_btn = gr.Button("Register Scripts")
281
- status = gr.Textbox(label="Status", interactive=False)
282
-
283
  register_btn.click(
284
  register_scripts,
285
  inputs=[script_selector],
286
  outputs=[status],
287
  )
288
 
289
- pipeline_picker = gr.Dropdown(
290
- choices=[],
291
- label="Active Pipeline",
292
- )
293
-
294
- register_btn.click(
295
- lambda s: gr.update(choices=s, value=s[0] if s else None),
296
- inputs=[script_selector],
297
- outputs=[pipeline_picker],
298
- )
299
-
300
- gr.Markdown("---")
301
-
302
- prompt = gr.Textbox(label="Prompt", lines=3)
303
- height = gr.Slider(256, 1024, 512, step=64, label="Height")
304
- width = gr.Slider(256, 1024, 512, step=64, label="Width")
305
- steps = gr.Slider(1, 8, 4, step=1, label="Inference Steps")
306
- images = gr.Slider(1, 3, 1, step=1, label="Images")
307
- seed = gr.Number(value=0, label="Seed")
308
- random_seed = gr.Checkbox(value=True, label="Randomize Seed")
309
-
310
- run_btn = gr.Button("Generate")
311
-
312
- gallery = gr.Gallery(
313
- columns=3,
314
- height=512,
315
- object_fit="contain",
316
- label="Output (512×512)"
317
- )
318
- used_seed = gr.Number(label="Used Seed")
319
- logs = gr.Textbox(lines=12, label="Logs")
320
-
321
- run_btn.click(
322
- generate_image,
323
- inputs=[
324
- prompt,
325
- height,
326
- width,
327
- steps,
328
- seed,
329
- random_seed,
330
- images,
331
- pipeline_picker,
332
- ],
333
- outputs=[gallery, used_seed, logs],
334
  )
335
 
336
  demo.queue()
 
21
  os.makedirs(LOCAL_SCRIPTS_DIR, exist_ok=True)
22
 
23
  # =========================================================
24
+ # GLOBAL STATE
25
  # =========================================================
26
+ SCRIPT_CODE = {}
27
+ PIPELINES = {}
28
  log_buffer = io.StringIO()
29
 
30
 
 
39
  def pipeline_technology_info(pipe):
40
  tech = []
41
 
 
42
  if hasattr(pipe, "hf_device_map"):
43
  tech.append("Device map: enabled")
44
  else:
45
  tech.append(f"Device: {pipe.device}")
46
 
 
47
  if hasattr(pipe, "transformer"):
48
  try:
49
  tech.append(f"Transformer dtype: {pipe.transformer.dtype}")
50
  except Exception:
51
  pass
 
 
52
  if hasattr(pipe.transformer, "layerwise_casting"):
53
  lw = pipe.transformer.layerwise_casting
54
+ tech.append(f"Layerwise casting: storage={lw.storage_dtype}, compute={lw.compute_dtype}")
 
 
 
55
 
 
56
  if hasattr(pipe, "vae"):
57
  try:
58
  tech.append(f"VAE dtype: {pipe.vae.dtype}")
59
  except Exception:
60
  pass
61
 
 
62
  if hasattr(pipe, "quantization_config"):
63
  tech.append(f"Quantization: {pipe.quantization_config}")
64
 
 
65
  if hasattr(pipe, "config"):
66
  attn = pipe.config.get("attn_implementation", None)
67
  if attn:
 
96
 
97
 
98
  # =========================================================
99
+ # WRAPPER TO LOG ANY METHOD CALL ON PIPE OR TRANSFORMER
100
+ # =========================================================
101
+ def log_pipe_calls(obj, obj_name="pipe"):
102
+ for attr_name in dir(obj):
103
+ attr = getattr(obj, attr_name)
104
+ if callable(attr) and not attr_name.startswith("_"):
105
+ def make_wrapper(f, name):
106
+ def wrapper(*args, **kwargs):
107
+ log(f"• {obj_name}.{name} called with args={args}, kwargs={kwargs}")
108
+ return f(*args, **kwargs)
109
+ return wrapper
110
+ setattr(obj, attr_name, make_wrapper(attr, attr_name))
111
+ return obj
112
+
113
+
114
+ # =========================================================
115
+ # DOWNLOAD SCRIPTS
116
  # =========================================================
117
  def download_scripts():
118
  resp = requests.get(SCRIPTS_REPO_API)
 
135
 
136
 
137
  # =========================================================
138
+ # REGISTER SCRIPTS
139
  # =========================================================
140
  def register_scripts(selected_scripts):
141
  SCRIPT_CODE.clear()
 
142
  for name in selected_scripts:
143
  path = os.path.join(LOCAL_SCRIPTS_DIR, name)
144
  with open(path, "r") as f:
145
  SCRIPT_CODE[name] = f.read()
 
146
  return f"{len(SCRIPT_CODE)} script(s) registered ✅"
147
 
148
 
149
  # =========================================================
150
+ # BUILD PIPELINE (GPU)
151
  # =========================================================
152
  def get_pipeline(script_name):
153
  if script_name in PIPELINES:
 
158
  namespace = {
159
  "__file__": script_name,
160
  "__name__": "__main__",
 
 
161
  "torch": torch,
162
  "register_pipeline_feature": register_pipeline_feature,
163
+ "log_pipe_calls": log_pipe_calls,
164
  }
165
 
166
  try:
 
170
  raise RuntimeError(f"Pipeline build failed for {script_name}") from e
171
 
172
  if "pipe" not in namespace:
173
+ raise RuntimeError(f"{script_name} did not define `pipe`.")
174
+
175
+ pipe = namespace["pipe"]
176
+
177
+ # Wrap transformer and pipe to log method calls (post-pretrained modifications)
178
+ if hasattr(pipe, "transformer"):
179
+ pipe.transformer = log_pipe_calls(pipe.transformer, "pipe.transformer")
180
+ pipe = log_pipe_calls(pipe, "pipe")
181
 
182
+ # ZeroGPU-safe
183
+ if not hasattr(pipe, "hf_device_map"):
184
+ pipe = pipe.to("cuda")
185
+
186
+ PIPELINES[script_name] = pipe
187
  log(f"✅ Pipeline ready: {script_name}")
188
 
189
+ return pipe
190
 
191
 
192
  # =========================================================
193
+ # IMAGE GENERATION
194
  # =========================================================
195
  @spaces.GPU
196
  def generate_image(
 
211
 
212
  pipe = get_pipeline(pipeline_name)
213
 
 
 
 
 
 
 
 
214
  log("=== PIPELINE TECHNOLOGY ===")
215
  log(pipeline_technology_info(pipe))
216
 
 
221
  else:
222
  log("✔ No explicit pipeline features registered")
223
 
 
 
 
224
  log("=== NEW GENERATION REQUEST ===")
225
  log(f"Pipeline: {pipeline_name}")
226
  log(f"Prompt: {prompt}")
 
249
  output_type="pil",
250
  )
251
 
252
+ # Optional: scale down very large images for UI display
253
+ max_display_size = 1024
254
  fixed_images = []
255
  for img in result.images:
256
  if isinstance(img, Image.Image):
257
+ w, h = img.size
258
+ scale = min(max_display_size / max(w, h), 1.0)
259
+ if scale < 1.0:
260
+ img = img.resize((int(w * scale), int(h * scale)), Image.BICUBIC)
261
  fixed_images.append(img)
262
 
263
  try:
 
274
  # =========================================================
275
  # GRADIO UI
276
  # =========================================================
277
+ with gr.Blocks(title="Z-Image-Turbo Multi Image Demo") as demo:
278
+ gr.Markdown("# 🎨 Z-Image-Turbo Multi Image")
279
+
280
+ with gr.Row():
281
+ with gr.Column(scale=1):
282
+ prompt = gr.Textbox(label="Prompt", lines=4)
283
+
284
+ with gr.Row():
285
+ height = gr.Slider(512, 2048, 1024, step=64, label="Height")
286
+ width = gr.Slider(512, 2048, 1024, step=64, label="Width")
287
+
288
+ num_images = gr.Slider(1, 3, 2, step=1, label="Number of Images")
289
+
290
+ num_inference_steps = gr.Slider(
291
+ 1, 20, 9, step=1, label="Inference Steps",
292
+ info="9 steps = 8 DiT forward passes",
293
+ )
294
 
295
+ with gr.Row():
296
+ seed = gr.Number(label="Seed", value=42, precision=0)
297
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
298
+
299
+ # Select pipeline script
300
+ pipeline_picker = gr.Dropdown(
301
+ choices=SCRIPT_NAMES,
302
+ value=SCRIPT_NAMES[0] if SCRIPT_NAMES else None,
303
+ label="Active Pipeline Script",
304
+ )
305
+
306
+ generate_btn = gr.Button("🚀 Generate", variant="primary")
307
+
308
+ with gr.Column(scale=1):
309
+ output_images = gr.Gallery(
310
+ label="Generated Images",
311
+ height=512,
312
+ object_fit="contain"
313
+ )
314
+ used_seed = gr.Number(label="Seed Used", interactive=False)
315
+ debug_log = gr.Textbox(
316
+ label="Debug Log Output",
317
+ lines=25,
318
+ interactive=False
319
+ )
320
+
321
+ register_btn = gr.Button("Register Scripts")
322
+ status = gr.Textbox(label="Status", interactive=False)
323
  script_selector = gr.CheckboxGroup(
324
  choices=SCRIPT_NAMES,
325
  label="Select pipeline scripts",
326
  )
327
 
 
 
 
328
  register_btn.click(
329
  register_scripts,
330
  inputs=[script_selector],
331
  outputs=[status],
332
  )
333
 
334
+ generate_btn.click(
335
+ fn=generate_image,
336
+ inputs=[prompt, height, width, num_inference_steps, seed, randomize_seed, num_images, pipeline_picker],
337
+ outputs=[output_images, used_seed, debug_log],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  )
339
 
340
  demo.queue()