r3gm commited on
Commit
67e379c
·
verified ·
1 Parent(s): 3a9a5d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -23
app.py CHANGED
@@ -17,6 +17,16 @@ from torchao.quantization import Int8WeightOnlyConfig
17
 
18
  import aoti
19
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
22
 
@@ -34,6 +44,16 @@ MAX_FRAMES_MODEL = 160
34
  MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1)
35
  MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
36
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  pipe = WanImageToVideoPipeline.from_pretrained(
39
  MODEL_ID,
@@ -52,6 +72,9 @@ pipe = WanImageToVideoPipeline.from_pretrained(
52
  torch_dtype=torch.bfloat16,
53
  ).to('cuda')
54
 
 
 
 
55
  pipe.load_lora_weights(
56
  "Kijai/WanVideo_comfy",
57
  weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
@@ -76,24 +99,17 @@ pipe.load_lora_weights(
76
  weight_name="livewallpaper_wan22_14b_i2v_low_model_0_1_e26.safetensors",
77
  adapter_name="livewallpaper"
78
  )
79
-
80
- default_transformer = copy.deepcopy(pipe.transformer)
81
-
82
  pipe.set_adapters(["livewallpaper"], adapter_weights=[1.])
83
- pipe.fuse_lora(adapter_names=["livewallpaper"], lora_scale=2., components=["transformer"])
84
  pipe.unload_lora_weights()
85
 
86
  quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
87
  quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
88
  quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
89
- quantize_(default_transformer, Float8DynamicActivationFloat8WeightConfig())
90
 
91
  aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
92
  aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
93
- aoti.aoti_blocks_load(default_transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
94
 
95
- static_transformer = pipe.transformer
96
- pipe.transformer = default_transformer
97
 
98
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
99
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
@@ -174,7 +190,9 @@ def get_inference_duration(
174
  guidance_scale,
175
  guidance_scale_2,
176
  current_seed,
177
- live_wallpaper_style,
 
 
178
  ):
179
  BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
180
  BASE_STEP_DURATION = 15
@@ -195,14 +213,22 @@ def run_inference(
195
  guidance_scale,
196
  guidance_scale_2,
197
  current_seed,
198
- live_wallpaper_style,
 
199
  progress=gr.Progress(track_tqdm=True),
200
  ):
201
 
202
- if live_wallpaper_style:
203
- pipe.transformer = static_transformer
204
-
205
- output_frames = pipe(
 
 
 
 
 
 
 
206
  image=resized_image,
207
  last_image=processed_last_image,
208
  prompt=prompt,
@@ -216,9 +242,8 @@ def run_inference(
216
  generator=torch.Generator(device="cuda").manual_seed(current_seed),
217
  ).frames[0]
218
 
219
- pipe.transformer = default_transformer
220
-
221
- return output_frames
222
 
223
 
224
  def generate_video(
@@ -233,7 +258,8 @@ def generate_video(
233
  seed=42,
234
  randomize_seed=False,
235
  quality=5,
236
- live_wallpaper_style=False,
 
237
  progress=gr.Progress(track_tqdm=True),
238
  ):
239
  """
@@ -263,8 +289,8 @@ def generate_video(
263
  Defaults to False.
264
  quality (float, optional): Video output quality. Default is 5. Uses variable bit rate.
265
  Highest quality is 10, lowest is 1.
266
- live_wallpaper_style (bool, optional): Whether to use the live wallpaper transformer.
267
- Defaults to False.
268
  progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
269
 
270
  Returns:
@@ -303,7 +329,8 @@ def generate_video(
303
  guidance_scale,
304
  guidance_scale_2,
305
  current_seed,
306
- live_wallpaper_style,
 
307
  progress,
308
  )
309
 
@@ -323,7 +350,6 @@ with gr.Blocks() as demo:
323
  with gr.Column():
324
  input_image_component = gr.Image(type="pil", label="Input Image")
325
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
326
- live_wallpaper_style_checkbox = gr.Checkbox(label="Live Wallpaper Style", value=False, interactive=True)
327
  duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
328
  steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps")
329
  quality_slider = gr.Slider(minimum=1, maximum=10, step=1, value=6, label="Video Quality")
@@ -335,6 +361,13 @@ with gr.Blocks() as demo:
335
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
336
  guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage")
337
  guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage")
 
 
 
 
 
 
 
338
 
339
  generate_button = gr.Button("Generate Video", variant="primary")
340
  with gr.Column():
@@ -345,7 +378,7 @@ with gr.Blocks() as demo:
345
  input_image_component, last_image_component, prompt_input, steps_slider,
346
  negative_prompt_input, duration_seconds_input,
347
  guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox,
348
- quality_slider, live_wallpaper_style_checkbox
349
  ]
350
  generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, file_output, seed_input])
351
 
 
17
 
18
  import aoti
19
 
20
+ from diffusers import (
21
+ FlowMatchEulerDiscreteScheduler,
22
+ SASolverScheduler,
23
+ DEISMultistepScheduler,
24
+ DPMSolverMultistepInverseScheduler,
25
+ UniPCMultistepScheduler,
26
+ DPMSolverMultistepScheduler,
27
+ DPMSolverSinglestepScheduler,
28
+ )
29
+
30
 
31
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
32
 
 
44
  MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1)
45
  MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
46
 
47
+ SCHEDULER_MAP = {
48
+ "FlowMatchEulerDiscrete": FlowMatchEulerDiscreteScheduler,
49
+ "SASolver": SASolverScheduler,
50
+ "DEISMultistep": DEISMultistepScheduler,
51
+ "DPMSolverMultistepInverse": DPMSolverMultistepInverseScheduler,
52
+ "UniPCMultistep": UniPCMultistepScheduler,
53
+ "DPMSolverMultistep": DPMSolverMultistepScheduler,
54
+ "DPMSolverSinglestep": DPMSolverSinglestepScheduler,
55
+ }
56
+
57
 
58
  pipe = WanImageToVideoPipeline.from_pretrained(
59
  MODEL_ID,
 
72
  torch_dtype=torch.bfloat16,
73
  ).to('cuda')
74
 
75
+ original_scheduler = copy.deepcopy(pipe.scheduler.config)
76
+ print(original_scheduler)
77
+
78
  pipe.load_lora_weights(
79
  "Kijai/WanVideo_comfy",
80
  weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
 
99
  weight_name="livewallpaper_wan22_14b_i2v_low_model_0_1_e26.safetensors",
100
  adapter_name="livewallpaper"
101
  )
 
 
 
102
  pipe.set_adapters(["livewallpaper"], adapter_weights=[1.])
103
+ pipe.fuse_lora(adapter_names=["livewallpaper"], lora_scale=1., components=["transformer"])
104
  pipe.unload_lora_weights()
105
 
106
  quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
107
  quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
108
  quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
 
109
 
110
  aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
111
  aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
 
112
 
 
 
113
 
114
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
115
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
 
190
  guidance_scale,
191
  guidance_scale_2,
192
  current_seed,
193
+ scheduler_name,
194
+ flow_shift,
195
+ progress
196
  ):
197
  BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
198
  BASE_STEP_DURATION = 15
 
213
  guidance_scale,
214
  guidance_scale_2,
215
  current_seed,
216
+ scheduler_name,
217
+ flow_shift,
218
  progress=gr.Progress(track_tqdm=True),
219
  ):
220
 
221
+ scheduler_class = SCHEDULER_MAP.get(scheduler_name)
222
+ if scheduler_class != pipe.scheduler._class_name or flow_shift != pipe.scheduler.config.get("flow_shift", "shift"):
223
+ config = copy.deepcopy(original_scheduler.config)
224
+ print("update scheduler")
225
+ if scheduler_class == FlowMatchEulerDiscreteScheduler:
226
+ config['shift'] = flow_shift
227
+ else:
228
+ config['flow_shift'] = flow_shift
229
+ pipe.scheduler = scheduler_class.from_config(config)
230
+
231
+ result = pipe(
232
  image=resized_image,
233
  last_image=processed_last_image,
234
  prompt=prompt,
 
242
  generator=torch.Generator(device="cuda").manual_seed(current_seed),
243
  ).frames[0]
244
 
245
+ pipe.scheduler = original_scheduler
246
+ return result
 
247
 
248
 
249
  def generate_video(
 
258
  seed=42,
259
  randomize_seed=False,
260
  quality=5,
261
+ scheduler="UniPCMultistep",
262
+ flow_shift=6.0,
263
  progress=gr.Progress(track_tqdm=True),
264
  ):
265
  """
 
289
  Defaults to False.
290
  quality (float, optional): Video output quality. Default is 5. Uses variable bit rate.
291
  Highest quality is 10, lowest is 1.
292
+ scheduler (str, optional): The name of the scheduler to use for inference. Defaults to "UniPCMultistep".
293
+ flow_shift (float, optional): The flow shift value for compatible schedulers. Defaults to 6.0.
294
  progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
295
 
296
  Returns:
 
329
  guidance_scale,
330
  guidance_scale_2,
331
  current_seed,
332
+ scheduler,
333
+ flow_shift,
334
  progress,
335
  )
336
 
 
350
  with gr.Column():
351
  input_image_component = gr.Image(type="pil", label="Input Image")
352
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
 
353
  duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
354
  steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps")
355
  quality_slider = gr.Slider(minimum=1, maximum=10, step=1, value=6, label="Video Quality")
 
361
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
362
  guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage")
363
  guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage")
364
+ scheduler_dropdown = gr.Dropdown(
365
+ label="Scheduler",
366
+ choices=list(SCHEDULER_MAP.keys()),
367
+ value="UniPCMultistep",
368
+ info="Select a custom scheduler."
369
+ )
370
+ flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
371
 
372
  generate_button = gr.Button("Generate Video", variant="primary")
373
  with gr.Column():
 
378
  input_image_component, last_image_component, prompt_input, steps_slider,
379
  negative_prompt_input, duration_seconds_input,
380
  guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox,
381
+ quality_slider, scheduler_dropdown, flow_shift_slider,
382
  ]
383
  generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, file_output, seed_input])
384