K1Z3M1112 commited on
Commit
332035a
·
verified ·
1 Parent(s): 24e1d83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -977
app.py CHANGED
@@ -1,996 +1,141 @@
1
- import os
2
  import gradio as gr
3
  import numpy as np
4
- import cv2
5
  from PIL import Image
6
  import torch
7
- from gradio.themes import Soft
8
- from gradio.themes.utils import colors, fonts, sizes
9
  import gc
10
 
11
- colors.steel_blue = colors.Color(
12
- name="steel_blue",
13
- c50="#EBF3F8",
14
- c100="#D3E5F0",
15
- c200="#A8CCE1",
16
- c300="#7DB3D2",
17
- c400="#529AC3",
18
- c500="#4682B4",
19
- c600="#3E72A0",
20
- c700="#36638C",
21
- c800="#2E5378",
22
- c900="#264364",
23
- c950="#1E3450",
24
- )
25
-
26
- class SteelBlueTheme(Soft):
27
- def __init__(self, **kwargs):
28
- super().__init__(
29
- primary_hue=colors.gray,
30
- secondary_hue=colors.steel_blue,
31
- neutral_hue=colors.slate,
32
- text_size=sizes.text_lg,
33
- font=(fonts.GoogleFont("Outfit"), "Arial", "sans-serif"),
34
- font_mono=(fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace"),
35
- )
36
-
37
- steel_blue_theme = SteelBlueTheme()
38
-
39
- print("=" * 50)
40
- print("🎨 Style2Paints - Uncensored Line Art Colorization & Text-to-Image")
41
- print("=" * 50)
42
-
43
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
45
 
46
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DDIMScheduler, StableDiffusionPipeline, StableDiffusionXLPipeline, EulerDiscreteScheduler
 
47
  from controlnet_aux import LineartDetector, LineartAnimeDetector
48
 
49
- # ===== 模型配置 =====
50
- AVAILABLE_MODELS = {
51
- "LineArt Colorization": ["Anything V3 (ControlNet)"],
52
- "Text-to-Image": [
53
- "Linaqruf/anything-v3.0",
54
- "digiplay/ChikMix_V3",
55
- "digiplay/chilloutmix_NiPrunedFp16Fix",
56
- "LyliaEngine/Pony_Diffusion_V6_XL",
57
- "wootwoot/abyssorangemix3-popupparade-fp16",
58
- "John6666/wai-nsfw-illustrious-v80-sdxl"
59
- ]
60
- }
61
-
62
- MODEL_CONFIGS = {
63
- "Linaqruf/anything-v3.0": {
64
- "type": "sd15",
65
- "description": "Anything V3 - 全能模型",
66
- "default_resolution": (512, 768)
67
- },
68
- "digiplay/ChikMix_V3": {
69
- "type": "sd15",
70
- "description": "ChikMix V3 - 高质量动漫模型",
71
- "default_resolution": (512, 768)
72
- },
73
- "digiplay/chilloutmix_NiPrunedFp16Fix": {
74
- "type": "sd15",
75
- "description": "ChilloutMix - 真人风格",
76
- "default_resolution": (512, 768)
77
- },
78
- "LyliaEngine/Pony_Diffusion_V6_XL": {
79
- "type": "sdxl",
80
- "description": "Pony Diffusion V6 XL - SDXL动漫模型",
81
- "default_resolution": (1024, 1024)
82
- },
83
- "wootwoot/abyssorangemix3-popupparade-fp16": {
84
- "type": "sd15",
85
- "description": "AbyssOrangeMix3 - 色彩鲜艳",
86
- "default_resolution": (512, 768)
87
- },
88
- "John6666/wai-nsfw-illustrious-v80-sdxl": {
89
- "type": "sdxl",
90
- "description": "WAI NSFW Illustrious - SDXL成人内容优化",
91
- "default_resolution": (1024, 1024)
92
- }
93
- }
94
-
95
- # ===== 全局模型变量 =====
96
- pipe_standard = None
97
- pipe_anime = None
98
- lineart_detector = None
99
- lineart_anime_detector = None
100
- current_t2i_model = None
101
- current_t2i_pipe = None
102
-
103
- def load_text_to_image_model(model_name, progress=gr.Progress()):
104
- """动态加载文本到图像模型"""
105
- global current_t2i_model, current_t2i_pipe
106
-
107
- if model_name == current_t2i_model and current_t2i_pipe is not None:
108
- print(f"✅ 模型 {model_name} 已加载,跳过重新加载")
109
- return True
110
-
111
- try:
112
- # 清理之前的模型
113
- if current_t2i_pipe is not None:
114
- del current_t2i_pipe
115
- current_t2i_pipe = None
116
- current_t2i_model = None
117
- gc.collect()
118
- if torch.cuda.is_available():
119
- torch.cuda.empty_cache()
120
-
121
- print(f"🔄 正在加载模型: {model_name}")
122
- progress(0.3, desc=f"正在加载 {model_name}")
123
-
124
- model_config = MODEL_CONFIGS.get(model_name, {})
125
- model_type = model_config.get("type", "sd15")
126
-
127
- if model_type == "sdxl":
128
- # SDXL 模型
129
- pipe = StableDiffusionXLPipeline.from_pretrained(
130
- model_name,
131
- torch_dtype=dtype,
132
- safety_checker=None,
133
- requires_safety_checker=False,
134
- use_safetensors=True,
135
- variant="fp16" if dtype == torch.float16 else None
136
- ).to(device)
137
-
138
- # SDXL 推荐使用 Euler scheduler
139
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
140
-
141
- else:
142
- # SD1.5 模型
143
- pipe = StableDiffusionPipeline.from_pretrained(
144
- model_name,
145
- torch_dtype=dtype,
146
- safety_checker=None,
147
- requires_safety_checker=False,
148
- use_safetensors=True,
149
- variant="fp16" if dtype == torch.float16 else None
150
- ).to(device)
151
-
152
- # SD1.5 使用 DDIM
153
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
154
-
155
- # 优化设置
156
- if device.type == "cuda":
157
- pipe.enable_model_cpu_offload()
158
- try:
159
- pipe.enable_xformers_memory_efficient_attention()
160
- except:
161
- print("⚠️ XFormers 不可用,跳过")
162
- pipe.enable_attention_slicing()
163
-
164
- current_t2i_model = model_name
165
- current_t2i_pipe = pipe
166
-
167
- print(f"✅ 模型 {model_name} 加载成功!")
168
- progress(1.0, desc="模型加载完成")
169
- return True
170
-
171
- except Exception as e:
172
- import traceback
173
- print(f"❌ 加载模型失败: {str(e)}")
174
- print(f"详细错误: {traceback.format_exc()}")
175
- return False
176
 
177
  def load_lineart_models():
178
- """加载线稿着色模型"""
179
- global pipe_standard, pipe_anime, lineart_detector, lineart_anime_detector
180
-
181
- try:
182
- print("🔄 加载线稿着色模型...")
183
-
184
- # Load STANDARD lineart ControlNet
185
- print("📦 加载标准线稿模型...")
186
- controlnet_standard = ControlNetModel.from_pretrained(
187
- "lllyasviel/control_v11p_sd15_lineart",
188
- torch_dtype=dtype
189
  ).to(device)
190
-
191
- # Load ANIME lineart ControlNet
192
- print("📦 加载动漫线稿模型...")
193
- controlnet_anime = ControlNetModel.from_pretrained(
194
- "lllyasviel/control_v11p_sd15s2_lineart_anime",
195
- torch_dtype=dtype
196
  ).to(device)
197
-
198
- # 使用 Anything V3 作为基础模型
199
- print("📦 加载基础模型 (Anything V3)...")
200
- pipe_standard = StableDiffusionControlNetPipeline.from_pretrained(
201
- "Linaqruf/anything-v3.0",
202
- controlnet=controlnet_standard,
203
- torch_dtype=dtype,
204
- safety_checker=None,
205
- requires_safety_checker=False
206
- ).to(device)
207
-
208
- pipe_anime = StableDiffusionControlNetPipeline.from_pretrained(
209
- "Linaqruf/anything-v3.0",
210
- controlnet=controlnet_anime,
211
- torch_dtype=dtype,
212
- safety_checker=None,
213
- requires_safety_checker=False
214
- ).to(device)
215
-
216
- # 配置两个管道
217
- for pipe in [pipe_standard, pipe_anime]:
218
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
219
-
220
  if device.type == "cuda":
221
  pipe.enable_model_cpu_offload()
222
- try:
223
- pipe.enable_xformers_memory_efficient_attention()
224
- except:
225
- pass
226
- pipe.enable_attention_slicing()
227
-
228
- # 加载线稿检测器
229
- print("📦 加载线稿检测器...")
230
- lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
231
- lineart_anime_detector = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
232
-
233
- print("✅ 线稿着色模型加载成功!")
234
- return True
235
-
236
- except Exception as e:
237
- print(f"❌ 加载线稿模型失败: {e}")
238
- return False
239
-
240
- # 加载线稿着色模型
241
- load_lineart_models()
242
-
243
- COLOR_STYLES = {
244
- "Anime Style": "anime, masterpiece, best quality, highly detailed, vibrant colors",
245
- "Manga Color": "manga coloring, cel shading, clean colors, professional",
246
- "Soft Shading": "soft shading, gradient colors, smooth, gentle lighting",
247
- "Vibrant": "vibrant colors, saturated, bold colors, eye-catching",
248
- "Realistic Skin": "realistic skin tones, detailed anatomy, natural colors",
249
- "Pastel Soft": "pastel colors, soft aesthetic, gentle tones, dreamy",
250
- "Dark Moody": "dark colors, moody lighting, dramatic shadows, cinematic",
251
- "Watercolor": "watercolor style, artistic, painterly, soft edges",
252
- }
253
-
254
- CONTENT_TEMPLATES = {
255
- "Character Portrait": "1girl, solo, portrait, detailed face, beautiful",
256
- "Full Body": "1girl, full body, standing, detailed",
257
- "Multiple Characters": "2girls, multiple girls, detailed",
258
- "Pin-up Style": "1girl, posing, detailed body, attractive pose",
259
- "Action Scene": "1girl, dynamic pose, action, movement",
260
- "Intimate Scene": "2girls, close together, intimate",
261
- "Custom": ""
262
- }
263
-
264
- def is_already_lineart(image):
265
- """检查图像是否已经是线稿"""
266
- if isinstance(image, Image.Image):
267
- image = np.array(image)
268
-
269
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) if len(image.shape) == 3 else image
270
- unique_vals = len(np.unique(gray))
271
- black_white_ratio = np.sum((gray < 50) | (gray > 200)) / gray.size
272
-
273
- return black_white_ratio > 0.7 or unique_vals < 30
274
-
275
- def extract_lineart(image, lineart_type="Standard", skip_if_lineart=True):
276
- """从图像中提取线稿"""
277
- if isinstance(image, np.ndarray):
278
- image = Image.fromarray(image)
279
-
280
- if skip_if_lineart and is_already_lineart(image):
281
- print("✅ 已经是线稿,跳过提取")
282
- return image.convert('RGB')
283
-
284
- print(f"🔄 提取线稿 ({lineart_type})...")
285
-
286
- if lineart_type == "Anime" and lineart_anime_detector is not None:
287
- lineart = lineart_anime_detector(image, detect_resolution=512, image_resolution=512)
288
- else:
289
- lineart = lineart_detector(image, detect_resolution=512, image_resolution=512)
290
-
291
- if isinstance(lineart, np.ndarray):
292
- lineart = Image.fromarray(lineart)
293
-
294
- return lineart
295
 
296
- def colorize_lineart(
297
- sketch_image,
298
- lineart_type,
299
- content_type,
300
- style,
301
- custom_prompt,
302
- quality_tags,
303
- nsfw_level,
304
- seed,
305
- randomize_seed,
306
- guidance_scale,
307
- num_steps,
308
- controlnet_strength,
309
- progress=gr.Progress(track_tqdm=True)
310
- ):
311
- """线稿着色"""
312
- if sketch_image is None:
313
- raise gr.Error("请上传线稿图像")
314
-
315
- # 根据线稿类型选择管道
316
- if lineart_type == "Anime" and pipe_anime is not None:
317
- pipe = pipe_anime
318
- print("🎨 使用动漫线稿模型")
319
- else:
320
- pipe = pipe_standard
321
- print("🎨 使用标准线稿模型")
322
-
323
- if pipe is None:
324
- raise gr.Error("模型未加载")
325
-
326
- # 转换数值输入
327
- seed = int(seed)
328
- guidance_scale = float(guidance_scale)
329
- num_steps = int(num_steps)
330
- controlnet_strength = float(controlnet_strength)
331
-
332
- if randomize_seed:
333
- import random
334
- seed = random.randint(0, 2**32-1)
335
-
336
- generator = torch.Generator(device=device).manual_seed(seed)
337
-
338
- # 转换和调整大小
339
- if isinstance(sketch_image, np.ndarray):
340
- sketch_image = Image.fromarray(sketch_image)
341
-
342
- width, height = sketch_image.size
343
- max_size = 512
344
-
345
- if width > max_size or height > max_size:
346
- if width > height:
347
- new_width = max_size
348
- new_height = int(height * (max_size / width))
349
- else:
350
- new_height = max_size
351
- new_width = int(width * (max_size / height))
352
-
353
- new_width = (new_width // 8) * 8
354
- new_height = (new_height // 8) * 8
355
- sketch_image = sketch_image.resize((new_width, new_height), Image.LANCZOS)
356
-
357
- # 提取线稿
358
- lineart = extract_lineart(sketch_image, lineart_type=lineart_type, skip_if_lineart=True)
359
-
360
- # 构建提示词
361
- prompt_parts = []
362
-
363
- # 内容模板
364
- content_template = CONTENT_TEMPLATES.get(content_type, "")
365
- if content_template:
366
- prompt_parts.append(content_template)
367
-
368
- # 自定义提示词
369
- if custom_prompt.strip():
370
- prompt_parts.append(custom_prompt.strip())
371
-
372
- # 质量标签
373
- if quality_tags:
374
- prompt_parts.append(quality_tags)
375
-
376
- # 风格
377
- style_prompt = COLOR_STYLES.get(style, COLOR_STYLES["Anime Style"])
378
- prompt_parts.append(style_prompt)
379
-
380
- # NSFW 级别标签
381
- if nsfw_level == "Safe":
382
- nsfw_tags = "sfw, safe for work"
383
- elif nsfw_level == "Suggestive":
384
- nsfw_tags = "suggestive, slightly revealing"
385
- elif nsfw_level == "Mild":
386
- nsfw_tags = "nsfw, ecchi, revealing clothing"
387
- elif nsfw_level == "Moderate":
388
- nsfw_tags = "nsfw, nude, explicit"
389
- else: # Explicit
390
- nsfw_tags = "nsfw, explicit, uncensored"
391
-
392
- prompt_parts.append(nsfw_tags)
393
-
394
- full_prompt = ", ".join(prompt_parts)
395
-
396
- # 负面提示词
397
- negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, jpeg artifacts, signature, watermark, username, blurry, artist name, black and white, monochrome"
398
-
399
- if nsfw_level in ["Moderate", "Explicit"]:
400
- negative_prompt = negative_prompt.replace("nsfw, ", "")
401
-
402
- print(f"🎨 提示词: {full_prompt}")
403
- print(f"🎛️ ControlNet 强度: {controlnet_strength}")
404
- print(f"🖼️ 线稿类型: {lineart_type}")
405
-
406
- try:
407
- progress(0.3, desc="正在生成颜色...")
408
-
409
- result = pipe(
410
- prompt=full_prompt,
411
- negative_prompt=negative_prompt,
412
- image=lineart,
413
- num_inference_steps=num_steps,
414
- guidance_scale=guidance_scale,
415
- controlnet_conditioning_scale=controlnet_strength,
416
- generator=generator,
417
- ).images[0]
418
-
419
- if device.type == "cuda":
420
  torch.cuda.empty_cache()
421
-
422
- return result, lineart, seed, full_prompt
423
-
424
- except Exception as e:
425
- import traceback
426
- print(f"❌ 完整错误: {traceback.format_exc()}")
427
- raise gr.Error(f"错误: {str(e)}")
428
-
429
- def generate_text_to_image(
430
- model_name,
431
- content_type,
432
- style,
433
- custom_prompt,
434
- quality_tags,
435
- nsfw_level,
436
- seed,
437
- randomize_seed,
438
- guidance_scale,
439
- num_steps,
440
- width,
441
- height,
442
- progress=gr.Progress(track_tqdm=True)
443
- ):
444
- """文本到图像生成"""
445
- # 加载模型
446
- if not load_text_to_image_model(model_name, progress):
447
- raise gr.Error(f"无法加载模型: {model_name}")
448
-
449
- # 转换数值输入
450
- seed = int(seed)
451
- guidance_scale = float(guidance_scale)
452
- num_steps = int(num_steps)
453
- width = int(width)
454
- height = int(height)
455
-
456
- if randomize_seed:
457
- import random
458
- seed = random.randint(0, 2**32-1)
459
-
460
- generator = torch.Generator(device=device).manual_seed(seed)
461
-
462
- # 构建提示词
463
- prompt_parts = []
464
-
465
- # 内容模板
466
- content_template = CONTENT_TEMPLATES.get(content_type, "")
467
- if content_template:
468
- prompt_parts.append(content_template)
469
-
470
- # 自定义提示词
471
- if custom_prompt.strip():
472
- prompt_parts.append(custom_prompt.strip())
473
-
474
- # 质量标签
475
- if quality_tags:
476
- prompt_parts.append(quality_tags)
477
-
478
- # 风格
479
- style_prompt = COLOR_STYLES.get(style, COLOR_STYLES["Anime Style"])
480
- prompt_parts.append(style_prompt)
481
-
482
- # NSFW 级别标签
483
- if nsfw_level == "Safe":
484
- nsfw_tags = "sfw, safe for work"
485
- elif nsfw_level == "Suggestive":
486
- nsfw_tags = "suggestive, slightly revealing"
487
- elif nsfw_level == "Mild":
488
- nsfw_tags = "nsfw, ecchi, revealing clothing"
489
- elif nsfw_level == "Moderate":
490
- nsfw_tags = "nsfw, nude, explicit"
491
- else: # Explicit
492
- nsfw_tags = "nsfw, explicit, uncensored"
493
-
494
- prompt_parts.append(nsfw_tags)
495
-
496
- full_prompt = ", ".join(prompt_parts)
497
-
498
- # 负面提示词
499
- negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, jpeg artifacts, signature, watermark, username, blurry, artist name, black and white, monochrome"
500
-
501
- if nsfw_level in ["Moderate", "Explicit"]:
502
- negative_prompt = negative_prompt.replace("nsfw, ", "")
503
-
504
- print(f"🎨 提示词: {full_prompt}")
505
- print(f"📐 分辨率: {width}x{height}")
506
- print(f"🎛️ 引导尺度: {guidance_scale}")
507
- print(f"🔄 步数: {num_steps}")
508
-
509
- try:
510
- progress(0.5, desc="正在生成图像...")
511
-
512
- result = current_t2i_pipe(
513
- prompt=full_prompt,
514
- negative_prompt=negative_prompt,
515
- width=width,
516
- height=height,
517
- num_inference_steps=num_steps,
518
- guidance_scale=guidance_scale,
519
- generator=generator,
520
- ).images[0]
521
-
522
- if device.type == "cuda":
523
- torch.cuda.empty_cache()
524
-
525
- return result, seed, full_prompt
526
-
527
- except Exception as e:
528
- import traceback
529
- print(f"❌ 完整错误: {traceback.format_exc()}")
530
- raise gr.Error(f"错误: {str(e)}")
531
-
532
- def update_resolution_from_model(model_name):
533
- """根据选择的模型更新推荐分辨率"""
534
- config = MODEL_CONFIGS.get(model_name, {})
535
- default_res = config.get("default_resolution", (512, 768))
536
- description = config.get("description", "通用模型")
537
-
538
- width, height = default_res
539
- return (
540
- gr.update(value=width, minimum=256, maximum=2048, step=8),
541
- gr.update(value=height, minimum=256, maximum=2048, step=8),
542
- gr.update(value=f"📊 推荐分辨率: {width}x{height} ({description})")
543
- )
544
-
545
- css="""
546
- #col-container {
547
- margin: 0 auto;
548
- max-width: 1600px;
549
- }
550
- #main-title h1 {
551
- font-size: 2.8em !important;
552
- text-align: center;
553
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
554
- -webkit-background-clip: text;
555
- -webkit-text-fill-color: transparent;
556
- background-clip: text;
557
- }
558
- .feature-box {
559
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
560
- color: white;
561
- border-radius: 12px;
562
- padding: 25px;
563
- margin: 20px 0;
564
- }
565
- .warning-box {
566
- background: #fff3cd;
567
- border-left: 4px solid #ffc107;
568
- padding: 15px;
569
- margin: 15px 0;
570
- border-radius: 8px;
571
- color: #856404;
572
- }
573
- .info-box {
574
- background: #f0f7ff;
575
- border-left: 4px solid #4682B4;
576
- padding: 15px;
577
- margin: 10px 0;
578
- border-radius: 8px;
579
- }
580
- .model-badge {
581
- display: inline-block;
582
- padding: 5px 12px;
583
- background: #4682B4;
584
- color: white;
585
- border-radius: 20px;
586
- font-size: 0.9em;
587
- margin: 5px;
588
- }
589
- .tab-buttons {
590
- margin-bottom: 20px;
591
- }
592
- .tab-nav {
593
- border-bottom: 2px solid #e0e0e0;
594
- }
595
- """
596
-
597
- with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
598
- with gr.Column(elem_id="col-container"):
599
- gr.Markdown("# 🎨 Style2Paints - 全能图像生成", elem_id="main-title")
600
- gr.Markdown("### ✨ 专业线稿着色与文本生成图像")
601
-
602
- gr.HTML("""
603
- <div class="warning-box">
604
- <strong>⚠️ 内容警告:</strong> 此工具支持所有类型内容的生成,包括 NSFW/成人内容。
605
- 请负责任地使用,并确保符合当地法律。生成成人内容必须年满18岁。
606
- </div>
607
- """)
608
-
609
- gr.HTML("""
610
- <div class="feature-box">
611
- <h3>✨ 核心功能</h3>
612
- <ul style="color:white; font-size:1.1em;">
613
- <li>🎨 <strong>双线稿模型</strong> - 标准和动漫专用线稿检测</li>
614
- <li>🖼️ <strong>文本生成图像</strong> - 从文本描述直接生成图像</li>
615
- <li>🎭 <strong>多模型支持</strong> - 6种不同风格的模型可选</li>
616
- <li>📝 <strong>内容模板</strong> - 常见场景的预设提示词</li>
617
- <li>🎚️ <strong>NSFW级别控制</strong> - 精确的内容级别控制</li>
618
- <li>⚡ <strong>智能模型加载</strong> - 按需加载,节省显存</li>
619
- </ul>
620
- <div style="margin-top:15px;">
621
- <span class="model-badge">6种文本生成模型</span>
622
- <span class="model-badge">2种线稿模型</span>
623
- </div>
624
- </div>
625
- """)
626
-
627
- with gr.Tabs() as tabs:
628
- # ===== 标签页 1: 线稿着色 =====
629
- with gr.TabItem("🎨 线稿着色"):
630
- with gr.Row():
631
- with gr.Column(scale=1):
632
- input_image = gr.Image(
633
- label="📤 上传线稿",
634
- type="pil",
635
- height=400
636
- )
637
-
638
- gr.Markdown("### 🎨 内容设置")
639
-
640
- lineart_type = gr.Radio(
641
- choices=["Standard", "Anime"],
642
- label="🖊️ 线稿模型",
643
- value="Anime",
644
- info="动漫模型更适合动漫/漫画风格"
645
- )
646
-
647
- content_type = gr.Dropdown(
648
- choices=list(CONTENT_TEMPLATES.keys()),
649
- label="📋 内容模板",
650
- value="Character Portrait",
651
- info="提示词起点"
652
- )
653
-
654
- custom_prompt = gr.Textbox(
655
- label="✍️ 详细描述 (重要)",
656
- placeholder="描述您想要的内容:发色、服装、姿势、背景、身体特征等",
657
- lines=3,
658
- info="请具体描述!这是最重要的字段。"
659
- )
660
-
661
- gr.HTML("""
662
- <div class="info-box">
663
- <strong>💡 提示词示例:</strong><br>
664
- • "金发,蓝眼睛,女仆装,丰满"<br>
665
- • "红色马尾辫,校服,短裙,过膝袜"<br>
666
- • "白发,猫耳,裸体,躺在床上"<br>
667
- • "两个女孩,接吻,亲密,卧室"
668
- </div>
669
- """)
670
-
671
- with gr.Row():
672
- style = gr.Dropdown(
673
- choices=list(COLOR_STYLES.keys()),
674
- label="🎨 颜色风格",
675
- value="Anime Style"
676
- )
677
-
678
- nsfw_level = gr.Dropdown(
679
- choices=["Safe", "Suggestive", "Mild", "Moderate", "Explicit"],
680
- label="🔞 内容级别",
681
- value="Moderate",
682
- info="内容明确程度"
683
- )
684
-
685
- quality_tags = gr.Textbox(
686
- label="⭐ 质量标签 (可选)",
687
- placeholder="masterpiece, best quality, highly detailed",
688
- value="masterpiece, best quality, highly detailed"
689
- )
690
-
691
- colorize_button = gr.Button("✨ 开始着色!", variant="primary", size="lg")
692
-
693
- with gr.Column(scale=2):
694
- with gr.Row():
695
- lineart_output = gr.Image(
696
- label="🖊️ 提取的线稿",
697
- type="pil",
698
- height=380
699
- )
700
- output_image = gr.Image(
701
- label="🎨 着色结果",
702
- type="pil",
703
- height=380
704
- )
705
-
706
- generated_prompt = gr.Textbox(
707
- label="📝 生成的提示词",
708
- lines=3,
709
- interactive=False,
710
- show_copy_button=True
711
- )
712
-
713
- with gr.Accordion("⚙️ 高级设置", open=True):
714
- with gr.Row():
715
- seed = gr.Slider(
716
- label="🎲 种子",
717
- minimum=0,
718
- maximum=2**32-1,
719
- step=1,
720
- value=42
721
- )
722
- randomize_seed = gr.Checkbox(
723
- label="🔀 随机种子",
724
- value=True
725
- )
726
-
727
- with gr.Row():
728
- guidance_scale = gr.Slider(
729
- label="💬 引导尺度",
730
- minimum=5.0,
731
- maximum=15.0,
732
- step=0.5,
733
- value=8.0,
734
- info="7-9 推荐用于 NSFW"
735
- )
736
-
737
- num_steps = gr.Slider(
738
- label="🔢 步数",
739
- minimum=10,
740
- maximum=30,
741
- step=5,
742
- value=20,
743
- info="20 是良好平衡"
744
- )
745
-
746
- controlnet_strength = gr.Slider(
747
- label="🎛️ 线稿保留强度",
748
- minimum=0.5,
749
- maximum=1.5,
750
- step=0.1,
751
- value=1.0,
752
- info="严格遵循线稿的程度"
753
- )
754
-
755
- colorize_button.click(
756
- fn=colorize_lineart,
757
- inputs=[
758
- input_image, lineart_type, content_type, style, custom_prompt, quality_tags, nsfw_level,
759
- seed, randomize_seed, guidance_scale, num_steps, controlnet_strength
760
- ],
761
- outputs=[output_image, lineart_output, seed, generated_prompt]
762
- )
763
-
764
- # ===== 标签页 2: 文本生成图像 =====
765
- with gr.TabItem("🖼️ 文本生成图像"):
766
- with gr.Row():
767
- with gr.Column(scale=1):
768
- gr.Markdown("### 🤖 模型选择")
769
-
770
- model_selector = gr.Dropdown(
771
- choices=AVAILABLE_MODELS["Text-to-Image"],
772
- label="🎯 选择模型",
773
- value="Linaqruf/anything-v3.0",
774
- info="选择要使用的生成模型"
775
- )
776
-
777
- model_info = gr.Textbox(
778
- label="📊 模型信息",
779
- value="📊 推荐分辨率: 512x768 (Anything V3 - 全能模型)",
780
- interactive=False
781
- )
782
-
783
- load_model_btn = gr.Button("🔄 加载模型", variant="secondary")
784
- model_status = gr.Textbox(
785
- label="✅ 状态",
786
- value="✅ 模型已就绪",
787
- interactive=False
788
- )
789
-
790
- gr.Markdown("### 🎨 内容设置")
791
-
792
- t2i_content_type = gr.Dropdown(
793
- choices=list(CONTENT_TEMPLATES.keys()),
794
- label="📋 内容模板",
795
- value="Character Portrait",
796
- info="提示词起点"
797
- )
798
-
799
- t2i_custom_prompt = gr.Textbox(
800
- label="✍️ 详细描述 (重要)",
801
- placeholder="详细描述您想要生成的图像:角色特征、服装、姿势、场景等",
802
- lines=3,
803
- info="描述越详细,生成效果越好"
804
- )
805
-
806
- gr.HTML("""
807
- <div class="info-box">
808
- <strong>💡 提示词示例:</strong><br>
809
- • "美丽的女孩,金色长发,蓝色眼睛,穿着白色连衣裙,站在花园里"<br>
810
- • "性感的女战士,红色铠甲,手持长剑,动态姿势,战场背景"<br>
811
- • "两个女孩在咖啡馆约会,温馨的氛围,详细的面部表情"<br>
812
- • "幻想风格的女精灵,尖耳朵,魔法光效,森林背景"
813
- </div>
814
- """)
815
-
816
- with gr.Row():
817
- t2i_style = gr.Dropdown(
818
- choices=list(COLOR_STYLES.keys()),
819
- label="🎨 艺术风格",
820
- value="Anime Style"
821
- )
822
-
823
- t2i_nsfw_level = gr.Dropdown(
824
- choices=["Safe", "Suggestive", "Mild", "Moderate", "Explicit"],
825
- label="🔞 内容级别",
826
- value="Moderate",
827
- info="内容明确程度"
828
- )
829
-
830
- t2i_quality_tags = gr.Textbox(
831
- label="⭐ 质量标签 (可选)",
832
- placeholder="masterpiece, best quality, highly detailed",
833
- value="masterpiece, best quality, highly detailed"
834
- )
835
-
836
- generate_button = gr.Button("✨ 生成图像!", variant="primary", size="lg")
837
-
838
- with gr.Column(scale=2):
839
- t2i_output_image = gr.Image(
840
- label="🖼️ 生成的图像",
841
- type="pil",
842
- height=500
843
- )
844
-
845
- t2i_generated_prompt = gr.Textbox(
846
- label="📝 生成的提示词",
847
- lines=3,
848
- interactive=False,
849
- show_copy_button=True
850
- )
851
-
852
- with gr.Accordion("⚙️ 高级设置", open=True):
853
- with gr.Row():
854
- t2i_seed = gr.Slider(
855
- label="🎲 种子",
856
- minimum=0,
857
- maximum=2**32-1,
858
- step=1,
859
- value=42
860
- )
861
- t2i_randomize_seed = gr.Checkbox(
862
- label="🔀 随机种子",
863
- value=True
864
- )
865
-
866
- with gr.Row():
867
- t2i_guidance_scale = gr.Slider(
868
- label="💬 引导尺度",
869
- minimum=5.0,
870
- maximum=15.0,
871
- step=0.5,
872
- value=7.5,
873
- info="控制提示词影响力"
874
- )
875
-
876
- t2i_num_steps = gr.Slider(
877
- label="🔢 生成步数",
878
- minimum=10,
879
- maximum=50,
880
- step=5,
881
- value=30,
882
- info="步数越多质量越高但越慢"
883
- )
884
-
885
- with gr.Row():
886
- t2i_width = gr.Slider(
887
- label="📏 宽度",
888
- minimum=256,
889
- maximum=2048,
890
- step=8,
891
- value=512,
892
- info="图像宽度"
893
- )
894
-
895
- t2i_height = gr.Slider(
896
- label="📐 高度",
897
- minimum=256,
898
- maximum=2048,
899
- step=8,
900
- value=768,
901
- info="图像高度"
902
- )
903
-
904
- # 事件处理
905
- model_selector.change(
906
- fn=update_resolution_from_model,
907
- inputs=[model_selector],
908
- outputs=[t2i_width, t2i_height, model_info]
909
- )
910
-
911
- load_model_btn.click(
912
- fn=lambda model_name: (
913
- load_text_to_image_model(model_name, gr.Progress()) and
914
- gr.update(value=f"✅ {model_name} 加载成功")
915
- ),
916
- inputs=[model_selector],
917
- outputs=[model_status]
918
- )
919
-
920
- generate_button.click(
921
- fn=generate_text_to_image,
922
- inputs=[
923
- model_selector,
924
- t2i_content_type,
925
- t2i_style,
926
- t2i_custom_prompt,
927
- t2i_quality_tags,
928
- t2i_nsfw_level,
929
- t2i_seed,
930
- t2i_randomize_seed,
931
- t2i_guidance_scale,
932
- t2i_num_steps,
933
- t2i_width,
934
- t2i_height
935
- ],
936
- outputs=[t2i_output_image, t2i_seed, t2i_generated_prompt]
937
- )
938
-
939
- gr.Markdown("""
940
- ---
941
- ## 📚 快速开始指南
942
-
943
- ### 🆕 **新功能: 文本生成图像**
944
-
945
- 此版本新增文本生成图像功能,支持6种不同的模型:
946
-
947
- #### 🤖 **可用模型:**
948
-
949
- 1. **Anything V3** (`Linaqruf/anything-v3.0`) - 全能动漫模型
950
- 2. **ChikMix V3** (`digiplay/ChikMix_V3`) - 高质量动漫模型
951
- 3. **ChilloutMix** (`digiplay/chilloutmix_NiPrunedFp16Fix`) - 真人风格模型
952
- 4. **Pony Diffusion V6 XL** (`LyliaEngine/Pony_Diffusion_V6_XL`) - SDXL动漫模型 (高分辨率)
953
- 5. **AbyssOrangeMix3** (`wootwoot/abyssorangemix3-popupparade-fp16`) - 色彩鲜艳的动漫模型
954
- 6. **WAI NSFW Illustrious** (`John6666/wai-nsfw-illustrious-v80-sdxl`) - SDXL成人内容优化模型
955
-
956
- ### 🎨 **线稿着色模型**
957
-
958
- 线稿着色功能提供两种线稿模型:
959
- - **标准线稿** (`control_v11p_sd15_lineart`) - 适合一般艺术作品
960
- - **动漫线稿** (`control_v11p_sd15s2_lineart_anime`) - 专为动漫/漫画风格优化 ✨
961
-
962
- ### ✅ **如何使用**
963
-
964
- #### **线稿着色:**
965
- 1. 上传您的线稿(黑白线条在白底上效果最好)
966
- 2. 选择线稿模型 - 动漫风格使用"Anime"模型
967
- 3. 选择内容模板作为起点
968
- 4. 编写详细描述 - 具体说明颜色、特征、服装等
969
- 5. 设置NSFW级别以匹配您的内容
970
- 6. 点击"开始着色!"
971
-
972
- #### **文本生成图像:**
973
- 1. 选择您想要使用的模型
974
- 2. 点击"加载模型"按钮(首次使用或切换模型时需要)
975
- 3. 编写详细描述您想要生成的图像
976
- 4. 调整分辨率和生成参数
977
- 5. 点击"生成图像!"
978
-
979
- ### 💡 **最佳实践提示**
980
-
981
- - **模型选择**: SDXL模型需要更多显存但生成质量更高
982
- - **详细描述**: 描述越详细,生成效果越好
983
- - **分辨率设置**: SDXL模型推荐使用1024x1024,SD1.5模型推荐512x768
984
- - **显存管理**: 模型按需加载,切换模型时会自动清理之前的模型
985
-
986
- ---
987
-
988
- <div style="text-align:center; color:#666; padding:20px;">
989
- <strong>🔞 负责任使用</strong><br>
990
- 此工具用于艺术创作目的。用户必须年满18岁。<br>
991
- 请尊重版权、同意和当地法律。<br>
992
- <em>由 Stable Diffusion + ControlNet + 多种生成模型驱动</em>
993
- </div>
994
- """)
995
-
996
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
1
  import gradio as gr
2
  import numpy as np
 
3
  from PIL import Image
4
  import torch
 
 
5
  import gc
6
 
7
+ # Device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
10
 
11
+ # Lazy import (to avoid long startup if unused)
12
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionPipeline
13
  from controlnet_aux import LineartDetector, LineartAnimeDetector
14
 
15
+ # ===== Model & Config =====
16
+ PIPE_STANDARD = None
17
+ PIPE_ANIME = None
18
+ LINEART_DETECTOR = None
19
+ LINEART_ANIME_DETECTOR = None
20
+ CURRENT_T2I_PIPE = None
21
+ CURRENT_T2I_MODEL = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def load_lineart_models():
24
+ global PIPE_STANDARD, PIPE_ANIME, LINEART_DETECTOR, LINEART_ANIME_DETECTOR
25
+ if PIPE_STANDARD is None:
26
+ print("Loading lineart models...")
27
+ controlnet_std = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_lineart", torch_dtype=dtype).to(device)
28
+ controlnet_anime = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15s2_lineart_anime", torch_dtype=dtype).to(device)
29
+
30
+ PIPE_STANDARD = StableDiffusionControlNetPipeline.from_pretrained(
31
+ "Linaqruf/anything-v3.0", controlnet=controlnet_std, torch_dtype=dtype,
32
+ safety_checker=None, requires_safety_checker=False
 
 
33
  ).to(device)
34
+ PIPE_ANIME = StableDiffusionControlNetPipeline.from_pretrained(
35
+ "Linaqruf/anything-v3.0", controlnet=controlnet_anime, torch_dtype=dtype,
36
+ safety_checker=None, requires_safety_checker=False
 
 
 
37
  ).to(device)
38
+
39
+ for pipe in [PIPE_STANDARD, PIPE_ANIME]:
40
+ pipe.enable_attention_slicing()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  if device.type == "cuda":
42
  pipe.enable_model_cpu_offload()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ LINEART_DETECTOR = LineartDetector.from_pretrained("lllyasviel/Annotators")
45
+ LINEART_ANIME_DETECTOR = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
46
+
47
+ def load_t2i_model(model_name: str):
48
+ global CURRENT_T2I_PIPE, CURRENT_T2I_MODEL
49
+ if CURRENT_T2I_MODEL == model_name and CURRENT_T2I_PIPE is not None:
50
+ return
51
+ if CURRENT_T2I_PIPE is not None:
52
+ del CURRENT_T2I_PIPE
53
+ gc.collect()
54
+ if torch.cuda.is_available():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  torch.cuda.empty_cache()
56
+ print(f"Loading: {model_name}")
57
+ CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
58
+ model_name, torch_dtype=dtype, safety_checker=None, requires_safety_checker=False
59
+ ).to(device)
60
+ CURRENT_T2I_PIPE.enable_attention_slicing()
61
+ if device.type == "cuda":
62
+ CURRENT_T2I_PIPE.enable_model_cpu_offload()
63
+ CURRENT_T2I_MODEL = model_name
64
+
65
+ # ===== Utils =====
66
+ def is_lineart(img: Image.Image) -> bool:
67
+ arr = np.array(img.convert("L"))
68
+ black_white_ratio = np.sum((arr < 50) | (arr > 200)) / arr.size
69
+ return black_white_ratio > 0.7
70
+
71
+ def extract_lineart(img, anime: bool = False):
72
+ if is_lineart(img):
73
+ return img.convert("RGB")
74
+ detector = LINEART_ANIME_DETECTOR if anime else LINEART_DETECTOR
75
+ out = detector(img, detect_resolution=512, image_resolution=512)
76
+ return Image.fromarray(out) if isinstance(out, np.ndarray) else out
77
+
78
+ # ===== Functions =====
79
+ def colorize(sketch, anime_model, prompt, seed, steps, scale, cn_weight):
80
+ load_lineart_models()
81
+ pipe = PIPE_ANIME if anime_model else PIPE_STANDARD
82
+ lineart = extract_lineart(sketch, anime_model)
83
+ gen = torch.Generator(device=device).manual_seed(int(seed))
84
+ out = pipe(
85
+ prompt, image=lineart, num_inference_steps=int(steps),
86
+ guidance_scale=float(scale), controlnet_conditioning_scale=float(cn_weight),
87
+ generator=gen
88
+ ).images[0]
89
+ return out, lineart
90
+
91
+ def t2i(prompt, model, seed, steps, scale, w, h):
92
+ load_t2i_model(model)
93
+ gen = torch.Generator(device=device).manual_seed(int(seed))
94
+ return CURRENT_T2I_PIPE(
95
+ prompt, width=int(w), height=int(h),
96
+ num_inference_steps=int(steps), guidance_scale=float(scale),
97
+ generator=gen
98
+ ).images[0]
99
+
100
+ # ===== Gradio UI (Minimal) =====
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown("# 🎨 Minimal Style2Paints")
103
+
104
+ with gr.Tab("🎨 Colorize"):
105
+ with gr.Row():
106
+ inp = gr.Image(label="Lineart", type="pil")
107
+ out = gr.Image(label="Colored")
108
+ with gr.Row():
109
+ sketch_out = gr.Image(label="Detected Lineart", type="pil")
110
+ anime_chk = gr.Checkbox(label="Anime Model")
111
+ with gr.Row():
112
+ prompt = gr.Textbox(label="Prompt", placeholder="e.g., 1girl, blonde hair, blue eyes")
113
+ seed = gr.Number(value=42, label="Seed")
114
+ with gr.Row():
115
+ steps = gr.Slider(10, 30, 20, step=5, label="Steps")
116
+ scale = gr.Slider(5, 15, 8, step=0.5, label="CFG Scale")
117
+ cn_weight = gr.Slider(0.5, 1.5, 1.0, step=0.1, label="CN Weight")
118
+ run = gr.Button("🎨 Colorize")
119
+ run.click(colorize, [inp, anime_chk, prompt, seed, steps, scale, cn_weight], [out, sketch_out])
120
+
121
+ with gr.Tab("🖼️ Text-to-Image"):
122
+ with gr.Row():
123
+ t2i_out = gr.Image(label="Output", type="pil")
124
+ with gr.Row():
125
+ t2i_prompt = gr.Textbox(label="Prompt", lines=2)
126
+ t2i_model = gr.Dropdown([
127
+ "Linaqruf/anything-v3.0",
128
+ "digiplay/ChikMix_V3",
129
+ "digiplay/chilloutmix_NiPrunedFp16Fix"
130
+ ], value="Linaqruf/anything-v3.0", label="Model")
131
+ with gr.Row():
132
+ t2i_seed = gr.Number(value=42, label="Seed")
133
+ t2i_steps = gr.Slider(10, 50, 30, step=5, label="Steps")
134
+ t2i_scale = gr.Slider(5, 15, 7.5, step=0.5, label="CFG Scale")
135
+ with gr.Row():
136
+ w = gr.Slider(256, 1024, 512, step=64, label="Width")
137
+ h = gr.Slider(256, 1024, 768, step=64, label="Height")
138
+ gen_btn = gr.Button("🖼️ Generate")
139
+ gen_btn.click(t2i, [t2i_prompt, t2i_model, t2i_seed, t2i_steps, t2i_scale, w, h], t2i_out)
140
+
141
+ demo.launch()