K1Z3M1112 commited on
Commit
a625bc0
·
verified ·
1 Parent(s): 43dcc74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +306 -183
app.py CHANGED
@@ -3,33 +3,70 @@ import numpy as np
3
  from PIL import Image
4
  import torch
5
  import gc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Device
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
10
 
 
 
11
  # Lazy import (to avoid long startup if unused)
12
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionPipeline
13
- from diffusers import StableDiffusionInstructPix2PixPipeline
14
  from controlnet_aux import LineartDetector, LineartAnimeDetector
15
 
 
 
 
 
 
 
 
 
 
16
  # ===== Model & Config =====
17
- PIPELINES = {} # key: (model_name, is_anime) -> pipeline
 
18
  LINEART_DETECTOR = None
19
  LINEART_ANIME_DETECTOR = None
20
  CURRENT_T2I_PIPE = None
21
  CURRENT_T2I_MODEL = None
22
- CURRENT_PIX2PIX_PIPE = None
23
- CURRENT_PIX2PIX_MODEL = None
24
 
25
  def get_pipeline(model_name: str, anime_model: bool = False):
26
  """Get or create a ControlNet pipeline for the given model and anime flag"""
 
 
27
  key = (model_name, anime_model)
28
 
29
- if key in PIPELINES:
30
- return PIPELINES[key]
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- print(f"Loading ControlNet pipeline for model: {model_name}, anime: {anime_model}")
33
 
34
  try:
35
  # โหลด ControlNet ที่เหมาะสม
@@ -50,20 +87,111 @@ def get_pipeline(model_name: str, anime_model: bool = False):
50
  controlnet=controlnet,
51
  torch_dtype=dtype,
52
  safety_checker=None,
53
- requires_safety_checker=False
 
 
54
  ).to(device)
55
 
56
- pipe.enable_attention_slicing()
 
 
 
 
 
 
 
 
 
 
 
 
57
  if device.type == "cuda":
 
 
 
 
 
 
 
 
 
58
  pipe.enable_model_cpu_offload()
59
 
60
- # เก็บ pipeline ไว้ใน cache
61
- PIPELINES[key] = pipe
 
 
 
 
 
 
 
 
 
 
62
  return pipe
63
 
64
  except Exception as e:
65
  print(f"Error loading ControlNet pipeline: {e}")
66
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def load_lineart_detectors():
69
  """Load lineart detectors if not already loaded"""
@@ -83,56 +211,97 @@ def load_t2i_model(model_name: str):
83
  if CURRENT_T2I_MODEL == model_name and CURRENT_T2I_PIPE is not None:
84
  return
85
  if CURRENT_T2I_PIPE is not None:
 
86
  del CURRENT_T2I_PIPE
87
  CURRENT_T2I_PIPE = None
88
  gc.collect()
89
  if torch.cuda.is_available():
90
  torch.cuda.empty_cache()
91
 
92
- print(f"Loading T2I model: {model_name}")
93
  CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
94
- model_name, torch_dtype=dtype, safety_checker=None, requires_safety_checker=False
 
 
 
 
 
95
  ).to(device)
96
- CURRENT_T2I_PIPE.enable_attention_slicing()
 
 
 
 
 
 
 
 
 
 
 
 
97
  if device.type == "cuda":
 
 
 
 
 
98
  CURRENT_T2I_PIPE.enable_model_cpu_offload()
 
 
 
 
 
 
 
 
 
99
  CURRENT_T2I_MODEL = model_name
100
 
101
  except Exception as e:
102
  print(f"Error loading T2I model {model_name}: {e}")
103
- # รีเซ็ตตัวแปรเมื่อโหลดไม่ส���เร็จ
104
- CURRENT_T2I_PIPE = None
105
- CURRENT_T2I_MODEL = None
106
- raise
107
-
108
- def load_pix2pix_model():
109
- """Load Instruct-Pix2Pix model for image editing"""
110
- global CURRENT_PIX2PIX_PIPE, CURRENT_PIX2PIX_MODEL
111
-
112
- if CURRENT_PIX2PIX_PIPE is not None:
113
- return CURRENT_PIX2PIX_PIPE
114
-
115
- try:
116
- print("Loading Instruct-Pix2Pix model...")
117
- CURRENT_PIX2PIX_PIPE = StableDiffusionInstructPix2PixPipeline.from_pretrained(
118
- "timbrooks/instruct-pix2pix",
119
- torch_dtype=dtype,
120
- safety_checker=None,
121
- requires_safety_checker=False
122
- ).to(device)
123
 
124
- CURRENT_PIX2PIX_PIPE.enable_attention_slicing()
125
- if device.type == "cuda":
126
- CURRENT_PIX2PIX_PIPE.enable_model_cpu_offload()
127
-
128
- CURRENT_PIX2PIX_MODEL = "timbrooks/instruct-pix2pix"
129
- return CURRENT_PIX2PIX_PIPE
130
-
131
- except Exception as e:
132
- print(f"Error loading Instruct-Pix2Pix model: {e}")
133
- CURRENT_PIX2PIX_PIPE = None
134
- CURRENT_PIX2PIX_MODEL = None
135
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  # ===== Utils =====
138
  def is_lineart(img: Image.Image) -> bool:
@@ -161,97 +330,82 @@ def resize_image(image, max_size=512):
161
  # ===== Functions =====
162
  def colorize(sketch, base_model, anime_model, prompt, seed, steps, scale, cn_weight):
163
  try:
164
- # โหลด pipeline ที่เหมาะสม
165
  pipe = get_pipeline(base_model, anime_model)
166
 
 
 
 
 
 
167
  # สกัด lineart
168
  lineart = extract_lineart(sketch, anime_model)
169
 
170
- # สร้างภาพ
171
  gen = torch.Generator(device=device).manual_seed(int(seed))
172
- out = pipe(
173
- prompt,
174
- image=lineart,
175
- num_inference_steps=int(steps),
176
- guidance_scale=float(scale),
177
- controlnet_conditioning_scale=float(cn_weight),
178
- generator=gen
179
- ).images[0]
 
 
 
 
 
 
180
 
181
  return out, lineart
182
  except Exception as e:
183
- print(f"Error in colorize: {e}")
184
- # ส่งกลับรูปภาพว่างพร้อมแสดง error
185
  error_img = Image.new('RGB', (512, 512), color='red')
186
- error_text = f"Error: {str(e)[:50]}..."
187
  return error_img, Image.new('RGB', (512, 512), color='gray')
188
 
189
  def t2i(prompt, model, seed, steps, scale, w, h):
190
  try:
191
  load_t2i_model(model)
192
- gen = torch.Generator(device=device).manual_seed(int(seed))
193
- return CURRENT_T2I_PIPE(
194
- prompt,
195
- width=int(w),
196
- height=int(h),
197
- num_inference_steps=int(steps),
198
- guidance_scale=float(scale),
199
- generator=gen
200
- ).images[0]
201
- except Exception as e:
202
- print(f"Error in t2i: {e}")
203
- # ส่งกลับรูปภาพว่างพร้อมแสดง error
204
- error_img = Image.new('RGB', (int(w), int(h)), color='red')
205
- return error_img
206
-
207
- def pix2pix_edit(image, instruction, seed, steps, scale, image_scale):
208
- """Edit image using Instruct-Pix2Pix"""
209
- try:
210
- # โหลดโมเดล
211
- pipe = load_pix2pix_model()
212
 
213
- # ปรับขนาดภาพ
214
- image = resize_image(image, max_size=768)
215
-
216
- # สร้าง generator
217
  gen = torch.Generator(device=device).manual_seed(int(seed))
218
 
219
- # แก้ไขภาพ
220
- result = pipe(
221
- instruction,
222
- image=image,
223
- num_inference_steps=int(steps),
224
- guidance_scale=float(scale),
225
- image_guidance_scale=float(image_scale),
226
- generator=gen
227
- ).images[0]
228
 
229
- return result
 
230
 
 
231
  except Exception as e:
232
- print(f"Error in pix2pix_edit: {e}")
233
- # ส่งกลับรูปภาพ error
234
- if image:
235
- error_img = Image.new('RGB', image.size, color='red')
236
- else:
237
- error_img = Image.new('RGB', (512, 512), color='red')
238
  return error_img
239
 
240
  # ===== Function to unload all models =====
241
  def unload_all_models():
242
- global PIPELINES, LINEART_DETECTOR, LINEART_ANIME_DETECTOR
 
243
  global CURRENT_T2I_PIPE, CURRENT_T2I_MODEL
244
- global CURRENT_PIX2PIX_PIPE, CURRENT_PIX2PIX_MODEL
245
 
246
  print("Unloading all models from memory...")
247
 
248
- # Unload ControlNet pipelines
249
- for key, pipe in list(PIPELINES.items()):
250
- try:
251
- del pipe
252
- except:
253
- pass
254
- PIPELINES.clear()
 
255
 
256
  # Unload lineart detectors
257
  try:
@@ -275,39 +429,43 @@ def unload_all_models():
275
  CURRENT_T2I_PIPE = None
276
  except:
277
  pass
278
-
279
  CURRENT_T2I_MODEL = None
280
 
281
- # Unload Pix2Pix model
282
- try:
283
- if CURRENT_PIX2PIX_PIPE is not None:
284
- del CURRENT_PIX2PIX_PIPE
285
- CURRENT_PIX2PIX_PIPE = None
286
- except:
287
- pass
288
-
289
- CURRENT_PIX2PIX_MODEL = None
290
-
291
  # Force garbage collection
292
  gc.collect()
293
  if torch.cuda.is_available():
294
  torch.cuda.empty_cache()
295
  allocated = torch.cuda.memory_allocated() / 1024**3
296
- print(f"GPU memory cleared. Allocated: {allocated:.2f} GB")
 
297
 
298
  return "✅ All models unloaded from memory!"
299
 
300
  # ===== Gradio UI =====
301
- with gr.Blocks() as demo:
302
  gr.Markdown("# 🎨 Advanced Image Generation & Editing Suite")
 
 
 
 
 
 
 
 
 
303
 
304
  # Add unload button at the top
305
  with gr.Row():
306
- unload_btn = gr.Button("🗑️ Unload All Models", variant="stop")
307
- status_text = gr.Textbox(label="Status", interactive=False)
308
  unload_btn.click(unload_all_models, outputs=status_text)
309
 
310
  with gr.Tab("🎨 Colorize Sketch"):
 
 
 
 
 
311
  with gr.Row():
312
  inp = gr.Image(label="Input Sketch/Image", type="pil")
313
  out = gr.Image(label="Colored Output")
@@ -316,15 +474,14 @@ with gr.Blocks() as demo:
316
  sketch_out = gr.Image(label="Detected Lineart", type="pil")
317
 
318
  with gr.Row():
319
- # เอาโมเดล stabilityai/stable-diffusion-2-1 และ runwayml/stable-diffusion-v1-5 ออกแล้ว
320
  base_model = gr.Dropdown(
321
  choices=[
322
- "admruul/anything-v3.0",
323
  "digiplay/ChikMix_V3",
324
  "digiplay/chilloutmix_NiPrunedFp16Fix",
325
- "gsdf/Counterfeit-V2.5"
 
326
  ],
327
- value="admruul/anything-v3.0",
328
  label="Base Model"
329
  )
330
  anime_chk = gr.Checkbox(label="Use Anime ControlNet", value=True)
@@ -350,20 +507,28 @@ with gr.Blocks() as demo:
350
  )
351
 
352
  with gr.Tab("🖼️ Text-to-Image"):
 
 
 
 
 
353
  with gr.Row():
354
  t2i_out = gr.Image(label="Output", type="pil")
355
 
356
  with gr.Row():
357
- t2i_prompt = gr.Textbox(label="Prompt", lines=3)
358
- # เอาโมเดล stabilityai/stable-diffusion-2-1 และ runwayml/stable-diffusion-v1-5 ออกแล้ว
 
 
 
359
  t2i_model = gr.Dropdown(
360
  choices=[
361
- "admruul/anything-v3.0",
362
  "digiplay/ChikMix_V3",
363
  "digiplay/chilloutmix_NiPrunedFp16Fix",
364
- "gsdf/Counterfeit-V2.5"
 
365
  ],
366
- value="admruul/anything-v3.0",
367
  label="Model"
368
  )
369
 
@@ -382,56 +547,14 @@ with gr.Blocks() as demo:
382
  [t2i_prompt, t2i_model, t2i_seed, t2i_steps, t2i_scale, w, h],
383
  t2i_out
384
  )
385
-
386
- with gr.Tab("🔄 Instruct-Pix2Pix"):
387
- gr.Markdown("### Edit Images with Text Instructions")
388
- gr.Markdown("ตัวอย่างคำสั่ง: 'make it winter', 'turn day into night', 'add sunglasses', 'make it look like a painting'")
389
-
390
- with gr.Row():
391
- with gr.Column():
392
- pix2pix_input = gr.Image(label="Input Image", type="pil")
393
- pix2pix_instruction = gr.Textbox(
394
- label="Edit Instruction",
395
- placeholder="e.g., make it winter, turn day into night, add sunglasses...",
396
- lines=2
397
- )
398
-
399
- with gr.Row():
400
- pix2pix_seed = gr.Number(value=42, label="Seed")
401
- pix2pix_steps = gr.Slider(10, 100, 50, step=5, label="Steps")
402
-
403
- with gr.Row():
404
- pix2pix_scale = gr.Slider(1, 20, 7.5, step=0.5, label="Text Guidance Scale")
405
- pix2pix_image_scale = gr.Slider(1, 5, 1.5, step=0.1, label="Image Guidance Scale")
406
-
407
- pix2pix_btn = gr.Button("🔄 Edit Image", variant="primary")
408
-
409
- with gr.Column():
410
- pix2pix_output = gr.Image(label="Edited Image", type="pil")
411
-
412
- # ตัวอย่างคำสั่งที่พบบ่อย
413
- with gr.Row():
414
- gr.Examples(
415
- examples=[
416
- ["make it winter", 42, 50, 7.5, 1.5],
417
- ["turn day into night", 42, 50, 7.5, 1.5],
418
- ["make it look like a painting", 42, 50, 7.5, 1.5],
419
- ["add sunglasses", 42, 50, 7.5, 1.5],
420
- ["make it cyberpunk style", 42, 50, 7.5, 1.5],
421
- ["change hair color to blue", 42, 50, 7.5, 1.5],
422
- ],
423
- inputs=[pix2pix_instruction, pix2pix_seed, pix2pix_steps, pix2pix_scale, pix2pix_image_scale],
424
- label="Quick Examples"
425
- )
426
-
427
- pix2pix_btn.click(
428
- pix2pix_edit,
429
- [pix2pix_input, pix2pix_instruction, pix2pix_seed, pix2pix_steps, pix2pix_scale, pix2pix_image_scale],
430
- pix2pix_output
431
- )
432
 
433
- # เพิ่ม error handling ในการ launch
434
  try:
435
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
 
436
  except Exception as e:
437
- print(f"Error launching Gradio app: {e}")
 
3
  from PIL import Image
4
  import torch
5
  import gc
6
+ import os
7
+ import warnings
8
+
9
+ # Suppress specific warnings
10
+ warnings.filterwarnings('ignore', category=FutureWarning)
11
+ warnings.filterwarnings('ignore', category=UserWarning)
12
+ warnings.filterwarnings('ignore', message='.*torch_dtype.*deprecated.*')
13
+ warnings.filterwarnings('ignore', message='.*CLIPFeatureExtractor.*deprecated.*')
14
+
15
+ # Performance optimizations
16
+ if torch.cuda.is_available():
17
+ torch.backends.cudnn.benchmark = True
18
+ torch.backends.cuda.matmul.allow_tf32 = True
19
+ torch.backends.cudnn.allow_tf32 = True
20
 
21
  # Device
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
 
25
+ print(f"🖥️ Device: {device} | dtype: {dtype}")
26
+
27
  # Lazy import (to avoid long startup if unused)
28
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionPipeline
 
29
  from controlnet_aux import LineartDetector, LineartAnimeDetector
30
 
31
+ # Memory optimization
32
+ if torch.cuda.is_available():
33
+ torch.cuda.empty_cache()
34
+ # Set memory fraction to prevent OOM
35
+ torch.cuda.set_per_process_memory_fraction(0.95)
36
+ print(f"🔥 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
37
+ else:
38
+ print("⚠️ Running on CPU - Image generation will be significantly slower")
39
+
40
  # ===== Model & Config =====
41
+ CURRENT_CONTROLNET_PIPE = None
42
+ CURRENT_CONTROLNET_KEY = None # (model_name, is_anime)
43
  LINEART_DETECTOR = None
44
  LINEART_ANIME_DETECTOR = None
45
  CURRENT_T2I_PIPE = None
46
  CURRENT_T2I_MODEL = None
 
 
47
 
48
  def get_pipeline(model_name: str, anime_model: bool = False):
49
  """Get or create a ControlNet pipeline for the given model and anime flag"""
50
+ global CURRENT_CONTROLNET_PIPE, CURRENT_CONTROLNET_KEY
51
+
52
  key = (model_name, anime_model)
53
 
54
+ # ถ้าเป็นโมเดลเดิมให้ใช้ต่อ
55
+ if CURRENT_CONTROLNET_KEY == key and CURRENT_CONTROLNET_PIPE is not None:
56
+ print(f"✅ Reusing existing ControlNet pipeline: {model_name}, anime: {anime_model}")
57
+ return CURRENT_CONTROLNET_PIPE
58
+
59
+ # ถ้าเป็นโมเดลใหม่ ลบอันเก่าก่อน
60
+ if CURRENT_CONTROLNET_PIPE is not None:
61
+ print(f"🗑️ Unloading old ControlNet pipeline: {CURRENT_CONTROLNET_KEY}")
62
+ del CURRENT_CONTROLNET_PIPE
63
+ CURRENT_CONTROLNET_PIPE = None
64
+ CURRENT_CONTROLNET_KEY = None
65
+ gc.collect()
66
+ if torch.cuda.is_available():
67
+ torch.cuda.empty_cache()
68
 
69
+ print(f"📥 Loading ControlNet pipeline for model: {model_name}, anime: {anime_model}")
70
 
71
  try:
72
  # โหลด ControlNet ที่เหมาะสม
 
87
  controlnet=controlnet,
88
  torch_dtype=dtype,
89
  safety_checker=None,
90
+ requires_safety_checker=False,
91
+ use_safetensors=True,
92
+ variant="fp16" if dtype == torch.float16 else None
93
  ).to(device)
94
 
95
+ # Aggressive memory optimizations
96
+ pipe.enable_attention_slicing(slice_size="max")
97
+
98
+ # Use new API for VAE slicing
99
+ if hasattr(pipe, 'vae') and hasattr(pipe.vae, 'enable_slicing'):
100
+ pipe.vae.enable_slicing()
101
+ else:
102
+ # Fallback for older versions
103
+ try:
104
+ pipe.enable_vae_slicing()
105
+ except:
106
+ pass
107
+
108
  if device.type == "cuda":
109
+ # Use xformers if available for better performance
110
+ try:
111
+ pipe.enable_xformers_memory_efficient_attention()
112
+ print("✅ xFormers enabled for ControlNet")
113
+ except:
114
+ print("⚠️ xFormers not available, using standard attention")
115
+ pass
116
+
117
+ # Enable model CPU offload for memory efficiency
118
  pipe.enable_model_cpu_offload()
119
 
120
+ # Compile model for faster inference (PyTorch 2.0+)
121
+ if hasattr(torch, 'compile') and device.type == "cuda":
122
+ try:
123
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
124
+ print("✅ Model compiled with torch.compile")
125
+ except Exception as e:
126
+ print(f"⚠️ torch.compile not available: {e}")
127
+ pass
128
+
129
+ # เก็บ pipeline ปัจจุบัน
130
+ CURRENT_CONTROLNET_PIPE = pipe
131
+ CURRENT_CONTROLNET_KEY = key
132
  return pipe
133
 
134
  except Exception as e:
135
  print(f"Error loading ControlNet pipeline: {e}")
136
+ print(f"⚠️ Trying to load without use_safetensors...")
137
+
138
+ # Retry without use_safetensors for models that don't support it
139
+ try:
140
+ if anime_model:
141
+ controlnet = ControlNetModel.from_pretrained(
142
+ "lllyasviel/control_v11p_sd15s2_lineart_anime",
143
+ torch_dtype=dtype
144
+ ).to(device)
145
+ else:
146
+ controlnet = ControlNetModel.from_pretrained(
147
+ "lllyasviel/control_v11p_sd15_lineart",
148
+ torch_dtype=dtype
149
+ ).to(device)
150
+
151
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
152
+ model_name,
153
+ controlnet=controlnet,
154
+ torch_dtype=dtype,
155
+ safety_checker=None,
156
+ requires_safety_checker=False
157
+ ).to(device)
158
+
159
+ # Optimizations
160
+ pipe.enable_attention_slicing(slice_size="max")
161
+ if hasattr(pipe, 'vae') and hasattr(pipe.vae, 'enable_slicing'):
162
+ pipe.vae.enable_slicing()
163
+ else:
164
+ try:
165
+ pipe.enable_vae_slicing()
166
+ except:
167
+ pass
168
+
169
+ if device.type == "cuda":
170
+ try:
171
+ pipe.enable_xformers_memory_efficient_attention()
172
+ print("✅ xFormers enabled for ControlNet")
173
+ except:
174
+ print("⚠️ xFormers not available, using standard attention")
175
+ pass
176
+ pipe.enable_model_cpu_offload()
177
+
178
+ if hasattr(torch, 'compile') and device.type == "cuda":
179
+ try:
180
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
181
+ print("✅ Model compiled with torch.compile")
182
+ except Exception as compile_err:
183
+ print(f"⚠️ torch.compile not available: {compile_err}")
184
+ pass
185
+
186
+ CURRENT_CONTROLNET_PIPE = pipe
187
+ CURRENT_CONTROLNET_KEY = key
188
+ return pipe
189
+
190
+ except Exception as retry_e:
191
+ print(f"❌ Error loading ControlNet pipeline (retry): {retry_e}")
192
+ CURRENT_CONTROLNET_PIPE = None
193
+ CURRENT_CONTROLNET_KEY = None
194
+ raise
195
 
196
  def load_lineart_detectors():
197
  """Load lineart detectors if not already loaded"""
 
211
  if CURRENT_T2I_MODEL == model_name and CURRENT_T2I_PIPE is not None:
212
  return
213
  if CURRENT_T2I_PIPE is not None:
214
+ print(f"🗑️ Unloading old T2I model: {CURRENT_T2I_MODEL}")
215
  del CURRENT_T2I_PIPE
216
  CURRENT_T2I_PIPE = None
217
  gc.collect()
218
  if torch.cuda.is_available():
219
  torch.cuda.empty_cache()
220
 
221
+ print(f"📥 Loading T2I model: {model_name}")
222
  CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
223
+ model_name,
224
+ torch_dtype=dtype,
225
+ safety_checker=None,
226
+ requires_safety_checker=False,
227
+ use_safetensors=True,
228
+ variant="fp16" if dtype == torch.float16 else None
229
  ).to(device)
230
+
231
+ # Optimizations
232
+ CURRENT_T2I_PIPE.enable_attention_slicing(slice_size="max")
233
+
234
+ # Use new API for VAE slicing
235
+ if hasattr(CURRENT_T2I_PIPE, 'vae') and hasattr(CURRENT_T2I_PIPE.vae, 'enable_slicing'):
236
+ CURRENT_T2I_PIPE.vae.enable_slicing()
237
+ else:
238
+ try:
239
+ CURRENT_T2I_PIPE.enable_vae_slicing()
240
+ except:
241
+ pass
242
+
243
  if device.type == "cuda":
244
+ try:
245
+ CURRENT_T2I_PIPE.enable_xformers_memory_efficient_attention()
246
+ print("✅ xFormers enabled for T2I")
247
+ except:
248
+ pass
249
  CURRENT_T2I_PIPE.enable_model_cpu_offload()
250
+
251
+ # Compile if available
252
+ if hasattr(torch, 'compile') and device.type == "cuda":
253
+ try:
254
+ CURRENT_T2I_PIPE.unet = torch.compile(CURRENT_T2I_PIPE.unet, mode="reduce-overhead", fullgraph=True)
255
+ print("✅ T2I model compiled")
256
+ except:
257
+ pass
258
+
259
  CURRENT_T2I_MODEL = model_name
260
 
261
  except Exception as e:
262
  print(f"Error loading T2I model {model_name}: {e}")
263
+ print(f"⚠️ Trying to load without use_safetensors...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
+ # Retry without use_safetensors
266
+ try:
267
+ CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
268
+ model_name,
269
+ torch_dtype=dtype,
270
+ safety_checker=None,
271
+ requires_safety_checker=False
272
+ ).to(device)
273
+
274
+ CURRENT_T2I_PIPE.enable_attention_slicing(slice_size="max")
275
+ if hasattr(CURRENT_T2I_PIPE, 'vae') and hasattr(CURRENT_T2I_PIPE.vae, 'enable_slicing'):
276
+ CURRENT_T2I_PIPE.vae.enable_slicing()
277
+ else:
278
+ try:
279
+ CURRENT_T2I_PIPE.enable_vae_slicing()
280
+ except:
281
+ pass
282
+
283
+ if device.type == "cuda":
284
+ try:
285
+ CURRENT_T2I_PIPE.enable_xformers_memory_efficient_attention()
286
+ print("✅ xFormers enabled for T2I")
287
+ except:
288
+ pass
289
+ CURRENT_T2I_PIPE.enable_model_cpu_offload()
290
+
291
+ if hasattr(torch, 'compile') and device.type == "cuda":
292
+ try:
293
+ CURRENT_T2I_PIPE.unet = torch.compile(CURRENT_T2I_PIPE.unet, mode="reduce-overhead", fullgraph=True)
294
+ print("✅ T2I model compiled")
295
+ except:
296
+ pass
297
+
298
+ CURRENT_T2I_MODEL = model_name
299
+
300
+ except Exception as retry_e:
301
+ print(f"❌ Error loading T2I model (retry): {retry_e}")
302
+ CURRENT_T2I_PIPE = None
303
+ CURRENT_T2I_MODEL = None
304
+ raise
305
 
306
  # ===== Utils =====
307
  def is_lineart(img: Image.Image) -> bool:
 
330
  # ===== Functions =====
331
  def colorize(sketch, base_model, anime_model, prompt, seed, steps, scale, cn_weight):
332
  try:
333
+ # โหลด pipeline ที่เหมาะสม (จะลบอันเก่าออกอัตโนมัติถ้าเปลี่ยนโมเดล)
334
  pipe = get_pipeline(base_model, anime_model)
335
 
336
+ # แสดงโมเดลที่กำลังใช้
337
+ controlnet_type = "Anime" if anime_model else "Standard"
338
+ status_msg = f"🎨 Using: {base_model} + {controlnet_type} ControlNet"
339
+ print(status_msg)
340
+
341
  # สกัด lineart
342
  lineart = extract_lineart(sketch, anime_model)
343
 
344
+ # สร้างภาพ with optimizations
345
  gen = torch.Generator(device=device).manual_seed(int(seed))
346
+
347
+ with torch.inference_mode():
348
+ out = pipe(
349
+ prompt,
350
+ image=lineart,
351
+ num_inference_steps=int(steps),
352
+ guidance_scale=float(scale),
353
+ controlnet_conditioning_scale=float(cn_weight),
354
+ generator=gen
355
+ ).images[0]
356
+
357
+ # Clear cache after generation
358
+ if device.type == "cuda":
359
+ torch.cuda.empty_cache()
360
 
361
  return out, lineart
362
  except Exception as e:
363
+ print(f"Error in colorize: {e}")
 
364
  error_img = Image.new('RGB', (512, 512), color='red')
 
365
  return error_img, Image.new('RGB', (512, 512), color='gray')
366
 
367
  def t2i(prompt, model, seed, steps, scale, w, h):
368
  try:
369
  load_t2i_model(model)
370
+ print(f"🖼️ Using T2I model: {model}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
 
 
 
 
372
  gen = torch.Generator(device=device).manual_seed(int(seed))
373
 
374
+ with torch.inference_mode():
375
+ result = CURRENT_T2I_PIPE(
376
+ prompt,
377
+ width=int(w),
378
+ height=int(h),
379
+ num_inference_steps=int(steps),
380
+ guidance_scale=float(scale),
381
+ generator=gen
382
+ ).images[0]
383
 
384
+ if device.type == "cuda":
385
+ torch.cuda.empty_cache()
386
 
387
+ return result
388
  except Exception as e:
389
+ print(f"Error in t2i: {e}")
390
+ error_img = Image.new('RGB', (int(w), int(h)), color='red')
 
 
 
 
391
  return error_img
392
 
393
  # ===== Function to unload all models =====
394
  def unload_all_models():
395
+ global CURRENT_CONTROLNET_PIPE, CURRENT_CONTROLNET_KEY
396
+ global LINEART_DETECTOR, LINEART_ANIME_DETECTOR
397
  global CURRENT_T2I_PIPE, CURRENT_T2I_MODEL
 
398
 
399
  print("Unloading all models from memory...")
400
 
401
+ # Unload ControlNet pipeline
402
+ try:
403
+ if CURRENT_CONTROLNET_PIPE is not None:
404
+ del CURRENT_CONTROLNET_PIPE
405
+ CURRENT_CONTROLNET_PIPE = None
406
+ except:
407
+ pass
408
+ CURRENT_CONTROLNET_KEY = None
409
 
410
  # Unload lineart detectors
411
  try:
 
429
  CURRENT_T2I_PIPE = None
430
  except:
431
  pass
 
432
  CURRENT_T2I_MODEL = None
433
 
 
 
 
 
 
 
 
 
 
 
434
  # Force garbage collection
435
  gc.collect()
436
  if torch.cuda.is_available():
437
  torch.cuda.empty_cache()
438
  allocated = torch.cuda.memory_allocated() / 1024**3
439
+ reserved = torch.cuda.memory_reserved() / 1024**3
440
+ print(f"💾 GPU memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
441
 
442
  return "✅ All models unloaded from memory!"
443
 
444
  # ===== Gradio UI =====
445
+ with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Soft()) as demo:
446
  gr.Markdown("# 🎨 Advanced Image Generation & Editing Suite")
447
+ gr.Markdown("### Powered by Stable Diffusion & ControlNet")
448
+
449
+ # Add system info
450
+ if torch.cuda.is_available():
451
+ gpu_name = torch.cuda.get_device_name(0)
452
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
453
+ gr.Markdown(f"**GPU:** {gpu_name} ({gpu_memory:.1f} GB)")
454
+ else:
455
+ gr.Markdown("**⚠️ Running on CPU** - Generation will be slower")
456
 
457
  # Add unload button at the top
458
  with gr.Row():
459
+ unload_btn = gr.Button("🗑️ Unload All Models", variant="stop", scale=1)
460
+ status_text = gr.Textbox(label="Status", interactive=False, scale=3)
461
  unload_btn.click(unload_all_models, outputs=status_text)
462
 
463
  with gr.Tab("🎨 Colorize Sketch"):
464
+ gr.Markdown("""
465
+ ### Convert your sketches to colored images using ControlNet
466
+ Upload a sketch or line art, and the AI will automatically colorize it based on your prompt.
467
+ """)
468
+
469
  with gr.Row():
470
  inp = gr.Image(label="Input Sketch/Image", type="pil")
471
  out = gr.Image(label="Colored Output")
 
474
  sketch_out = gr.Image(label="Detected Lineart", type="pil")
475
 
476
  with gr.Row():
 
477
  base_model = gr.Dropdown(
478
  choices=[
 
479
  "digiplay/ChikMix_V3",
480
  "digiplay/chilloutmix_NiPrunedFp16Fix",
481
+ "gsdf/Counterfeit-V2.5",
482
+ "stablediffusionapi/anything-v5"
483
  ],
484
+ value="digiplay/ChikMix_V3",
485
  label="Base Model"
486
  )
487
  anime_chk = gr.Checkbox(label="Use Anime ControlNet", value=True)
 
507
  )
508
 
509
  with gr.Tab("🖼️ Text-to-Image"):
510
+ gr.Markdown("""
511
+ ### Generate images from text descriptions
512
+ Describe what you want to see, and the AI will create it for you.
513
+ """)
514
+
515
  with gr.Row():
516
  t2i_out = gr.Image(label="Output", type="pil")
517
 
518
  with gr.Row():
519
+ t2i_prompt = gr.Textbox(
520
+ label="Prompt",
521
+ lines=3,
522
+ placeholder="e.g., a beautiful landscape with mountains and a lake at sunset, highly detailed, 4k"
523
+ )
524
  t2i_model = gr.Dropdown(
525
  choices=[
 
526
  "digiplay/ChikMix_V3",
527
  "digiplay/chilloutmix_NiPrunedFp16Fix",
528
+ "gsdf/Counterfeit-V2.5",
529
+ "stablediffusionapi/anything-v5"
530
  ],
531
+ value="digiplay/ChikMix_V3",
532
  label="Model"
533
  )
534
 
 
547
  [t2i_prompt, t2i_model, t2i_seed, t2i_steps, t2i_scale, w, h],
548
  t2i_out
549
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
 
 
551
  try:
552
+ demo.launch(
553
+ server_name="0.0.0.0",
554
+ server_port=7860,
555
+ share=False,
556
+ show_error=True,
557
+ quiet=False
558
+ )
559
  except Exception as e:
560
+ print(f"Error launching Gradio app: {e}")