K1Z3M1112 commited on
Commit
96afbed
·
verified ·
1 Parent(s): c8108cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +373 -682
app.py CHANGED
@@ -20,20 +20,18 @@ if torch.cuda.is_available():
20
 
21
  # Device
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
 
25
- print(f"🖥️ Device: {device} | dtype: {torch_dtype}")
26
 
27
- # Lazy import (to avoid long startup if unused)
28
  from diffusers import (
29
  StableDiffusionControlNetPipeline,
30
  ControlNetModel,
31
  StableDiffusionPipeline,
32
- StableDiffusionXLPipeline,
33
- DiffusionPipeline,
34
- StableDiffusionImg2ImgPipeline
35
  )
36
- from diffusers import UniPCMultistepScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler
37
  from controlnet_aux import (
38
  LineartDetector,
39
  LineartAnimeDetector,
@@ -44,7 +42,6 @@ from controlnet_aux import (
44
  HEDdetector,
45
  PidiNetDetector,
46
  NormalBaeDetector,
47
- ContentShuffleDetector,
48
  ZoeDetector,
49
  MediapipeFaceDetector
50
  )
@@ -52,7 +49,6 @@ from controlnet_aux import (
52
  # Memory optimization
53
  if torch.cuda.is_available():
54
  torch.cuda.empty_cache()
55
- # Set memory fraction to prevent OOM
56
  torch.cuda.set_per_process_memory_fraction(0.95)
57
  print(f"🔥 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
58
  else:
@@ -60,29 +56,33 @@ else:
60
 
61
  # ===== Model & Config =====
62
  CURRENT_CONTROLNET_PIPE = None
63
- CURRENT_CONTROLNET_KEY = None # (model_name, controlnet_type)
64
  CURRENT_T2I_PIPE = None
65
  CURRENT_T2I_MODEL = None
66
  CURRENT_SDXL_REFINER = None
67
 
68
- # Define model types with expanded lists
69
  SDXL_MODELS = [
70
  "stabilityai/stable-diffusion-xl-base-1.0",
71
  "stabilityai/stable-diffusion-xl-refiner-1.0",
72
  "Laxhar/noobai-XL-1.1",
73
  "RunDiffusion/Juggernaut-XL-v9",
74
  "dataautogpt3/ProteusV0.4",
75
- "thibaud/sdxl_dpo",
76
  "playgroundai/playground-v2.5-1024px-aesthetic",
77
- "stablediffusionapi/sdxl-unstable-diffusers-y"
 
 
 
 
78
  ]
79
 
 
80
  SD15_MODELS = [
 
81
  "digiplay/ChikMix_V3",
82
  "digiplay/chilloutmix_NiPrunedFp16Fix",
83
  "gsdf/Counterfeit-V2.5",
84
  "stablediffusionapi/anything-v5",
85
- "digiplay/CleanLinearMix_nsfw",
86
  "runwayml/stable-diffusion-v1-5",
87
  "stablediffusionapi/realistic-vision-v51",
88
  "stablediffusionapi/dreamshaper-v8",
@@ -90,35 +90,41 @@ SD15_MODELS = [
90
  "stablediffusionapi/rev-animated-v122",
91
  "stablediffusionapi/cyberrealistic-v33",
92
  "stablediffusionapi/meinamix-meina-v11",
93
- "stablediffusionapi/epicphotogasm-x",
94
- "stablediffusionapi/absolute-realism-v16",
95
- "stablediffusionapi/flat-2d-animerge",
96
  "prompthero/openjourney-v4",
97
  "wavymulder/Analog-Diffusion",
98
  "dreamlike-art/dreamlike-photoreal-2.0",
99
- "nitrosocke/redshift-diffusion",
100
- "segmind/SSD-1B", # 更小的模型
101
  "SG161222/Realistic_Vision_V5.1_noVAE",
102
  "Lykon/dreamshaper-8",
103
  "hakurei/waifu-diffusion",
104
  "andite/anything-v4.0",
105
- "Linaqruf/animagine-xl" # Anime specific
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  ]
107
 
108
- # 新增的中文模型
109
  CHINESE_MODELS = [
110
- "AI-Chen/Chinese-Stable-Diffusion", # 中文模型
111
- "IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1", # 太乙中文模型
112
- "AI-ModelScope/stable-diffusion-v1-5-chinese", # 中文适配
113
- "YeungNLP/fusionnet_img2text_chinese" # 中文图文
114
  ]
115
 
116
- # 新增 Florence-2 模型
117
- FLORENCE2_MODELS = [
118
- "microsoft/Florence-2-base"
119
- ]
120
-
121
- ALL_MODELS = SD15_MODELS + SDXL_MODELS + CHINESE_MODELS + FLORENCE2_MODELS
122
 
123
  # ControlNet models
124
  CONTROLNET_MODELS = {
@@ -128,47 +134,77 @@ CONTROLNET_MODELS = {
128
  "depth": "lllyasviel/control_v11p_sd15_depth",
129
  "normal": "lllyasviel/control_v11p_sd15_normalbae",
130
  "openpose": "lllyasviel/control_v11p_sd15_openpose",
131
- "scribble": "lllyasviel/control_v11p_sd15_scribble",
132
  "softedge": "lllyasviel/control_v11p_sd15_softedge",
133
  "segmentation": "lllyasviel/control_v11p_sd15_seg",
134
  "mlsd": "lllyasviel/control_v11p_sd15_mlsd",
135
  "shuffle": "lllyasviel/control_v11p_sd15_shuffle",
136
- "inpaint": "lllyasviel/control_v11p_sd15_inpaint",
137
- "tile": "lllyasviel/control_v11p_sd15_tile",
138
- "ip2p": "lllyasviel/control_v11p_sd15_ip2p",
139
- "color": "lllyasviel/control_v11p_sd15_color"
140
  }
141
 
142
- # SDXL ControlNet models (limited availability)
143
  SDXL_CONTROLNET_MODELS = {
144
  "canny_sdxl": "diffusers/controlnet-canny-sdxl-1.0",
145
- "depth_sdxl": "diffusers/controlnet-depth-sdxl-1.0"
 
146
  }
147
 
148
- # Popular LoRA models list
149
  LORA_MODELS = {
150
  "None": None,
 
151
  "Lowpoly Game Character": "nerijs/lowpoly-game-character-lora",
152
- "Japanese Doll": "Norod78/sd15-JapaneseDollLikeness_lora",
153
- "Korean Doll": "Norod78/sd15-KoreanDollLikeness_lora",
154
- "Detail Tweaker": "nitrosocke/detail-tweaker-lora",
155
  "Pixel Art": "nerijs/pixel-art-xl",
156
  "Watercolor Style": "OedoSoldier/watercolor-style-lora",
157
  "Manga Style": "raemikk/Animerge_V3.0_LoRA",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  "Photorealistic": "microsoft/lora-photorealistic",
159
- "Cyberpunk": "microsoft/lora-cyberpunk",
160
- "Fantasy Art": "microsoft/lora-fantasy-art",
161
- "Chinese Style": "yfszzx/Chinese_style_xl_LoRA", # 中国风
162
- "Traditional Painting": "artificialguybr/Traditional-Painting-Style-LoRA"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  }
164
 
165
  # Detector instances
166
  DETECTORS = {}
167
 
168
- # Florence-2 model cache
169
- FLORENCE2_PROCESSOR = None
170
- FLORENCE2_MODEL = None
171
-
172
  def is_sdxl_model(model_name: str) -> bool:
173
  """Check if model is SDXL"""
174
  return model_name in SDXL_MODELS or "xl" in model_name.lower() or "XL" in model_name
@@ -222,134 +258,6 @@ def get_controlnet_model(controlnet_type: str):
222
  else:
223
  raise ValueError(f"Unknown ControlNet type: {controlnet_type}")
224
 
225
- def load_florence2():
226
- """Lazy load Florence-2 model"""
227
- global FLORENCE2_PROCESSOR, FLORENCE2_MODEL
228
-
229
- if FLORENCE2_PROCESSOR is not None and FLORENCE2_MODEL is not None:
230
- return FLORENCE2_PROCESSOR, FLORENCE2_MODEL
231
-
232
- try:
233
- from transformers import AutoProcessor, AutoModelForCausalLM
234
-
235
- print("📥 Loading Microsoft/Florence-2-base...")
236
-
237
- # 按照官方文檔加載模型
238
- FLORENCE2_MODEL = AutoModelForCausalLM.from_pretrained(
239
- "microsoft/Florence-2-base",
240
- torch_dtype=torch_dtype,
241
- trust_remote_code=True
242
- ).to(device)
243
-
244
- FLORENCE2_PROCESSOR = AutoProcessor.from_pretrained(
245
- "microsoft/Florence-2-base",
246
- trust_remote_code=True
247
- )
248
-
249
- print("✅ Florence-2 model loaded successfully")
250
- return FLORENCE2_PROCESSOR, FLORENCE2_MODEL
251
-
252
- except Exception as e:
253
- print(f"❌ Error loading Florence-2: {e}")
254
- import traceback
255
- traceback.print_exc()
256
- return None, None
257
-
258
- def analyze_with_florence2(image, task_prompt):
259
- """Analyze image using Florence-2"""
260
- try:
261
- processor, model = load_florence2()
262
-
263
- if processor is None or model is None:
264
- return "❌ Failed to load Florence-2 model. Please check installation."
265
-
266
- # 檢查圖像
267
- if image is None:
268
- return "❌ No image provided for analysis."
269
-
270
- # 確保圖像是 PIL Image 格式
271
- if not isinstance(image, Image.Image):
272
- try:
273
- if isinstance(image, np.ndarray):
274
- image = Image.fromarray(image)
275
- else:
276
- return "❌ Invalid image format. Please upload a valid image."
277
- except Exception as e:
278
- return f"❌ Error converting image: {str(e)}"
279
-
280
- # 確保圖像是 RGB 模式
281
- if image.mode != 'RGB':
282
- image = image.convert('RGB')
283
-
284
- # 調整圖像大小以優化處理(可選)
285
- max_size = 512
286
- if max(image.size) > max_size:
287
- ratio = max_size / max(image.size)
288
- new_size = (int(image.width * ratio), int(image.height * ratio))
289
- image = image.resize(new_size, Image.Resampling.LANCZOS)
290
-
291
- # 按照官方文檔準備輸入
292
- try:
293
- inputs = processor(
294
- text=task_prompt,
295
- images=image,
296
- return_tensors="pt"
297
- ).to(device, torch_dtype)
298
- except Exception as e:
299
- print(f"❌ Error processing image: {e}")
300
- return f"❌ Error processing image: {str(e)}"
301
-
302
- # 按照官方文檔生成
303
- try:
304
- generated_ids = model.generate(
305
- input_ids=inputs["input_ids"],
306
- pixel_values=inputs["pixel_values"],
307
- max_new_tokens=1024,
308
- do_sample=False,
309
- num_beams=3,
310
- )
311
- except Exception as e:
312
- print(f"❌ Error generating text: {e}")
313
- return f"❌ Error during analysis: {str(e)}"
314
-
315
- # 解碼
316
- try:
317
- generated_text = processor.batch_decode(
318
- generated_ids,
319
- skip_special_tokens=False
320
- )[0]
321
- except Exception as e:
322
- print(f"❌ Error decoding text: {e}")
323
- return f"❌ Error decoding result: {str(e)}"
324
-
325
- # 使用 post_process_generation 解析結果
326
- try:
327
- parsed_answer = processor.post_process_generation(
328
- generated_text,
329
- task=task_prompt,
330
- image_size=(image.width, image.height)
331
- )
332
-
333
- # 將結果轉換為可讀字符串
334
- if isinstance(parsed_answer, dict):
335
- result_str = ""
336
- for key, value in parsed_answer.items():
337
- result_str += f"{key}:\n{value}\n\n"
338
- return result_str.strip()
339
- else:
340
- return str(parsed_answer)
341
-
342
- except Exception as e:
343
- print(f"❌ Error in post-processing: {e}")
344
- # 如果後處理失敗,返回原始生成的文本
345
- return f"Raw output: {generated_text}"
346
-
347
- except Exception as e:
348
- print(f"❌ Error in Florence-2 analysis: {e}")
349
- import traceback
350
- traceback.print_exc()
351
- return f"❌ Analysis error: {str(e)}"
352
-
353
  def prepare_condition_image(image, controlnet_type):
354
  """Prepare condition image for ControlNet"""
355
  if controlnet_type in ["lineart", "lineart_anime"]:
@@ -382,21 +290,19 @@ def prepare_condition_image(image, controlnet_type):
382
  result = detector(image, detect_resolution=512, image_resolution=512)
383
  return Image.fromarray(result) if isinstance(result, np.ndarray) else result
384
 
385
- # For other types, return original image or processed version
386
  return image
387
 
388
- def get_pipeline(model_name: str, controlnet_type: str = "lineart", lora_model: str = None, lora_weight: float = 0.8):
389
- """Get or create a ControlNet pipeline with optional LoRA"""
 
390
  global CURRENT_CONTROLNET_PIPE, CURRENT_CONTROLNET_KEY
391
 
392
- key = (model_name, controlnet_type, lora_model, lora_weight)
393
 
394
- # Reuse existing pipeline
395
  if CURRENT_CONTROLNET_KEY == key and CURRENT_CONTROLNET_PIPE is not None:
396
  print(f"✅ Reusing existing ControlNet pipeline: {model_name}, type: {controlnet_type}")
397
  return CURRENT_CONTROLNET_PIPE
398
 
399
- # Unload old pipeline
400
  if CURRENT_CONTROLNET_PIPE is not None:
401
  print(f"🗑️ Unloading old ControlNet pipeline: {CURRENT_CONTROLNET_KEY}")
402
  del CURRENT_CONTROLNET_PIPE
@@ -409,65 +315,64 @@ def get_pipeline(model_name: str, controlnet_type: str = "lineart", lora_model:
409
  print(f"📥 Loading ControlNet pipeline for model: {model_name}, type: {controlnet_type}")
410
 
411
  try:
412
- # Check if SDXL with ControlNet
413
  if is_sdxl_model(model_name):
414
- if controlnet_type in ["canny_sdxl", "depth_sdxl"]:
415
  controlnet_model_name = get_controlnet_model(controlnet_type)
416
  controlnet = ControlNetModel.from_pretrained(
417
  controlnet_model_name,
418
- torch_dtype=torch_dtype
419
  ).to(device)
420
 
421
  pipe = StableDiffusionXLPipeline.from_pretrained(
422
  model_name,
423
  controlnet=controlnet,
424
- torch_dtype=torch_dtype,
425
- safety_checker=None,
426
  requires_safety_checker=False,
427
  use_safetensors=True,
428
- variant="fp16" if torch_dtype == torch.float16 else None
429
  ).to(device)
430
  else:
431
- raise ValueError(f"SDXL model {model_name} only supports limited ControlNet types: {list(SDXL_CONTROLNET_MODELS.keys())}")
432
  else:
433
- # SD1.5 ControlNet
434
  controlnet_model_name = get_controlnet_model(controlnet_type)
435
  controlnet = ControlNetModel.from_pretrained(
436
  controlnet_model_name,
437
- torch_dtype=torch_dtype
438
  ).to(device)
439
 
440
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
441
  model_name,
442
  controlnet=controlnet,
443
- torch_dtype=torch_dtype,
444
- safety_checker=None,
445
  requires_safety_checker=False,
446
  use_safetensors=True,
447
- variant="fp16" if torch_dtype == torch.float16 else None
448
  ).to(device)
449
 
 
 
 
 
 
 
 
 
 
 
450
  # Apply LoRA if specified
451
  if lora_model and lora_model != "None":
452
  print(f"🔄 Applying LoRA: {lora_model} with weight: {lora_weight}")
453
  try:
454
- pipe.load_lora_weights(lora_model, weight_name=None if "safetensors" in lora_model else "pytorch_lora_weights.safetensors")
455
  pipe.fuse_lora(lora_scale=lora_weight)
456
  except Exception as e:
457
  print(f"⚠️ Error loading LoRA: {e}")
458
- print("Trying alternative LoRA loading method...")
459
- try:
460
- from safetensors.torch import load_file
461
- from huggingface_hub import hf_hub_download
462
- lora_path = hf_hub_download(lora_model, "pytorch_lora_weights.safetensors")
463
- pipe.unet.load_state_dict(load_file(lora_path), strict=False)
464
- except Exception as e2:
465
- print(f"❌ Failed to load LoRA: {e2}")
466
 
467
  # Optimizations
468
  pipe.enable_attention_slicing(slice_size="max")
469
 
470
- # VAE slicing
471
  if hasattr(pipe, 'vae') and hasattr(pipe.vae, 'enable_slicing'):
472
  pipe.vae.enable_slicing()
473
  else:
@@ -477,34 +382,19 @@ def get_pipeline(model_name: str, controlnet_type: str = "lineart", lora_model:
477
  pass
478
 
479
  if device.type == "cuda":
480
- # xFormers
481
  try:
482
  pipe.enable_xformers_memory_efficient_attention()
483
- print("✅ xFormers enabled for ControlNet")
484
  except:
485
- print("⚠️ xFormers not available, using standard attention")
486
  pass
487
-
488
- # Model CPU offload
489
  pipe.enable_model_cpu_offload()
490
 
491
- # Compile model for faster inference
492
- if hasattr(torch, 'compile') and device.type == "cuda":
493
- try:
494
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
495
- print("✅ Model compiled with torch.compile")
496
- except Exception as e:
497
- print(f"⚠️ torch.compile not available: {e}")
498
- pass
499
-
500
- # Change scheduler for better quality
501
  try:
502
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
503
- print("✅ Using UniPC scheduler for faster convergence")
504
  except:
505
  try:
506
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
507
- print("✅ Using DPM++ scheduler")
508
  except:
509
  pass
510
 
@@ -518,19 +408,17 @@ def get_pipeline(model_name: str, controlnet_type: str = "lineart", lora_model:
518
  CURRENT_CONTROLNET_KEY = None
519
  raise
520
 
521
- def load_t2i_model(model_name: str, lora_model: str = None, lora_weight: float = 0.8):
522
- """Load text-to-image model with optional LoRA"""
 
523
  global CURRENT_T2I_PIPE, CURRENT_T2I_MODEL, CURRENT_SDXL_REFINER
524
 
525
- # Check if we need to load refiner for SDXL
526
  use_refiner = "refiner" in model_name.lower()
527
-
528
- key = (model_name, lora_model, lora_weight, use_refiner)
529
 
530
  if CURRENT_T2I_MODEL == key and CURRENT_T2I_PIPE is not None:
531
  return
532
 
533
- # Unload old model
534
  if CURRENT_T2I_PIPE is not None:
535
  print(f"🗑️ Unloading old T2I model: {CURRENT_T2I_MODEL}")
536
  del CURRENT_T2I_PIPE
@@ -546,64 +434,67 @@ def load_t2i_model(model_name: str, lora_model: str = None, lora_weight: float =
546
 
547
  try:
548
  if is_sdxl_model(model_name):
549
- # Load SDXL model
550
  if use_refiner:
551
- # Load base and refiner
552
  CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
553
  "stabilityai/stable-diffusion-xl-base-1.0",
554
- torch_dtype=torch_dtype,
555
- safety_checker=None,
556
  requires_safety_checker=False,
557
  use_safetensors=True,
558
- variant="fp16" if torch_dtype == torch.float16 else None
559
  ).to(device)
560
 
561
  CURRENT_SDXL_REFINER = StableDiffusionXLPipeline.from_pretrained(
562
  model_name,
563
- torch_dtype=torch_dtype,
564
  safety_checker=None,
565
  requires_safety_checker=False,
566
  use_safetensors=True,
567
- variant="fp16" if torch_dtype == torch.float16 else None,
568
  text_encoder_2=CURRENT_T2I_PIPE.text_encoder_2,
569
  vae=CURRENT_T2I_PIPE.vae
570
  ).to(device)
571
- print(f"✅ Loaded SDXL with refiner: {model_name}")
572
  else:
573
  CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
574
- model_name,
575
- torch_dtype=torch_dtype,
576
- safety_checker=None,
577
  requires_safety_checker=False,
578
  use_safetensors=True,
579
- variant="fp16" if torch_dtype == torch.float16 else None
580
  ).to(device)
581
- print(f"✅ Loaded SDXL model: {model_name}")
582
  else:
583
- # Load SD1.5 model
584
  CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
585
- model_name,
586
- torch_dtype=torch_dtype,
587
- safety_checker=None,
588
  requires_safety_checker=False,
589
  use_safetensors=True,
590
- variant="fp16" if torch_dtype == torch.float16 else None
591
  ).to(device)
592
- print(f"✅ Loaded SD1.5 model: {model_name}")
593
 
594
- # Apply LoRA if specified
 
 
 
 
 
 
 
 
 
 
595
  if lora_model and lora_model != "None":
596
- print(f"🔄 Applying LoRA to T2I: {lora_model} with weight: {lora_weight}")
597
  try:
598
  CURRENT_T2I_PIPE.load_lora_weights(lora_model)
599
  CURRENT_T2I_PIPE.fuse_lora(lora_scale=lora_weight)
600
  except Exception as e:
601
- print(f"⚠️ Error loading LoRA for T2I: {e}")
602
 
603
  # Optimizations
604
  CURRENT_T2I_PIPE.enable_attention_slicing(slice_size="max")
605
 
606
- # VAE slicing
607
  if hasattr(CURRENT_T2I_PIPE, 'vae') and hasattr(CURRENT_T2I_PIPE.vae, 'enable_slicing'):
608
  CURRENT_T2I_PIPE.vae.enable_slicing()
609
  else:
@@ -615,129 +506,26 @@ def load_t2i_model(model_name: str, lora_model: str = None, lora_weight: float =
615
  if device.type == "cuda":
616
  try:
617
  CURRENT_T2I_PIPE.enable_xformers_memory_efficient_attention()
618
- print("✅ xFormers enabled for T2I")
619
  except:
620
  pass
621
  CURRENT_T2I_PIPE.enable_model_cpu_offload()
622
 
623
- # Change scheduler
624
  try:
625
- CURRENT_T2I_PIPE.scheduler = UniPCMultistepScheduler.from_config(CURRENT_T2I_PIPE.scheduler.config)
626
- print("✅ Using UniPC scheduler")
627
  except:
628
- try:
629
- CURRENT_T2I_PIPE.scheduler = DPMSolverMultistepScheduler.from_config(CURRENT_T2I_PIPE.scheduler.config)
630
- print("✅ Using DPM++ scheduler")
631
- except:
632
- pass
633
 
634
  CURRENT_T2I_MODEL = key
635
 
636
  except Exception as e:
637
- print(f"❌ Error loading T2I model {model_name}: {e}")
638
- print(f"⚠️ Trying to load without use_safetensors...")
639
-
640
- # Retry without use_safetensors
641
- try:
642
- if is_sdxl_model(model_name):
643
- CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
644
- model_name,
645
- torch_dtype=torch_dtype,
646
- safety_checker=None,
647
- requires_safety_checker=False
648
- ).to(device)
649
- else:
650
- CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
651
- model_name,
652
- torch_dtype=torch_dtype,
653
- safety_checker=None,
654
- requires_safety_checker=False
655
- ).to(device)
656
-
657
- # Optimizations
658
- CURRENT_T2I_PIPE.enable_attention_slicing(slice_size="max")
659
- if hasattr(CURRENT_T2I_PIPE, 'vae') and hasattr(CURRENT_T2I_PIPE.vae, 'enable_slicing'):
660
- CURRENT_T2I_PIPE.vae.enable_slicing()
661
- else:
662
- try:
663
- CURRENT_T2I_PIPE.enable_vae_slicing()
664
- except:
665
- pass
666
-
667
- if device.type == "cuda":
668
- try:
669
- CURRENT_T2I_PIPE.enable_xformers_memory_efficient_attention()
670
- print("✅ xFormers enabled for T2I")
671
- except:
672
- pass
673
- CURRENT_T2I_PIPE.enable_model_cpu_offload()
674
-
675
- CURRENT_T2I_MODEL = key
676
-
677
- except Exception as retry_e:
678
- print(f"❌ Error loading T2I model (retry): {retry_e}")
679
- CURRENT_T2I_PIPE = None
680
- CURRENT_T2I_MODEL = None
681
- raise
682
-
683
- # ===== Utils =====
684
- def resize_image(image, max_size=1024):
685
- """Resize image while maintaining aspect ratio"""
686
- width, height = image.size
687
- if max(width, height) > max_size:
688
- ratio = max_size / max(width, height)
689
- new_width = int(width * ratio)
690
- new_height = int(height * ratio)
691
- return image.resize((new_width, new_height), Image.LANCZOS)
692
- return image
693
-
694
- def image_to_image(img, prompt, negative_prompt, model_name, strength=0.75, steps=30, scale=7.5, seed=42):
695
- """Image-to-Image transformation"""
696
- try:
697
- load_t2i_model(model_name)
698
-
699
- # Resize if needed
700
- img = resize_image(img, 1024)
701
-
702
- # Create img2img pipeline
703
- pipe = StableDiffusionImg2ImgPipeline(
704
- vae=CURRENT_T2I_PIPE.vae,
705
- text_encoder=CURRENT_T2I_PIPE.text_encoder,
706
- tokenizer=CURRENT_T2I_PIPE.tokenizer,
707
- unet=CURRENT_T2I_PIPE.unet,
708
- scheduler=CURRENT_T2I_PIPE.scheduler,
709
- safety_checker=None,
710
- feature_extractor=None,
711
- requires_safety_checker=False,
712
- ).to(device)
713
-
714
- gen = torch.Generator(device=device).manual_seed(int(seed))
715
-
716
- with torch.inference_mode():
717
- result = pipe(
718
- prompt=prompt,
719
- negative_prompt=negative_prompt,
720
- image=img,
721
- strength=strength,
722
- num_inference_steps=int(steps),
723
- guidance_scale=float(scale),
724
- generator=gen
725
- ).images[0]
726
-
727
- if device.type == "cuda":
728
- torch.cuda.empty_cache()
729
-
730
- return result
731
- except Exception as e:
732
- print(f"❌ Error in img2img: {e}")
733
- error_img = Image.new('RGB', (512, 512), color='red')
734
- return error_img
735
 
736
- # ===== Functions =====
737
- def colorize(sketch, base_model, controlnet_type, lora_model, lora_weight,
738
  prompt, negative_prompt, seed, steps, scale, cn_weight):
739
  try:
740
- # 檢查是否為 SDXL model 且不支援 ControlNet
741
  if is_sdxl_model(base_model) and controlnet_type not in SDXL_CONTROLNET_MODELS:
742
  error_img = Image.new('RGB', (512, 512), color='red')
743
  error_msg_img = Image.new('RGB', (512, 512), color='yellow')
@@ -751,34 +539,30 @@ def colorize(sketch, base_model, controlnet_type, lora_model, lora_weight,
751
  draw.text((50, 230), f"{', '.join(SDXL_CONTROLNET_MODELS.keys())}", fill="black", font=font)
752
  return error_img, error_msg_img
753
 
754
- # 載入 pipeline
755
- pipe = get_pipeline(base_model, controlnet_type, lora_model, lora_weight)
756
 
757
- status_msg = f"🎨 Using: {base_model} + {controlnet_type} ControlNet"
758
  if lora_model and lora_model != "None":
759
  status_msg += f" + {lora_model}"
760
  print(status_msg)
761
 
762
- # 準備 condition image
763
  condition_img = prepare_condition_image(sketch, controlnet_type)
764
 
765
- # 生成圖像
766
  gen = torch.Generator(device=device).manual_seed(int(seed))
767
 
768
  with torch.inference_mode():
769
  out = pipe(
770
- prompt,
771
  negative_prompt=negative_prompt,
772
- image=condition_img,
773
  num_inference_steps=int(steps),
774
- guidance_scale=float(scale),
775
  controlnet_conditioning_scale=float(cn_weight),
776
  generator=gen,
777
  height=512,
778
  width=512
779
  ).images[0]
780
 
781
- # Clear cache
782
  if device.type == "cuda":
783
  torch.cuda.empty_cache()
784
 
@@ -788,14 +572,14 @@ def colorize(sketch, base_model, controlnet_type, lora_model, lora_weight,
788
  error_img = Image.new('RGB', (512, 512), color='red')
789
  return error_img, Image.new('RGB', (512, 512), color='gray')
790
 
791
- def t2i(prompt, negative_prompt, model, lora_model, lora_weight, seed, steps, scale, w, h, use_refiner=False):
 
792
  try:
793
- # 如果需要 refiner,使用特殊的模型名稱
794
  model_to_load = model
795
  if use_refiner and "refiner" not in model.lower():
796
  model_to_load = "stabilityai/stable-diffusion-xl-refiner-1.0"
797
 
798
- load_t2i_model(model_to_load, lora_model, lora_weight)
799
 
800
  print(f"🖼️ Using T2I model: {model}")
801
  if lora_model and lora_model != "None":
@@ -804,50 +588,46 @@ def t2i(prompt, negative_prompt, model, lora_model, lora_weight, seed, steps, sc
804
  gen = torch.Generator(device=device).manual_seed(int(seed))
805
 
806
  with torch.inference_mode():
807
- # SDXL with refiner
808
  if use_refiner and CURRENT_SDXL_REFINER is not None:
809
- # First stage with base model
810
  image = CURRENT_T2I_PIPE(
811
  prompt=prompt,
812
  negative_prompt=negative_prompt,
813
  width=int(w),
814
  height=int(h),
815
- num_inference_steps=int(steps//2), # Half steps for base
816
  guidance_scale=float(scale),
817
  generator=gen,
818
  output_type="latent"
819
  ).images
820
 
821
- # Second stage with refiner
822
  result = CURRENT_SDXL_REFINER(
823
  prompt=prompt,
824
  negative_prompt=negative_prompt,
825
  image=image,
826
- num_inference_steps=int(steps//2), # Half steps for refiner
827
  guidance_scale=float(scale),
828
  generator=gen
829
  ).images[0]
830
  else:
831
- # Normal generation
832
  if is_sdxl_model(model):
833
  width = max(int(w), 512)
834
  height = max(int(h), 512)
835
  result = CURRENT_T2I_PIPE(
836
- prompt,
837
  negative_prompt=negative_prompt,
838
- width=width,
839
  height=height,
840
- num_inference_steps=int(steps),
841
  guidance_scale=float(scale),
842
  generator=gen
843
  ).images[0]
844
  else:
845
  result = CURRENT_T2I_PIPE(
846
- prompt,
847
  negative_prompt=negative_prompt,
848
- width=int(w),
849
  height=int(h),
850
- num_inference_steps=int(steps),
851
  guidance_scale=float(scale),
852
  generator=gen
853
  ).images[0]
@@ -868,39 +648,13 @@ def t2i(prompt, negative_prompt, model, lora_model, lora_weight, seed, steps, sc
868
  draw.text((50, 50), f"Error: {str(e)[:50]}...", fill="white", font=font)
869
  return error_img
870
 
871
- def florence2_analysis(image, task_prompt, custom_prompt):
872
- """Analyze image with Florence-2"""
873
- try:
874
- if image is None:
875
- return "❌ Please upload an image first"
876
-
877
- # 確保圖像是 PIL Image 格式
878
- if not isinstance(image, Image.Image):
879
- return "❌ Invalid image format. Please upload a valid image."
880
-
881
- # Use custom prompt if provided
882
- prompt_to_use = custom_prompt.strip() if custom_prompt.strip() else task_prompt
883
-
884
- print(f"🔍 Analyzing image with Florence-2 using prompt: {prompt_to_use}")
885
- result = analyze_with_florence2(image, prompt_to_use)
886
- return result
887
-
888
- except Exception as e:
889
- print(f"❌ Error in Florence-2 analysis: {e}")
890
- import traceback
891
- traceback.print_exc()
892
- return f"Error: {str(e)}"
893
-
894
- # ===== Function to unload all models =====
895
  def unload_all_models():
896
  global CURRENT_CONTROLNET_PIPE, CURRENT_CONTROLNET_KEY
897
  global DETECTORS
898
  global CURRENT_T2I_PIPE, CURRENT_T2I_MODEL, CURRENT_SDXL_REFINER
899
- global FLORENCE2_PROCESSOR, FLORENCE2_MODEL
900
 
901
- print("Unloading all models from memory...")
902
 
903
- # Unload ControlNet pipeline
904
  try:
905
  if CURRENT_CONTROLNET_PIPE is not None:
906
  del CURRENT_CONTROLNET_PIPE
@@ -909,7 +663,6 @@ def unload_all_models():
909
  pass
910
  CURRENT_CONTROLNET_KEY = None
911
 
912
- # Unload detectors
913
  for detector_type in list(DETECTORS.keys()):
914
  try:
915
  del DETECTORS[detector_type]
@@ -917,7 +670,6 @@ def unload_all_models():
917
  pass
918
  DETECTORS.clear()
919
 
920
- # Unload T2I models
921
  try:
922
  if CURRENT_T2I_PIPE is not None:
923
  del CURRENT_T2I_PIPE
@@ -934,22 +686,6 @@ def unload_all_models():
934
 
935
  CURRENT_T2I_MODEL = None
936
 
937
- # Unload Florence-2
938
- try:
939
- if FLORENCE2_PROCESSOR is not None:
940
- del FLORENCE2_PROCESSOR
941
- FLORENCE2_PROCESSOR = None
942
- except:
943
- pass
944
-
945
- try:
946
- if FLORENCE2_MODEL is not None:
947
- del FLORENCE2_MODEL
948
- FLORENCE2_MODEL = None
949
- except:
950
- pass
951
-
952
- # Force garbage collection
953
  gc.collect()
954
  if torch.cuda.is_available():
955
  torch.cuda.empty_cache()
@@ -960,12 +696,11 @@ def unload_all_models():
960
  return "✅ All models unloaded from memory!"
961
 
962
  # ===== Gradio UI =====
963
- with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Soft()) as demo:
964
- gr.Markdown("# 🎨 Advanced Image Generation & Editing Suite")
965
- gr.Markdown("### Powered by Stable Diffusion & ControlNet")
966
- gr.Markdown("**Note:** SDXL models work with limited ControlNet types (canny_sdxl, depth_sdxl)")
967
 
968
- # System info
969
  if torch.cuda.is_available():
970
  gpu_name = torch.cuda.get_device_name(0)
971
  gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
@@ -973,301 +708,257 @@ with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Sof
973
  else:
974
  gr.Markdown("**⚠️ Running on CPU** - Generation will be slower")
975
 
976
- # Add unload button
977
  with gr.Row():
978
  unload_btn = gr.Button("🗑️ Unload All Models", variant="stop", scale=1)
979
  status_text = gr.Textbox(label="Status", interactive=False, scale=3)
980
  unload_btn.click(unload_all_models, outputs=status_text)
981
 
982
- with gr.Tab("🎨 ControlNet Colorize"):
983
  gr.Markdown("""
984
- ### Convert sketches to colored images using ControlNet
985
- **SD1.5 Models:** Support all ControlNet types
986
- **SDXL Models:** Only support canny_sdxl and depth_sdxl
987
  """)
988
 
989
  with gr.Row():
990
- inp = gr.Image(label="Input Sketch/Image", type="pil")
991
- out = gr.Image(label="Colored Output")
992
-
993
- with gr.Row():
994
- condition_out = gr.Image(label="Processed Condition Image", type="pil")
995
-
996
- with gr.Row():
997
- base_model = gr.Dropdown(
998
- choices=ALL_MODELS,
999
- value="digiplay/ChikMix_V3",
1000
- label="Base Model"
1001
- )
1002
- controlnet_type = gr.Dropdown(
1003
- choices=list(CONTROLNET_MODELS.keys()) + list(SDXL_CONTROLNET_MODELS.keys()),
1004
- value="lineart_anime",
1005
- label="ControlNet Type"
1006
- )
1007
-
1008
- with gr.Row():
1009
- lora_model = gr.Dropdown(
1010
- choices=list(LORA_MODELS.keys()),
1011
- value="None",
1012
- label="LoRA Model (Optional)"
1013
- )
1014
- lora_weight = gr.Slider(0.1, 1.5, 0.8, step=0.1, label="LoRA Weight")
 
 
 
 
 
 
 
 
1015
 
 
1016
  with gr.Row():
1017
  prompt = gr.Textbox(
1018
- label="Prompt",
1019
- placeholder="e.g., 1girl, blonde hair, blue eyes, beautiful, masterpiece",
1020
- lines=2
1021
  )
1022
  negative_prompt = gr.Textbox(
1023
- label="Negative Prompt",
1024
- placeholder="e.g., ugly, deformed, bad anatomy, blurry",
1025
- lines=2
1026
  )
1027
 
1028
  with gr.Row():
1029
- seed = gr.Number(value=42, label="Seed")
1030
- steps = gr.Slider(10, 100, 30, step=1, label="Steps")
1031
- scale = gr.Slider(1, 20, 7.5, step=0.5, label="CFG Scale")
1032
  cn_weight = gr.Slider(0.1, 2.0, 1.0, step=0.1, label="ControlNet Weight")
1033
 
1034
- run = gr.Button("🎨 Colorize", variant="primary")
1035
  run.click(
1036
- colorize,
1037
- [inp, base_model, controlnet_type, lora_model, lora_weight,
1038
- prompt, negative_prompt, seed, steps, scale, cn_weight],
1039
  [out, condition_out]
1040
  )
 
 
 
 
 
 
 
 
1041
 
1042
- with gr.Tab("🖼️ Text-to-Image"):
1043
  gr.Markdown("""
1044
  ### Generate images from text descriptions
1045
- Supports both SD1.5 and SDXL models with optional LoRA.
1046
- **Tip:** SDXL models produce higher quality but require more memory.
1047
  """)
1048
 
1049
  with gr.Row():
1050
- t2i_out = gr.Image(label="Output", type="pil")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1051
 
 
1052
  with gr.Row():
1053
  t2i_prompt = gr.Textbox(
1054
- label="Prompt",
1055
- lines=3,
1056
- placeholder="e.g., a beautiful landscape with mountains and a lake at sunset, highly detailed, 4k"
1057
  )
1058
  t2i_negative_prompt = gr.Textbox(
1059
- label="Negative Prompt",
1060
- lines=2,
1061
- placeholder="e.g., blurry, ugly, deformed, low quality"
1062
- )
1063
-
1064
- with gr.Row():
1065
- t2i_model = gr.Dropdown(
1066
- choices=ALL_MODELS,
1067
- value="digiplay/ChikMix_V3",
1068
- label="Model"
1069
- )
1070
- t2i_lora = gr.Dropdown(
1071
- choices=list(LORA_MODELS.keys()),
1072
- value="None",
1073
- label="LoRA Model (Optional)"
1074
  )
1075
- t2i_lora_weight = gr.Slider(0.1, 1.5, 0.8, step=0.1, label="LoRA Weight")
1076
 
 
1077
  with gr.Row():
1078
- t2i_seed = gr.Number(value=42, label="Seed")
1079
- t2i_steps = gr.Slider(10, 100, 30, step=1, label="Steps")
1080
- t2i_scale = gr.Slider(1, 20, 7.5, step=0.5, label="CFG Scale")
1081
 
1082
  with gr.Row():
1083
- w = gr.Slider(256, 2048, 1024, step=64, label="Width")
1084
- h = gr.Slider(256, 2048, 1024, step=64, label="Height")
1085
- use_refiner = gr.Checkbox(label="Use SDXL Refiner (SDXL only)", value=False)
1086
 
1087
- gen_btn = gr.Button("🖼️ Generate", variant="primary")
1088
  gen_btn.click(
1089
- t2i,
1090
- [t2i_prompt, t2i_negative_prompt, t2i_model, t2i_lora, t2i_lora_weight,
1091
- t2i_seed, t2i_steps, t2i_scale, w, h, use_refiner],
1092
  t2i_out
1093
  )
1094
-
1095
- with gr.Tab("🔄 Image-to-Image"):
1096
- gr.Markdown("""
1097
- ### Transform existing images using img2img
1098
- Modify images based on prompts with control over transformation strength.
1099
- """)
1100
-
1101
- with gr.Row():
1102
- img2img_input = gr.Image(label="Input Image", type="pil")
1103
- img2img_output = gr.Image(label="Transformed Output")
1104
-
1105
- with gr.Row():
1106
- img2img_prompt = gr.Textbox(
1107
- label="Prompt",
1108
- lines=2,
1109
- placeholder="e.g., make it anime style, cyberpunk style, etc."
1110
- )
1111
- img2img_negative_prompt = gr.Textbox(
1112
- label="Negative Prompt",
1113
- lines=2,
1114
- placeholder="e.g., blurry, low quality"
1115
- )
1116
 
1117
- with gr.Row():
1118
- img2img_model = gr.Dropdown(
1119
- choices=ALL_MODELS,
1120
- value="stablediffusionapi/realistic-vision-v51",
1121
- label="Model"
1122
- )
1123
- img2img_strength = gr.Slider(0.1, 0.95, 0.75, step=0.05, label="Transformation Strength")
1124
-
1125
- with gr.Row():
1126
- img2img_seed = gr.Number(value=42, label="Seed")
1127
- img2img_steps = gr.Slider(10, 100, 30, step=1, label="Steps")
1128
- img2img_scale = gr.Slider(1, 20, 7.5, step=0.5, label="CFG Scale")
1129
-
1130
- img2img_btn = gr.Button("🔄 Transform Image", variant="primary")
1131
- img2img_btn.click(
1132
- image_to_image,
1133
- [img2img_input, img2img_prompt, img2img_negative_prompt,
1134
- img2img_model, img2img_strength, img2img_steps, img2img_scale, img2img_seed],
1135
- img2img_output
1136
- )
1137
-
1138
- with gr.Tab("🔍 Florence-2 Vision Analysis"):
1139
  gr.Markdown("""
1140
- ### Microsoft Florence-2 Vision Language Model
1141
- **Pre-trained Tasks:**
1142
- - `<OCR>`: Text recognition (Extract text from image)
1143
- - `<CAPTION>`: Image captioning (Generate a caption)
1144
- - `<DETAILED_CAPTION>`: Detailed caption (More detailed description)
1145
- - `<MORE_DETAILED_CAPTION>`: More detailed caption (Even more details)
1146
- - `<OD>`: Object detection (Detect objects with bounding boxes)
1147
- - `<OPEN_VOCABULARY_DETECTION>`: Open-vocabulary detection
1148
- - `<REGION_PROPOSAL>`: Region proposal
1149
-
1150
- **How to use:**
1151
- 1. Upload an image
1152
- 2. Select a task from the dropdown
1153
- 3. Click "Analyze Image"
1154
- 4. Results will be displayed in the text box
1155
-
1156
- **Example tasks:**
1157
- - Extract text from a document: `<OCR>`
1158
- - Describe what's in the image: `<CAPTION>`
1159
- - Detect objects in the image: `<OD>`
1160
  """)
1161
-
1162
- with gr.Row():
1163
- florence_input = gr.Image(label="Input Image", type="pil")
1164
- florence_output = gr.Textbox(
1165
- label="Analysis Result",
1166
- lines=15,
1167
- interactive=False,
1168
- show_copy_button=True
1169
- )
1170
-
1171
- with gr.Row():
1172
- florence_task = gr.Dropdown(
1173
- choices=[
1174
- "<OCR>",
1175
- "<CAPTION>",
1176
- "<DETAILED_CAPTION>",
1177
- "<MORE_DETAILED_CAPTION>",
1178
- "<OD>",
1179
- "<OPEN_VOCABULARY_DETECTION>",
1180
- "<REGION_PROPOSAL>"
1181
- ],
1182
- value="<CAPTION>",
1183
- label="Task Prompt"
1184
- )
1185
-
1186
- custom_prompt = gr.Textbox(
1187
- label="Custom Prompt (Optional)",
1188
- value="",
1189
- placeholder="e.g., Describe the main objects in this image"
1190
- )
1191
-
1192
- with gr.Row():
1193
- analyze_btn = gr.Button("🔍 Analyze Image", variant="primary")
1194
- clear_btn = gr.Button("🗑️ Clear")
1195
-
1196
- def clear_analysis():
1197
- return None, ""
1198
-
1199
- analyze_btn.click(
1200
- florence2_analysis,
1201
- [florence_input, florence_task, custom_prompt],
1202
- florence_output
1203
- )
1204
-
1205
- clear_btn.click(
1206
- clear_analysis,
1207
- [],
1208
- [florence_input, florence_output]
1209
- )
1210
-
1211
- with gr.Tab("📊 Model Info"):
1212
  gr.Markdown("""
1213
- ### Available Models Information
1214
-
1215
- **SD1.5 Models (Support all ControlNet types):**
1216
- - Recommended for ControlNet workflows
1217
- - Faster inference, lower memory usage
1218
- - Wide variety of styles available
1219
-
1220
- **SDXL Models (Higher quality, limited ControlNet):**
1221
- - Better quality, more details
1222
- - Larger image sizes (1024x1024+)
1223
- - Only supports canny_sdxl and depth_sdxl ControlNet
1224
-
1225
- **Chinese Models:**
1226
- - Optimized for Chinese prompts
1227
- - Better understanding of Chinese culture elements
1228
-
1229
- **Florence-2 Model:**
1230
- - Microsoft's vision-language model
1231
- - Image analysis, OCR, captioning, object detection
1232
-
1233
- **LoRA Models:**
1234
- - Fine-tuned models for specific styles
1235
- - Can be combined with base models
1236
- - Adjust weight for stronger/weaker effect
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1237
  """)
1238
-
1239
- with gr.Row():
1240
- with gr.Column():
1241
- gr.Markdown("**SD1.5 Models Count:** " + str(len(SD15_MODELS)))
1242
- gr.Markdown("**SDXL Models Count:** " + str(len(SDXL_MODELS)))
1243
- gr.Markdown("**Chinese Models Count:** " + str(len(CHINESE_MODELS)))
1244
- gr.Markdown("**Florence-2 Models:** " + str(len(FLORENCE2_MODELS)))
1245
- gr.Markdown("**ControlNet Types:** " + str(len(CONTROLNET_MODELS) + len(SDXL_CONTROLNET_MODELS)))
1246
- gr.Markdown("**LoRA Models:** " + str(len(LORA_MODELS) - 1)) # Subtract "None"
1247
-
1248
- with gr.Row():
1249
- refresh_btn = gr.Button("🔄 Refresh Memory Info")
1250
- memory_info = gr.Textbox(label="Memory Status")
1251
-
1252
- def get_memory_info():
1253
- info = ""
1254
- if torch.cuda.is_available():
1255
- allocated = torch.cuda.memory_allocated() / 1024**3
1256
- reserved = torch.cuda.memory_reserved() / 1024**3
1257
- max_allocated = torch.cuda.max_memory_allocated() / 1024**3
1258
- info = f"Allocated: {allocated:.2f} GB\n"
1259
- info += f"Reserved: {reserved:.2f} GB\n"
1260
- info += f"Max Allocated: {max_allocated:.2f} GB"
1261
- else:
1262
- info = "Running on CPU - No GPU memory info"
1263
- return info
1264
-
1265
- refresh_btn.click(get_memory_info, outputs=memory_info)
1266
 
1267
  try:
1268
  demo.launch(
1269
- server_name="0.0.0.0",
1270
- server_port=7860,
1271
  share=False,
1272
  show_error=True,
1273
  quiet=False
 
20
 
21
  # Device
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
 
25
+ print(f"🖥️ Device: {device} | dtype: {dtype}")
26
 
27
+ # Lazy import
28
  from diffusers import (
29
  StableDiffusionControlNetPipeline,
30
  ControlNetModel,
31
  StableDiffusionPipeline,
32
+ StableDiffusionXLPipeline
 
 
33
  )
34
+ from diffusers import UniPCMultistepScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler
35
  from controlnet_aux import (
36
  LineartDetector,
37
  LineartAnimeDetector,
 
42
  HEDdetector,
43
  PidiNetDetector,
44
  NormalBaeDetector,
 
45
  ZoeDetector,
46
  MediapipeFaceDetector
47
  )
 
49
  # Memory optimization
50
  if torch.cuda.is_available():
51
  torch.cuda.empty_cache()
 
52
  torch.cuda.set_per_process_memory_fraction(0.95)
53
  print(f"🔥 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
54
  else:
 
56
 
57
  # ===== Model & Config =====
58
  CURRENT_CONTROLNET_PIPE = None
59
+ CURRENT_CONTROLNET_KEY = None
60
  CURRENT_T2I_PIPE = None
61
  CURRENT_T2I_MODEL = None
62
  CURRENT_SDXL_REFINER = None
63
 
64
+ # Enhanced SDXL Models (including NSFW-capable)
65
  SDXL_MODELS = [
66
  "stabilityai/stable-diffusion-xl-base-1.0",
67
  "stabilityai/stable-diffusion-xl-refiner-1.0",
68
  "Laxhar/noobai-XL-1.1",
69
  "RunDiffusion/Juggernaut-XL-v9",
70
  "dataautogpt3/ProteusV0.4",
 
71
  "playgroundai/playground-v2.5-1024px-aesthetic",
72
+ "misri/epicrealismXL_v10",
73
+ "SG161222/RealVisXL_V4.0",
74
+ "stablediffusionapi/juggernaut-xl-v8",
75
+ "Lykon/dreamshaper-xl-1-0",
76
+ "digiplay/Pony_Diffusion_V6_XL"
77
  ]
78
 
79
+ # Enhanced SD1.5 Models (including NSFW-capable)
80
  SD15_MODELS = [
81
+ # Original models
82
  "digiplay/ChikMix_V3",
83
  "digiplay/chilloutmix_NiPrunedFp16Fix",
84
  "gsdf/Counterfeit-V2.5",
85
  "stablediffusionapi/anything-v5",
 
86
  "runwayml/stable-diffusion-v1-5",
87
  "stablediffusionapi/realistic-vision-v51",
88
  "stablediffusionapi/dreamshaper-v8",
 
90
  "stablediffusionapi/rev-animated-v122",
91
  "stablediffusionapi/cyberrealistic-v33",
92
  "stablediffusionapi/meinamix-meina-v11",
 
 
 
93
  "prompthero/openjourney-v4",
94
  "wavymulder/Analog-Diffusion",
95
  "dreamlike-art/dreamlike-photoreal-2.0",
96
+ "segmind/SSD-1B",
 
97
  "SG161222/Realistic_Vision_V5.1_noVAE",
98
  "Lykon/dreamshaper-8",
99
  "hakurei/waifu-diffusion",
100
  "andite/anything-v4.0",
101
+ "Linaqruf/animagine-xl",
102
+ # Additional NSFW-capable models
103
+ "emilianJR/epiCRealism",
104
+ "stablediffusionapi/deliberate-v2",
105
+ "stablediffusionapi/edge-of-realism",
106
+ "Yntec/epiCPhotoGasm",
107
+ "digiplay/majicMIX_realistic_v7",
108
+ "stablediffusionapi/perfect-world-v6",
109
+ "stablediffusionapi/uber-realistic-merge",
110
+ "XpucT/Deliberate",
111
+ "prompthero/openjourney",
112
+ "Lykon/absolute-reality-1.81",
113
+ "digiplay/BeautyProMix_v2",
114
+ "stablediffusionapi/3d-animation-diffusion",
115
+ "nitrosocke/Ghibli-Diffusion",
116
+ "nitrosocke/mo-di-diffusion",
117
+ "Fictiverse/Stable_Diffusion_VoxelArt_Model"
118
  ]
119
 
120
+ # Chinese Models
121
  CHINESE_MODELS = [
122
+ "AI-Chen/Chinese-Stable-Diffusion",
123
+ "IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1",
124
+ "AI-ModelScope/stable-diffusion-v1-5-chinese"
 
125
  ]
126
 
127
+ ALL_MODELS = SD15_MODELS + SDXL_MODELS + CHINESE_MODELS
 
 
 
 
 
128
 
129
  # ControlNet models
130
  CONTROLNET_MODELS = {
 
134
  "depth": "lllyasviel/control_v11p_sd15_depth",
135
  "normal": "lllyasviel/control_v11p_sd15_normalbae",
136
  "openpose": "lllyasviel/control_v11p_sd15_openpose",
 
137
  "softedge": "lllyasviel/control_v11p_sd15_softedge",
138
  "segmentation": "lllyasviel/control_v11p_sd15_seg",
139
  "mlsd": "lllyasviel/control_v11p_sd15_mlsd",
140
  "shuffle": "lllyasviel/control_v11p_sd15_shuffle",
141
+ "scribble": "lllyasviel/control_v11p_sd15_scribble",
142
+ "tile": "lllyasviel/control_v11f1e_sd15_tile"
 
 
143
  }
144
 
145
+ # SDXL ControlNet models
146
  SDXL_CONTROLNET_MODELS = {
147
  "canny_sdxl": "diffusers/controlnet-canny-sdxl-1.0",
148
+ "depth_sdxl": "diffusers/controlnet-depth-sdxl-1.0",
149
+ "openpose_sdxl": "thibaud/controlnet-openpose-sdxl-1.0"
150
  }
151
 
152
+ # Expanded LoRA models list (including NSFW-capable)
153
  LORA_MODELS = {
154
  "None": None,
155
+ # Style LoRAs
156
  "Lowpoly Game Character": "nerijs/lowpoly-game-character-lora",
 
 
 
157
  "Pixel Art": "nerijs/pixel-art-xl",
158
  "Watercolor Style": "OedoSoldier/watercolor-style-lora",
159
  "Manga Style": "raemikk/Animerge_V3.0_LoRA",
160
+ "Cyberpunk": "artificialguybr/cyberpunk-anime-diffusion",
161
+ "Fantasy Art": "artificialguybr/fantasy-art-lora",
162
+ "Chinese Style": "yfszzx/Chinese_style_xl_LoRA",
163
+ "Traditional Painting": "artificialguybr/Traditional-Painting-Style-LoRA",
164
+ "Anime Art": "Linaqruf/anime-detailer-xl-lora",
165
+ "Cinematic": "artificialguybr/cinematic-diffusion",
166
+ "Oil Painting": "artificialguybr/oil-painting-style",
167
+ # Character/Face LoRAs
168
+ "Japanese Doll": "Norod78/sd15-JapaneseDollLikeness_lora",
169
+ "Korean Doll": "Norod78/sd15-KoreanDollLikeness_lora",
170
+ "Detail Tweaker": "nitrosocke/detail-tweaker-lora",
171
+ "Beautiful Realistic Asians": "etok/Beautiful_Realistic_Asians",
172
+ "Asian Beauty": "digiplay/AsianBeauty_V1",
173
+ "Perfect Hands": "Sanster/perfect-hands",
174
+ "Face Detail": "ostris/face-detail-lora",
175
+ # Body/Pose LoRAs
176
+ "Body Pose Control": "alvdansen/lora-body-pose",
177
+ "Dynamic Poses": "alvdansen/dynamic-poses-lora",
178
+ "Full Body": "artificialguybr/full-body-lora",
179
+ # Realism LoRAs
180
  "Photorealistic": "microsoft/lora-photorealistic",
181
+ "Hyper-Realistic": "dallinmackay/hyper-realistic-lora",
182
+ "Ultra Realistic": "artificialguybr/ultra-realistic-lora",
183
+ "Realistic Vision": "SG161222/Realistic_Vision_V5.1_noVAE",
184
+ # Lighting/Quality LoRAs
185
+ "Add Detail": "ostris/add-detail-lora",
186
+ "Sharp Details": "ostris/sharp-details-lora",
187
+ "Better Lighting": "artificialguybr/better-lighting-lora",
188
+ "Studio Lighting": "artificialguybr/studio-lighting",
189
+ # NSFW-capable LoRAs
190
+ "NSFW Master": "hearmeneigh/nsfw-master-lora",
191
+ "Realistic NSFW": "digiplay/RealisticNSFW_v1",
192
+ "Anime NSFW": "Linaqruf/anime-nsfw-lora",
193
+ "Hentai Diffusion": "Deltaadams/Hentai-Diffusion",
194
+ "Sexy Pose": "alvdansen/sexy-pose-lora"
195
+ }
196
+
197
+ # VAE models for better quality
198
+ VAE_MODELS = {
199
+ "None": None,
200
+ "SD1.5 VAE": "stabilityai/sd-vae-ft-mse",
201
+ "Anime VAE": "hakurei/waifu-diffusion-v1-4",
202
+ "SDXL VAE": "madebyollin/sdxl-vae-fp16-fix"
203
  }
204
 
205
  # Detector instances
206
  DETECTORS = {}
207
 
 
 
 
 
208
  def is_sdxl_model(model_name: str) -> bool:
209
  """Check if model is SDXL"""
210
  return model_name in SDXL_MODELS or "xl" in model_name.lower() or "XL" in model_name
 
258
  else:
259
  raise ValueError(f"Unknown ControlNet type: {controlnet_type}")
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  def prepare_condition_image(image, controlnet_type):
262
  """Prepare condition image for ControlNet"""
263
  if controlnet_type in ["lineart", "lineart_anime"]:
 
290
  result = detector(image, detect_resolution=512, image_resolution=512)
291
  return Image.fromarray(result) if isinstance(result, np.ndarray) else result
292
 
 
293
  return image
294
 
295
+ def get_pipeline(model_name: str, controlnet_type: str = "lineart", lora_model: str = None,
296
+ lora_weight: float = 0.8, vae_model: str = None):
297
+ """Get or create a ControlNet pipeline with optional LoRA and VAE"""
298
  global CURRENT_CONTROLNET_PIPE, CURRENT_CONTROLNET_KEY
299
 
300
+ key = (model_name, controlnet_type, lora_model, lora_weight, vae_model)
301
 
 
302
  if CURRENT_CONTROLNET_KEY == key and CURRENT_CONTROLNET_PIPE is not None:
303
  print(f"✅ Reusing existing ControlNet pipeline: {model_name}, type: {controlnet_type}")
304
  return CURRENT_CONTROLNET_PIPE
305
 
 
306
  if CURRENT_CONTROLNET_PIPE is not None:
307
  print(f"🗑️ Unloading old ControlNet pipeline: {CURRENT_CONTROLNET_KEY}")
308
  del CURRENT_CONTROLNET_PIPE
 
315
  print(f"📥 Loading ControlNet pipeline for model: {model_name}, type: {controlnet_type}")
316
 
317
  try:
 
318
  if is_sdxl_model(model_name):
319
+ if controlnet_type in SDXL_CONTROLNET_MODELS:
320
  controlnet_model_name = get_controlnet_model(controlnet_type)
321
  controlnet = ControlNetModel.from_pretrained(
322
  controlnet_model_name,
323
+ torch_dtype=dtype
324
  ).to(device)
325
 
326
  pipe = StableDiffusionXLPipeline.from_pretrained(
327
  model_name,
328
  controlnet=controlnet,
329
+ torch_dtype=dtype,
330
+ safety_checker=None,
331
  requires_safety_checker=False,
332
  use_safetensors=True,
333
+ variant="fp16" if dtype == torch.float16 else None
334
  ).to(device)
335
  else:
336
+ raise ValueError(f"SDXL model only supports: {list(SDXL_CONTROLNET_MODELS.keys())}")
337
  else:
 
338
  controlnet_model_name = get_controlnet_model(controlnet_type)
339
  controlnet = ControlNetModel.from_pretrained(
340
  controlnet_model_name,
341
+ torch_dtype=dtype
342
  ).to(device)
343
 
344
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
345
  model_name,
346
  controlnet=controlnet,
347
+ torch_dtype=dtype,
348
+ safety_checker=None,
349
  requires_safety_checker=False,
350
  use_safetensors=True,
351
+ variant="fp16" if dtype == torch.float16 else None
352
  ).to(device)
353
 
354
+ # Load custom VAE if specified
355
+ if vae_model and vae_model != "None":
356
+ try:
357
+ from diffusers import AutoencoderKL
358
+ print(f"🔄 Loading custom VAE: {vae_model}")
359
+ vae = AutoencoderKL.from_pretrained(vae_model, torch_dtype=dtype).to(device)
360
+ pipe.vae = vae
361
+ except Exception as e:
362
+ print(f"⚠️ Error loading VAE: {e}")
363
+
364
  # Apply LoRA if specified
365
  if lora_model and lora_model != "None":
366
  print(f"🔄 Applying LoRA: {lora_model} with weight: {lora_weight}")
367
  try:
368
+ pipe.load_lora_weights(lora_model)
369
  pipe.fuse_lora(lora_scale=lora_weight)
370
  except Exception as e:
371
  print(f"⚠️ Error loading LoRA: {e}")
 
 
 
 
 
 
 
 
372
 
373
  # Optimizations
374
  pipe.enable_attention_slicing(slice_size="max")
375
 
 
376
  if hasattr(pipe, 'vae') and hasattr(pipe.vae, 'enable_slicing'):
377
  pipe.vae.enable_slicing()
378
  else:
 
382
  pass
383
 
384
  if device.type == "cuda":
 
385
  try:
386
  pipe.enable_xformers_memory_efficient_attention()
387
+ print("✅ xFormers enabled")
388
  except:
 
389
  pass
 
 
390
  pipe.enable_model_cpu_offload()
391
 
 
 
 
 
 
 
 
 
 
 
392
  try:
393
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
394
+ print("✅ Using Euler Ancestral scheduler")
395
  except:
396
  try:
397
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
 
398
  except:
399
  pass
400
 
 
408
  CURRENT_CONTROLNET_KEY = None
409
  raise
410
 
411
+ def load_t2i_model(model_name: str, lora_model: str = None, lora_weight: float = 0.8,
412
+ vae_model: str = None):
413
+ """Load text-to-image model with optional LoRA and VAE"""
414
  global CURRENT_T2I_PIPE, CURRENT_T2I_MODEL, CURRENT_SDXL_REFINER
415
 
 
416
  use_refiner = "refiner" in model_name.lower()
417
+ key = (model_name, lora_model, lora_weight, vae_model, use_refiner)
 
418
 
419
  if CURRENT_T2I_MODEL == key and CURRENT_T2I_PIPE is not None:
420
  return
421
 
 
422
  if CURRENT_T2I_PIPE is not None:
423
  print(f"🗑️ Unloading old T2I model: {CURRENT_T2I_MODEL}")
424
  del CURRENT_T2I_PIPE
 
434
 
435
  try:
436
  if is_sdxl_model(model_name):
 
437
  if use_refiner:
 
438
  CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
439
  "stabilityai/stable-diffusion-xl-base-1.0",
440
+ torch_dtype=dtype,
441
+ safety_checker=None,
442
  requires_safety_checker=False,
443
  use_safetensors=True,
444
+ variant="fp16" if dtype == torch.float16 else None
445
  ).to(device)
446
 
447
  CURRENT_SDXL_REFINER = StableDiffusionXLPipeline.from_pretrained(
448
  model_name,
449
+ torch_dtype=dtype,
450
  safety_checker=None,
451
  requires_safety_checker=False,
452
  use_safetensors=True,
453
+ variant="fp16" if dtype == torch.float16 else None,
454
  text_encoder_2=CURRENT_T2I_PIPE.text_encoder_2,
455
  vae=CURRENT_T2I_PIPE.vae
456
  ).to(device)
 
457
  else:
458
  CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
459
+ model_name,
460
+ torch_dtype=dtype,
461
+ safety_checker=None,
462
  requires_safety_checker=False,
463
  use_safetensors=True,
464
+ variant="fp16" if dtype == torch.float16 else None
465
  ).to(device)
 
466
  else:
 
467
  CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
468
+ model_name,
469
+ torch_dtype=dtype,
470
+ safety_checker=None,
471
  requires_safety_checker=False,
472
  use_safetensors=True,
473
+ variant="fp16" if dtype == torch.float16 else None
474
  ).to(device)
 
475
 
476
+ # Load custom VAE
477
+ if vae_model and vae_model != "None":
478
+ try:
479
+ from diffusers import AutoencoderKL
480
+ print(f"🔄 Loading custom VAE: {vae_model}")
481
+ vae = AutoencoderKL.from_pretrained(vae_model, torch_dtype=dtype).to(device)
482
+ CURRENT_T2I_PIPE.vae = vae
483
+ except Exception as e:
484
+ print(f"⚠️ Error loading VAE: {e}")
485
+
486
+ # Apply LoRA
487
  if lora_model and lora_model != "None":
488
+ print(f"🔄 Applying LoRA: {lora_model} with weight: {lora_weight}")
489
  try:
490
  CURRENT_T2I_PIPE.load_lora_weights(lora_model)
491
  CURRENT_T2I_PIPE.fuse_lora(lora_scale=lora_weight)
492
  except Exception as e:
493
+ print(f"⚠️ Error loading LoRA: {e}")
494
 
495
  # Optimizations
496
  CURRENT_T2I_PIPE.enable_attention_slicing(slice_size="max")
497
 
 
498
  if hasattr(CURRENT_T2I_PIPE, 'vae') and hasattr(CURRENT_T2I_PIPE.vae, 'enable_slicing'):
499
  CURRENT_T2I_PIPE.vae.enable_slicing()
500
  else:
 
506
  if device.type == "cuda":
507
  try:
508
  CURRENT_T2I_PIPE.enable_xformers_memory_efficient_attention()
 
509
  except:
510
  pass
511
  CURRENT_T2I_PIPE.enable_model_cpu_offload()
512
 
 
513
  try:
514
+ CURRENT_T2I_PIPE.scheduler = EulerAncestralDiscreteScheduler.from_config(CURRENT_T2I_PIPE.scheduler.config)
 
515
  except:
516
+ pass
 
 
 
 
517
 
518
  CURRENT_T2I_MODEL = key
519
 
520
  except Exception as e:
521
+ print(f"❌ Error loading T2I model: {e}")
522
+ CURRENT_T2I_PIPE = None
523
+ CURRENT_T2I_MODEL = None
524
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
+ def colorize(sketch, base_model, controlnet_type, lora_model, lora_weight, vae_model,
 
527
  prompt, negative_prompt, seed, steps, scale, cn_weight):
528
  try:
 
529
  if is_sdxl_model(base_model) and controlnet_type not in SDXL_CONTROLNET_MODELS:
530
  error_img = Image.new('RGB', (512, 512), color='red')
531
  error_msg_img = Image.new('RGB', (512, 512), color='yellow')
 
539
  draw.text((50, 230), f"{', '.join(SDXL_CONTROLNET_MODELS.keys())}", fill="black", font=font)
540
  return error_img, error_msg_img
541
 
542
+ pipe = get_pipeline(base_model, controlnet_type, lora_model, lora_weight, vae_model)
 
543
 
544
+ status_msg = f"🎨 Using: {base_model} + {controlnet_type}"
545
  if lora_model and lora_model != "None":
546
  status_msg += f" + {lora_model}"
547
  print(status_msg)
548
 
 
549
  condition_img = prepare_condition_image(sketch, controlnet_type)
550
 
 
551
  gen = torch.Generator(device=device).manual_seed(int(seed))
552
 
553
  with torch.inference_mode():
554
  out = pipe(
555
+ prompt,
556
  negative_prompt=negative_prompt,
557
+ image=condition_img,
558
  num_inference_steps=int(steps),
559
+ guidance_scale=float(scale),
560
  controlnet_conditioning_scale=float(cn_weight),
561
  generator=gen,
562
  height=512,
563
  width=512
564
  ).images[0]
565
 
 
566
  if device.type == "cuda":
567
  torch.cuda.empty_cache()
568
 
 
572
  error_img = Image.new('RGB', (512, 512), color='red')
573
  return error_img, Image.new('RGB', (512, 512), color='gray')
574
 
575
+ def t2i(prompt, negative_prompt, model, lora_model, lora_weight, vae_model,
576
+ seed, steps, scale, w, h, use_refiner=False):
577
  try:
 
578
  model_to_load = model
579
  if use_refiner and "refiner" not in model.lower():
580
  model_to_load = "stabilityai/stable-diffusion-xl-refiner-1.0"
581
 
582
+ load_t2i_model(model_to_load, lora_model, lora_weight, vae_model)
583
 
584
  print(f"🖼️ Using T2I model: {model}")
585
  if lora_model and lora_model != "None":
 
588
  gen = torch.Generator(device=device).manual_seed(int(seed))
589
 
590
  with torch.inference_mode():
 
591
  if use_refiner and CURRENT_SDXL_REFINER is not None:
 
592
  image = CURRENT_T2I_PIPE(
593
  prompt=prompt,
594
  negative_prompt=negative_prompt,
595
  width=int(w),
596
  height=int(h),
597
+ num_inference_steps=int(steps//2),
598
  guidance_scale=float(scale),
599
  generator=gen,
600
  output_type="latent"
601
  ).images
602
 
 
603
  result = CURRENT_SDXL_REFINER(
604
  prompt=prompt,
605
  negative_prompt=negative_prompt,
606
  image=image,
607
+ num_inference_steps=int(steps//2),
608
  guidance_scale=float(scale),
609
  generator=gen
610
  ).images[0]
611
  else:
 
612
  if is_sdxl_model(model):
613
  width = max(int(w), 512)
614
  height = max(int(h), 512)
615
  result = CURRENT_T2I_PIPE(
616
+ prompt,
617
  negative_prompt=negative_prompt,
618
+ width=width,
619
  height=height,
620
+ num_inference_steps=int(steps),
621
  guidance_scale=float(scale),
622
  generator=gen
623
  ).images[0]
624
  else:
625
  result = CURRENT_T2I_PIPE(
626
+ prompt,
627
  negative_prompt=negative_prompt,
628
+ width=int(w),
629
  height=int(h),
630
+ num_inference_steps=int(steps),
631
  guidance_scale=float(scale),
632
  generator=gen
633
  ).images[0]
 
648
  draw.text((50, 50), f"Error: {str(e)[:50]}...", fill="white", font=font)
649
  return error_img
650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
  def unload_all_models():
652
  global CURRENT_CONTROLNET_PIPE, CURRENT_CONTROLNET_KEY
653
  global DETECTORS
654
  global CURRENT_T2I_PIPE, CURRENT_T2I_MODEL, CURRENT_SDXL_REFINER
 
655
 
656
+ print("🗑️ Unloading all models from memory...")
657
 
 
658
  try:
659
  if CURRENT_CONTROLNET_PIPE is not None:
660
  del CURRENT_CONTROLNET_PIPE
 
663
  pass
664
  CURRENT_CONTROLNET_KEY = None
665
 
 
666
  for detector_type in list(DETECTORS.keys()):
667
  try:
668
  del DETECTORS[detector_type]
 
670
  pass
671
  DETECTORS.clear()
672
 
 
673
  try:
674
  if CURRENT_T2I_PIPE is not None:
675
  del CURRENT_T2I_PIPE
 
686
 
687
  CURRENT_T2I_MODEL = None
688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689
  gc.collect()
690
  if torch.cuda.is_available():
691
  torch.cuda.empty_cache()
 
696
  return "✅ All models unloaded from memory!"
697
 
698
  # ===== Gradio UI =====
699
+ with gr.Blocks(title="🎨 AI Image Generator Pro", theme=gr.themes.Soft()) as demo:
700
+ gr.Markdown("# 🎨 AI Image Generator Pro - NSFW Capable")
701
+ gr.Markdown("### Advanced Image Generation with ControlNet, LoRA & VAE Support")
702
+ gr.Markdown("⚠️ **Content Warning:** This tool can generate NSFW content. Use responsibly and in compliance with applicable laws.")
703
 
 
704
  if torch.cuda.is_available():
705
  gpu_name = torch.cuda.get_device_name(0)
706
  gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
 
708
  else:
709
  gr.Markdown("**⚠️ Running on CPU** - Generation will be slower")
710
 
 
711
  with gr.Row():
712
  unload_btn = gr.Button("🗑️ Unload All Models", variant="stop", scale=1)
713
  status_text = gr.Textbox(label="Status", interactive=False, scale=3)
714
  unload_btn.click(unload_all_models, outputs=status_text)
715
 
716
+ with gr.Tab("🎨 ControlNet Image-to-Image"):
717
  gr.Markdown("""
718
+ ### Transform sketches/images using ControlNet
719
+ - **SD1.5 Models:** Support all ControlNet types
720
+ - **SDXL Models:** Support canny_sdxl, depth_sdxl, openpose_sdxl only
721
  """)
722
 
723
  with gr.Row():
724
+ with gr.Column(scale=1):
725
+ inp = gr.Image(label="Input Sketch/Image", type="pil")
726
+
727
+ gr.Markdown("### Model Settings")
728
+ base_model = gr.Dropdown(
729
+ choices=ALL_MODELS,
730
+ value="digiplay/ChikMix_V3",
731
+ label="Base Model"
732
+ )
733
+ controlnet_type = gr.Dropdown(
734
+ choices=list(CONTROLNET_MODELS.keys()) + list(SDXL_CONTROLNET_MODELS.keys()),
735
+ value="lineart_anime",
736
+ label="ControlNet Type"
737
+ )
738
+
739
+ gr.Markdown("### Enhancement Options")
740
+ with gr.Row():
741
+ lora_model = gr.Dropdown(
742
+ choices=list(LORA_MODELS.keys()),
743
+ value="None",
744
+ label="LoRA Model"
745
+ )
746
+ lora_weight = gr.Slider(0.1, 2.0, 0.8, step=0.1, label="LoRA Weight")
747
+
748
+ vae_model = gr.Dropdown(
749
+ choices=list(VAE_MODELS.keys()),
750
+ value="None",
751
+ label="VAE Model (Optional)"
752
+ )
753
+
754
+ with gr.Column(scale=1):
755
+ out = gr.Image(label="Generated Output")
756
+ condition_out = gr.Image(label="Processed Condition", type="pil")
757
 
758
+ gr.Markdown("### Generation Parameters")
759
  with gr.Row():
760
  prompt = gr.Textbox(
761
+ label="Prompt",
762
+ placeholder="masterpiece, best quality, 1girl, beautiful detailed eyes, long hair",
763
+ lines=3
764
  )
765
  negative_prompt = gr.Textbox(
766
+ label="Negative Prompt",
767
+ placeholder="lowres, bad anatomy, bad hands, text, error, missing fingers",
768
+ lines=3
769
  )
770
 
771
  with gr.Row():
772
+ seed = gr.Number(value=-1, label="Seed (-1 for random)")
773
+ steps = gr.Slider(10, 150, 30, step=1, label="Steps")
774
+ scale = gr.Slider(1, 30, 7.5, step=0.5, label="CFG Scale")
775
  cn_weight = gr.Slider(0.1, 2.0, 1.0, step=0.1, label="ControlNet Weight")
776
 
777
+ run = gr.Button("🎨 Generate", variant="primary", size="lg")
778
  run.click(
779
+ colorize,
780
+ [inp, base_model, controlnet_type, lora_model, lora_weight, vae_model,
781
+ prompt, negative_prompt, seed, steps, scale, cn_weight],
782
  [out, condition_out]
783
  )
784
+
785
+ gr.Markdown("""
786
+ ### Tips for Better Results:
787
+ - Use detailed prompts for better control
788
+ - Adjust ControlNet weight to balance between condition and creativity
789
+ - Try different LoRA models for various styles
790
+ - Higher steps = better quality but slower generation
791
+ """)
792
 
793
+ with gr.Tab("🖼️ Text-to-Image Generation"):
794
  gr.Markdown("""
795
  ### Generate images from text descriptions
796
+ Supports both SD1.5 and SDXL models with advanced features
 
797
  """)
798
 
799
  with gr.Row():
800
+ with gr.Column(scale=1):
801
+ gr.Markdown("### Model Configuration")
802
+ t2i_model = gr.Dropdown(
803
+ choices=ALL_MODELS,
804
+ value="digiplay/ChikMix_V3",
805
+ label="Base Model"
806
+ )
807
+
808
+ gr.Markdown("### Enhancement Options")
809
+ with gr.Row():
810
+ t2i_lora = gr.Dropdown(
811
+ choices=list(LORA_MODELS.keys()),
812
+ value="None",
813
+ label="LoRA Model"
814
+ )
815
+ t2i_lora_weight = gr.Slider(0.1, 2.0, 0.8, step=0.1, label="LoRA Weight")
816
+
817
+ t2i_vae = gr.Dropdown(
818
+ choices=list(VAE_MODELS.keys()),
819
+ value="None",
820
+ label="VAE Model"
821
+ )
822
+
823
+ use_refiner = gr.Checkbox(
824
+ label="Use SDXL Refiner (SDXL only)",
825
+ value=False
826
+ )
827
+
828
+ with gr.Column(scale=1):
829
+ t2i_out = gr.Image(label="Generated Image", type="pil")
830
 
831
+ gr.Markdown("### Prompts")
832
  with gr.Row():
833
  t2i_prompt = gr.Textbox(
834
+ label="Prompt",
835
+ lines=4,
836
+ placeholder="masterpiece, best quality, highly detailed, 8k, photorealistic, beautiful lighting"
837
  )
838
  t2i_negative_prompt = gr.Textbox(
839
+ label="Negative Prompt",
840
+ lines=4,
841
+ placeholder="lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
 
 
 
 
 
 
 
 
 
 
 
 
842
  )
 
843
 
844
+ gr.Markdown("### Generation Parameters")
845
  with gr.Row():
846
+ t2i_seed = gr.Number(value=-1, label="Seed (-1 for random)")
847
+ t2i_steps = gr.Slider(10, 150, 30, step=1, label="Steps")
848
+ t2i_scale = gr.Slider(1, 30, 7.5, step=0.5, label="CFG Scale")
849
 
850
  with gr.Row():
851
+ w = gr.Slider(256, 2048, 512, step=64, label="Width")
852
+ h = gr.Slider(256, 2048, 768, step=64, label="Height")
 
853
 
854
+ gen_btn = gr.Button("🖼️ Generate Image", variant="primary", size="lg")
855
  gen_btn.click(
856
+ t2i,
857
+ [t2i_prompt, t2i_negative_prompt, t2i_model, t2i_lora, t2i_lora_weight,
858
+ t2i_vae, t2i_seed, t2i_steps, t2i_scale, w, h, use_refiner],
859
  t2i_out
860
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
861
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
862
  gr.Markdown("""
863
+ ### Pro Tips:
864
+ - **SDXL models** produce higher quality at 1024x1024
865
+ - **SD1.5 models** work best at 512x512 or 512x768
866
+ - Use **LoRA** for specific styles (anime, realistic, etc.)
867
+ - Use **VAE** for better colors and details
868
+ - **Refiner** adds extra polish to SDXL generations
869
+ - Higher **CFG Scale** = more prompt adherence
 
 
 
 
 
 
 
 
 
 
 
 
 
870
  """)
871
+
872
+ with gr.Tab("📚 Quick Reference"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873
  gr.Markdown("""
874
+ # Model & Feature Guide
875
+
876
+ ## 🎯 Recommended Models for Different Purposes
877
+
878
+ ### Realistic/Photorealistic
879
+ - `emilianJR/epiCRealism` - Excellent for realistic portraits
880
+ - `stablediffusionapi/realistic-vision-v51` - High quality realistic images
881
+ - `digiplay/majicMIX_realistic_v7` - Great for realistic characters
882
+ - `SG161222/RealVisXL_V4.0` - SDXL realistic model
883
+
884
+ ### Anime/Cartoon
885
+ - `digiplay/ChikMix_V3` - Versatile anime style
886
+ - `gsdf/Counterfeit-V2.5` - High quality anime
887
+ - `stablediffusionapi/anything-v5` - Popular anime model
888
+ - `digiplay/Pony_Diffusion_V6_XL` - SDXL anime model
889
+
890
+ ### Artistic/Stylized
891
+ - `stablediffusionapi/dreamshaper-v8` - Dream-like artistic style
892
+ - `wavymulder/Analog-Diffusion` - Analog photo aesthetic
893
+ - `Lykon/dreamshaper-xl-1-0` - SDXL artistic model
894
+
895
+ ## 🎨 ControlNet Types Explained
896
+
897
+ - **lineart/lineart_anime**: Convert line drawings to colored images
898
+ - **canny**: Edge detection based generation
899
+ - **depth**: Depth map based generation
900
+ - **openpose**: Human pose based generation
901
+ - **normal**: Normal map based generation
902
+ - **softedge**: Soft edge detection
903
+ - **scribble**: Scribble to image
904
+ - **tile**: Upscaling and detail enhancement
905
+
906
+ ## 💎 Popular LoRA Combinations
907
+
908
+ ### For Portraits
909
+ - Base: `digiplay/majicMIX_realistic_v7`
910
+ - LoRA: `Detail Tweaker` or `Face Detail`
911
+ - VAE: `SD1.5 VAE`
912
+
913
+ ### For Anime Characters
914
+ - Base: `digiplay/ChikMix_V3`
915
+ - LoRA: `Anime Art` or `Manga Style`
916
+ - VAE: `Anime VAE`
917
+
918
+ ### For NSFW Content
919
+ - Base: Any NSFW-capable model
920
+ - LoRA: `NSFW Master`, `Realistic NSFW`, or `Anime NSFW`
921
+ - Note: Always use responsibly and legally
922
+
923
+ ## ⚙️ Parameter Guidelines
924
+
925
+ ### Steps
926
+ - **20-30**: Fast, good quality
927
+ - **30-50**: Balanced
928
+ - **50-100**: High quality, slow
929
+
930
+ ### CFG Scale
931
+ - **5-7**: Creative, loose interpretation
932
+ - **7-10**: Balanced
933
+ - **10-15**: Strict prompt adherence
934
+ - **15+**: Very strict, may oversaturate
935
+
936
+ ### Resolution
937
+ - **SD1.5**: 512x512, 512x768, 768x512
938
+ - **SDXL**: 1024x1024, 1024x1536, 1536x1024
939
+
940
+ ## 🔞 NSFW Generation Guidelines
941
+
942
+ 1. Use NSFW-capable base models
943
+ 2. Apply relevant LoRA for style enhancement
944
+ 3. Use detailed prompts
945
+ 4. Adjust CFG scale (7-12 recommended)
946
+ 5. Consider using higher steps (40-60)
947
+ 6. **Always comply with local laws and regulations**
948
+
949
+ ## 🚀 Performance Tips
950
+
951
+ - Unload models when switching between different types
952
+ - Use lower resolutions for testing
953
+ - Enable xFormers if available (automatic)
954
+ - Use appropriate batch sizes for your GPU
955
+ - Monitor GPU memory usage
956
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
957
 
958
  try:
959
  demo.launch(
960
+ server_name="0.0.0.0",
961
+ server_port=7860,
962
  share=False,
963
  show_error=True,
964
  quiet=False