r3gm commited on
Commit
e9bbfb1
·
verified ·
1 Parent(s): 107040a

test update rife

Browse files
Files changed (6) hide show
  1. README.md +1 -1
  2. app.py +341 -59
  3. model/loss.py +128 -0
  4. model/pytorch_msssim/__init__.py +198 -0
  5. model/warplayer.py +24 -0
  6. requirements.txt +7 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Wan2.2 14B Fast Preview
3
  emoji: 🐌
4
  colorFrom: yellow
5
  colorTo: pink
 
1
  ---
2
+ title: Wan2.2 14B Preview
3
  emoji: 🐌
4
  colorFrom: yellow
5
  colorTo: pink
app.py CHANGED
@@ -1,24 +1,24 @@
 
1
  import spaces
2
- import torch
3
- from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
4
- from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
5
- from diffusers.utils.export_utils import export_to_video
6
- import gradio as gr
7
- import tempfile
8
- import numpy as np
9
- from PIL import Image
10
  import random
 
 
 
11
  import gc
12
- import copy
13
- import os
14
- import shutil
15
 
16
- from torchao.quantization import quantize_
17
- from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
18
- from torchao.quantization import Int8WeightOnlyConfig
19
-
20
- import aoti
21
 
 
22
  from diffusers import (
23
  FlowMatchEulerDiscreteScheduler,
24
  SASolverScheduler,
@@ -28,15 +28,211 @@ from diffusers import (
28
  DPMSolverMultistepScheduler,
29
  DPMSolverSinglestepScheduler,
30
  )
 
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
 
34
 
35
  MAX_DIM = 832
36
  MIN_DIM = 480
37
  SQUARE_DIM = 640
38
  MULTIPLE_OF = 16
39
-
40
  MAX_SEED = np.iinfo(np.int32).max
41
 
42
  FIXED_FPS = 16
@@ -46,8 +242,6 @@ MAX_FRAMES_MODEL = 160
46
  MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1)
47
  MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
48
 
49
- CACHE_DIR = os.path.expanduser("~/.cache/huggingface/")
50
-
51
  SCHEDULER_MAP = {
52
  "FlowMatchEulerDiscrete": FlowMatchEulerDiscreteScheduler,
53
  "SASolver": SASolverScheduler,
@@ -63,7 +257,6 @@ pipe = WanImageToVideoPipeline.from_pretrained(
63
  torch_dtype=torch.bfloat16,
64
  ).to('cuda')
65
  original_scheduler = copy.deepcopy(pipe.scheduler)
66
- print(original_scheduler)
67
 
68
  if os.path.exists(CACHE_DIR):
69
  shutil.rmtree(CACHE_DIR)
@@ -78,6 +271,8 @@ quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
78
  aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
79
  aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
80
 
 
 
81
 
82
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
83
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
@@ -88,44 +283,36 @@ def resize_image(image: Image.Image) -> Image.Image:
88
  Resizes an image to fit within the model's constraints, preserving aspect ratio as much as possible.
89
  """
90
  width, height = image.size
91
-
92
- # Handle square case
93
  if width == height:
94
  return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
95
-
96
  aspect_ratio = width / height
97
-
98
  MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM
99
  MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM
100
 
101
  image_to_resize = image
102
-
103
  if aspect_ratio > MAX_ASPECT_RATIO:
104
- # Very wide image -> crop width to fit 832x480 aspect ratio
105
  target_w, target_h = MAX_DIM, MIN_DIM
106
  crop_width = int(round(height * MAX_ASPECT_RATIO))
107
  left = (width - crop_width) // 2
108
  image_to_resize = image.crop((left, 0, left + crop_width, height))
109
  elif aspect_ratio < MIN_ASPECT_RATIO:
110
- # Very tall image -> crop height to fit 480x832 aspect ratio
111
  target_w, target_h = MIN_DIM, MAX_DIM
112
  crop_height = int(round(width / MIN_ASPECT_RATIO))
113
  top = (height - crop_height) // 2
114
  image_to_resize = image.crop((0, top, width, top + crop_height))
115
  else:
116
- if width > height: # Landscape
117
  target_w = MAX_DIM
118
  target_h = int(round(target_w / aspect_ratio))
119
- else: # Portrait
120
  target_h = MAX_DIM
121
  target_w = int(round(target_h * aspect_ratio))
122
 
123
  final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF
124
  final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF
125
-
126
  final_w = max(MIN_DIM, min(MAX_DIM, final_w))
127
  final_h = max(MIN_DIM, min(MAX_DIM, final_h))
128
-
129
  return image_to_resize.resize((final_w, final_h), Image.LANCZOS)
130
 
131
 
@@ -160,6 +347,9 @@ def get_inference_duration(
160
  current_seed,
161
  scheduler_name,
162
  flow_shift,
 
 
 
163
  progress
164
  ):
165
  BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
@@ -167,7 +357,18 @@ def get_inference_duration(
167
  width, height = resized_image.size
168
  factor = num_frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
169
  step_duration = BASE_STEP_DURATION * factor ** 1.5
170
- return 5 + int(steps) * step_duration
 
 
 
 
 
 
 
 
 
 
 
171
 
172
 
173
  @spaces.GPU(duration=get_inference_duration)
@@ -183,9 +384,11 @@ def run_inference(
183
  current_seed,
184
  scheduler_name,
185
  flow_shift,
 
 
 
186
  progress=gr.Progress(track_tqdm=True),
187
  ):
188
-
189
  scheduler_class = SCHEDULER_MAP.get(scheduler_name)
190
  if scheduler_class.__name__ != pipe.scheduler.config._class_name or flow_shift != pipe.scheduler.config.get("flow_shift", "shift"):
191
  config = copy.deepcopy(original_scheduler.config)
@@ -195,6 +398,11 @@ def run_inference(
195
  config['flow_shift'] = flow_shift
196
  pipe.scheduler = scheduler_class.from_config(config)
197
 
 
 
 
 
 
198
  result = pipe(
199
  image=resized_image,
200
  last_image=processed_last_image,
@@ -207,10 +415,33 @@ def run_inference(
207
  guidance_scale_2=float(guidance_scale_2),
208
  num_inference_steps=int(steps),
209
  generator=torch.Generator(device="cuda").manual_seed(current_seed),
210
- ).frames[0]
211
-
 
 
212
  pipe.scheduler = original_scheduler
213
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
 
216
  def generate_video(
@@ -227,15 +458,15 @@ def generate_video(
227
  quality=5,
228
  scheduler="UniPCMultistep",
229
  flow_shift=6.0,
 
 
230
  progress=gr.Progress(track_tqdm=True),
231
  ):
232
  """
233
  Generate a video from an input image using the Wan 2.2 14B I2V model with Lightning LoRA.
234
-
235
  This function takes an input image and generates a video animation based on the provided
236
  prompt and parameters. It uses an FP8 qunatized Wan 2.2 14B Image-to-Video model in with Lightning LoRA
237
  for fast generation in 4-8 steps.
238
-
239
  Args:
240
  input_image (PIL.Image): The input image to animate. Will be resized to target dimensions.
241
  last_image (PIL.Image, optional): The optional last image for the video.
@@ -258,23 +489,24 @@ def generate_video(
258
  Highest quality is 10, lowest is 1.
259
  scheduler (str, optional): The name of the scheduler to use for inference. Defaults to "UniPCMultistep".
260
  flow_shift (float, optional): The flow shift value for compatible schedulers. Defaults to 6.0.
 
 
 
261
  progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
262
-
263
  Returns:
264
  tuple: A tuple containing:
265
  - video_path (str): Path for the video component.
266
  - video_path (str): Path for the file download component. Attempt to avoid reconversion in video component.
267
  - current_seed (int): The seed used for generation.
268
-
269
  Raises:
270
  gr.Error: If input_image is None (no image uploaded).
271
-
272
  Note:
273
  - Frame count is calculated as duration_seconds * FIXED_FPS (24)
274
  - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
275
  - The function uses GPU acceleration via the @spaces.GPU decorator
276
  - Generation time varies based on steps and duration (see get_duration function)
277
  """
 
278
  if input_image is None:
279
  raise gr.Error("Please upload an input image.")
280
 
@@ -286,7 +518,7 @@ def generate_video(
286
  if last_image:
287
  processed_last_image = resize_and_crop_to_match(last_image, resized_image)
288
 
289
- output_frames_list = run_inference(
290
  resized_image,
291
  processed_last_image,
292
  prompt,
@@ -298,36 +530,55 @@ def generate_video(
298
  current_seed,
299
  scheduler,
300
  flow_shift,
 
 
 
301
  progress,
302
  )
 
303
 
304
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
305
- video_path = tmpfile.name
306
 
307
- export_to_video(output_frames_list, video_path, fps=FIXED_FPS, quality=quality)
308
 
309
- return video_path, video_path, current_seed
 
 
 
 
 
 
 
 
 
 
 
310
 
311
 
312
- with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(3600, 10800)) as demo:
313
- gr.Markdown("# WAMU V2 - Wan 2.2 I2V (14B) 🐢")
314
- gr.Markdown("## ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
315
  gr.Markdown('Try the previous version: [WAMU v1](https://huggingface.co/spaces/r3gm/wan2-2-fp8da-aoti-preview2)')
316
  gr.Markdown("Run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU")
 
317
  with gr.Row():
318
  with gr.Column():
319
- input_image_component = gr.Image(type="pil", label="Input Image")
320
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
321
  duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
322
- steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps")
323
-
 
 
 
 
324
  with gr.Accordion("Advanced Settings", open=False):
325
- last_image_component = gr.Image(type="pil", label="Last Image (Optional)")
326
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, info="Used if any Guidance Scale > 1.", lines=3)
327
- quality_slider = gr.Slider(minimum=1, maximum=10, step=1, value=6, label="Video Quality")
328
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
329
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
330
- guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage")
 
331
  guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage")
332
  scheduler_dropdown = gr.Dropdown(
333
  label="Scheduler",
@@ -336,20 +587,51 @@ with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(3600, 10800)) as demo:
336
  info="Select a custom scheduler."
337
  )
338
  flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
 
339
 
340
  generate_button = gr.Button("Generate Video", variant="primary")
 
341
  with gr.Column():
342
- video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
 
 
 
 
 
 
 
 
343
  file_output = gr.File(label="Download Video")
344
 
345
  ui_inputs = [
346
  input_image_component, last_image_component, prompt_input, steps_slider,
347
  negative_prompt_input, duration_seconds_input,
348
  guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox,
349
- quality_slider, scheduler_dropdown, flow_shift_slider,
 
350
  ]
351
- generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, file_output, seed_input])
352
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
  if __name__ == "__main__":
355
  demo.queue().launch(
 
1
+ import os
2
  import spaces
3
+ import shutil
4
+ import subprocess
5
+ import sys
6
+ import copy
 
 
 
 
7
  import random
8
+ import tempfile
9
+ import warnings
10
+ import time
11
  import gc
12
+ import uuid
13
+ from tqdm import tqdm
 
14
 
15
+ import cv2
16
+ import numpy as np
17
+ import torch
18
+ from torch.nn import functional as F
19
+ from PIL import Image
20
 
21
+ import gradio as gr
22
  from diffusers import (
23
  FlowMatchEulerDiscreteScheduler,
24
  SASolverScheduler,
 
28
  DPMSolverMultistepScheduler,
29
  DPMSolverSinglestepScheduler,
30
  )
31
+ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
32
+ from diffusers.utils.export_utils import export_to_video
33
 
34
+ from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
35
+ import aoti
36
+
37
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
38
+ warnings.filterwarnings("ignore")
39
+
40
+ # --- FRAME EXTRACTION JS & LOGIC ---
41
+
42
+ # JS to grab timestamp from the output video
43
+ get_timestamp_js = """
44
+ function() {
45
+ // Select the video element specifically inside the component with id 'generated-video'
46
+ const video = document.querySelector('#generated-video video');
47
+
48
+ if (video) {
49
+ console.log("Video found! Time: " + video.currentTime);
50
+ return video.currentTime;
51
+ } else {
52
+ console.log("No video element found.");
53
+ return 0;
54
+ }
55
+ }
56
+ """
57
+
58
+
59
+ def extract_frame(video_path, timestamp):
60
+ # Safety check: if no video is present
61
+ if not video_path:
62
+ return None
63
+
64
+ print(f"Extracting frame at timestamp: {timestamp}")
65
+
66
+ cap = cv2.VideoCapture(video_path)
67
+
68
+ if not cap.isOpened():
69
+ return None
70
+
71
+ # Calculate frame number
72
+ fps = cap.get(cv2.CAP_PROP_FPS)
73
+ target_frame_num = int(float(timestamp) * fps)
74
+
75
+ # Cap total frames to prevent errors at the very end of video
76
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
77
+ if target_frame_num >= total_frames:
78
+ target_frame_num = total_frames - 1
79
+
80
+ # Set position
81
+ cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame_num)
82
+ ret, frame = cap.read()
83
+ cap.release()
84
+
85
+ if ret:
86
+ # Convert from BGR (OpenCV) to RGB (Gradio)
87
+ # Gradio Image component handles Numpy array -> PIL conversion automatically
88
+ return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
89
+
90
+ return None
91
+
92
+ # --- END FRAME EXTRACTION LOGIC ---
93
+
94
+
95
+ def clear_vram():
96
+ gc.collect()
97
+ torch.cuda.empty_cache()
98
+
99
+
100
+ # RIFE
101
+ if not os.path.exists("RIFEv4.26_0921.zip"):
102
+ print("Downloading RIFE Model...")
103
+ subprocess.run([
104
+ "wget", "-q",
105
+ "https://huggingface.co/r3gm/RIFE/resolve/main/RIFEv4.26_0921.zip",
106
+ "-O", "RIFEv4.26_0921.zip"
107
+ ], check=True)
108
+ subprocess.run(["unzip", "-o", "RIFEv4.26_0921.zip"], check=True)
109
+
110
+ # sys.path.append(os.getcwd())
111
+
112
+ from train_log.RIFE_HDv3 import Model
113
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
+ rife_model = Model()
115
+ rife_model.load_model("train_log", -1)
116
+ rife_model.eval()
117
+
118
+
119
+ @torch.no_grad()
120
+ def interpolate_bits(frames_np, multiplier=2, scale=1.0):
121
+ """
122
+ Interpolation maintaining Numpy Float 0-1 format.
123
+ Args:
124
+ frames_np: Numpy Array (Time, Height, Width, Channels) - Float32 [0.0, 1.0]
125
+ multiplier: int (2, 4, 8)
126
+ Returns:
127
+ List of Numpy Arrays (Height, Width, Channels) - Float32 [0.0, 1.0]
128
+ """
129
+
130
+ # Handle input shape
131
+ if isinstance(frames_np, list):
132
+ # Convert list of arrays to one big array for easier shape handling if needed,
133
+ # but here we just grab dims from first frame
134
+ T = len(frames_np)
135
+ H, W, C = frames_np[0].shape
136
+ else:
137
+ T, H, W, C = frames_np.shape
138
+
139
+ # 1. No Interpolation Case
140
+ if multiplier < 2:
141
+ # Just convert 4D array to list of 3D arrays
142
+ if isinstance(frames_np, np.ndarray):
143
+ return list(frames_np)
144
+ return frames_np
145
+
146
+ n_interp = multiplier - 1
147
+
148
+ # Pre-calc padding for RIFE (requires dimensions divisible by 32/scale)
149
+ tmp = max(128, int(128 / scale))
150
+ ph = ((H - 1) // tmp + 1) * tmp
151
+ pw = ((W - 1) // tmp + 1) * tmp
152
+ padding = (0, pw - W, 0, ph - H)
153
+
154
+ # Helper: Numpy (H, W, C) Float -> Tensor (1, C, H, W) Half
155
+ def to_tensor(frame_np):
156
+ # frame_np is float32 0-1
157
+ t = torch.from_numpy(frame_np).to(device)
158
+ # HWC -> CHW
159
+ t = t.permute(2, 0, 1).unsqueeze(0)
160
+ return F.pad(t, padding).half()
161
+
162
+ # Helper: Tensor (1, C, H, W) Half -> Numpy (H, W, C) Float
163
+ def from_tensor(tensor):
164
+ # Crop padding
165
+ t = tensor[0, :, :H, :W]
166
+ # CHW -> HWC
167
+ t = t.permute(1, 2, 0)
168
+ # Keep as float32, range 0-1
169
+ return t.float().cpu().numpy()
170
+
171
+ def make_inference(I0, I1, n):
172
+ if rife_model.version >= 3.9:
173
+ res = []
174
+ for i in range(n):
175
+ res.append(rife_model.inference(I0, I1, (i+1) * 1. / (n+1), scale))
176
+ return res
177
+ else:
178
+ middle = rife_model.inference(I0, I1, scale)
179
+ if n == 1:
180
+ return [middle]
181
+ first_half = make_inference(I0, middle, n=n//2)
182
+ second_half = make_inference(middle, I1, n=n//2)
183
+ if n % 2:
184
+ return [*first_half, middle, *second_half]
185
+ else:
186
+ return [*first_half, *second_half]
187
+
188
+ output_frames = []
189
+
190
+ # Process Frames
191
+ # Load first frame into GPU
192
+ I1 = to_tensor(frames_np[0])
193
+
194
+ total_steps = T - 1
195
+
196
+ with tqdm(total=total_steps, desc="Interpolating", unit="frame") as pbar:
197
+
198
+ for i in range(total_steps):
199
+ I0 = I1
200
+ # Add original frame to output
201
+ output_frames.append(from_tensor(I0))
202
+
203
+ # Load next frame
204
+ I1 = to_tensor(frames_np[i+1])
205
+
206
+ # Generate intermediate frames
207
+ mid_tensors = make_inference(I0, I1, n_interp)
208
+
209
+ # Append intermediate frames
210
+ for mid in mid_tensors:
211
+ output_frames.append(from_tensor(mid))
212
+
213
+ if (i + 1) % 50 == 0:
214
+ pbar.update(50)
215
+ pbar.update(total_steps % 50)
216
+
217
+ # Add the very last frame
218
+ output_frames.append(from_tensor(I1))
219
+
220
+ # Cleanup
221
+ del I0, I1, mid_tensors
222
+ torch.cuda.empty_cache()
223
+
224
+ return output_frames
225
+
226
+
227
+ # WAN
228
 
229
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
230
+ CACHE_DIR = os.path.expanduser("~/.cache/huggingface/")
231
 
232
  MAX_DIM = 832
233
  MIN_DIM = 480
234
  SQUARE_DIM = 640
235
  MULTIPLE_OF = 16
 
236
  MAX_SEED = np.iinfo(np.int32).max
237
 
238
  FIXED_FPS = 16
 
242
  MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1)
243
  MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
244
 
 
 
245
  SCHEDULER_MAP = {
246
  "FlowMatchEulerDiscrete": FlowMatchEulerDiscreteScheduler,
247
  "SASolver": SASolverScheduler,
 
257
  torch_dtype=torch.bfloat16,
258
  ).to('cuda')
259
  original_scheduler = copy.deepcopy(pipe.scheduler)
 
260
 
261
  if os.path.exists(CACHE_DIR):
262
  shutil.rmtree(CACHE_DIR)
 
271
  aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
272
  aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
273
 
274
+ # pipe.vae.enable_slicing()
275
+ # pipe.vae.enable_tiling()
276
 
277
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
278
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
 
283
  Resizes an image to fit within the model's constraints, preserving aspect ratio as much as possible.
284
  """
285
  width, height = image.size
 
 
286
  if width == height:
287
  return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
288
+
289
  aspect_ratio = width / height
 
290
  MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM
291
  MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM
292
 
293
  image_to_resize = image
 
294
  if aspect_ratio > MAX_ASPECT_RATIO:
 
295
  target_w, target_h = MAX_DIM, MIN_DIM
296
  crop_width = int(round(height * MAX_ASPECT_RATIO))
297
  left = (width - crop_width) // 2
298
  image_to_resize = image.crop((left, 0, left + crop_width, height))
299
  elif aspect_ratio < MIN_ASPECT_RATIO:
 
300
  target_w, target_h = MIN_DIM, MAX_DIM
301
  crop_height = int(round(width / MIN_ASPECT_RATIO))
302
  top = (height - crop_height) // 2
303
  image_to_resize = image.crop((0, top, width, top + crop_height))
304
  else:
305
+ if width > height:
306
  target_w = MAX_DIM
307
  target_h = int(round(target_w / aspect_ratio))
308
+ else:
309
  target_h = MAX_DIM
310
  target_w = int(round(target_h * aspect_ratio))
311
 
312
  final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF
313
  final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF
 
314
  final_w = max(MIN_DIM, min(MAX_DIM, final_w))
315
  final_h = max(MIN_DIM, min(MAX_DIM, final_h))
 
316
  return image_to_resize.resize((final_w, final_h), Image.LANCZOS)
317
 
318
 
 
347
  current_seed,
348
  scheduler_name,
349
  flow_shift,
350
+ frame_multiplier,
351
+ quality,
352
+ duration_seconds,
353
  progress
354
  ):
355
  BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
 
357
  width, height = resized_image.size
358
  factor = num_frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
359
  step_duration = BASE_STEP_DURATION * factor ** 1.5
360
+ gen_time = int(steps) * step_duration
361
+
362
+ if guidance_scale > 1:
363
+ gen_time = gen_time * 1.8
364
+
365
+ frame_factor = frame_multiplier // FIXED_FPS
366
+ if frame_factor > 1:
367
+ total_out_frames = (num_frames * frame_factor) - num_frames
368
+ inter_time = (total_out_frames * 0.02)
369
+ gen_time += inter_time
370
+
371
+ return 10 + gen_time
372
 
373
 
374
  @spaces.GPU(duration=get_inference_duration)
 
384
  current_seed,
385
  scheduler_name,
386
  flow_shift,
387
+ frame_multiplier,
388
+ quality,
389
+ duration_seconds,
390
  progress=gr.Progress(track_tqdm=True),
391
  ):
 
392
  scheduler_class = SCHEDULER_MAP.get(scheduler_name)
393
  if scheduler_class.__name__ != pipe.scheduler.config._class_name or flow_shift != pipe.scheduler.config.get("flow_shift", "shift"):
394
  config = copy.deepcopy(original_scheduler.config)
 
398
  config['flow_shift'] = flow_shift
399
  pipe.scheduler = scheduler_class.from_config(config)
400
 
401
+ clear_vram()
402
+
403
+ task_name = str(uuid.uuid4())[:8]
404
+ print(f"Task: {task_name}, {duration_seconds}, {resized_image.size}, FM={frame_multiplier}")
405
+ start = time.time()
406
  result = pipe(
407
  image=resized_image,
408
  last_image=processed_last_image,
 
415
  guidance_scale_2=float(guidance_scale_2),
416
  num_inference_steps=int(steps),
417
  generator=torch.Generator(device="cuda").manual_seed(current_seed),
418
+ output_type="np"
419
+ )
420
+
421
+ raw_frames_np = result.frames[0] # Returns (T, H, W, C) float32
422
  pipe.scheduler = original_scheduler
423
+
424
+ frame_factor = frame_multiplier // FIXED_FPS
425
+ if frame_factor > 1:
426
+ start = time.time()
427
+ rife_model.device()
428
+ rife_model.flownet = rife_model.flownet.half()
429
+ final_frames = interpolate_bits(raw_frames_np, multiplier=int(frame_factor))
430
+ else:
431
+ final_frames = list(raw_frames_np)
432
+
433
+ final_fps = FIXED_FPS * int(frame_factor)
434
+
435
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
436
+ video_path = tmpfile.name
437
+
438
+ start = time.time()
439
+ with tqdm(total=3, desc="Rendering Media", unit="clip") as pbar:
440
+ pbar.update(2)
441
+ export_to_video(final_frames, video_path, fps=final_fps, quality=quality)
442
+ pbar.update(1)
443
+
444
+ return video_path, task_name
445
 
446
 
447
  def generate_video(
 
458
  quality=5,
459
  scheduler="UniPCMultistep",
460
  flow_shift=6.0,
461
+ frame_multiplier=16,
462
+ video_component=True,
463
  progress=gr.Progress(track_tqdm=True),
464
  ):
465
  """
466
  Generate a video from an input image using the Wan 2.2 14B I2V model with Lightning LoRA.
 
467
  This function takes an input image and generates a video animation based on the provided
468
  prompt and parameters. It uses an FP8 qunatized Wan 2.2 14B Image-to-Video model in with Lightning LoRA
469
  for fast generation in 4-8 steps.
 
470
  Args:
471
  input_image (PIL.Image): The input image to animate. Will be resized to target dimensions.
472
  last_image (PIL.Image, optional): The optional last image for the video.
 
489
  Highest quality is 10, lowest is 1.
490
  scheduler (str, optional): The name of the scheduler to use for inference. Defaults to "UniPCMultistep".
491
  flow_shift (float, optional): The flow shift value for compatible schedulers. Defaults to 6.0.
492
+ frame_multiplier (int, optional): The int value for fps enhancer
493
+ video_component(bool, optional): Show video player in output.
494
+ Defaults to True.
495
  progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
 
496
  Returns:
497
  tuple: A tuple containing:
498
  - video_path (str): Path for the video component.
499
  - video_path (str): Path for the file download component. Attempt to avoid reconversion in video component.
500
  - current_seed (int): The seed used for generation.
 
501
  Raises:
502
  gr.Error: If input_image is None (no image uploaded).
 
503
  Note:
504
  - Frame count is calculated as duration_seconds * FIXED_FPS (24)
505
  - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
506
  - The function uses GPU acceleration via the @spaces.GPU decorator
507
  - Generation time varies based on steps and duration (see get_duration function)
508
  """
509
+
510
  if input_image is None:
511
  raise gr.Error("Please upload an input image.")
512
 
 
518
  if last_image:
519
  processed_last_image = resize_and_crop_to_match(last_image, resized_image)
520
 
521
+ video_path, task_n = run_inference(
522
  resized_image,
523
  processed_last_image,
524
  prompt,
 
530
  current_seed,
531
  scheduler,
532
  flow_shift,
533
+ frame_multiplier,
534
+ quality,
535
+ duration_seconds,
536
  progress,
537
  )
538
+ print(f"GPU complete: {task_n}")
539
 
540
+ return (video_path if video_component else None), video_path, current_seed
 
541
 
 
542
 
543
+ CSS = """
544
+ #hidden-timestamp {
545
+ opacity: 0;
546
+ height: 0px;
547
+ width: 0px;
548
+ margin: 0px;
549
+ padding: 0px;
550
+ overflow: hidden;
551
+ position: absolute;
552
+ pointer-events: none;
553
+ }
554
+ """
555
 
556
 
557
+ with gr.Blocks(theme=gr.themes.Soft(), css=CSS, delete_cache=(3600, 10800)) as demo:
558
+ gr.Markdown("## WAMU V2 - Wan 2.2 I2V (14B) 🐢🐢")
559
+ gr.Markdown("#### ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
560
  gr.Markdown('Try the previous version: [WAMU v1](https://huggingface.co/spaces/r3gm/wan2-2-fp8da-aoti-preview2)')
561
  gr.Markdown("Run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU")
562
+
563
  with gr.Row():
564
  with gr.Column():
565
+ input_image_component = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"])
566
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
567
  duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
568
+ frame_multi = gr.Dropdown(
569
+ choices=[FIXED_FPS, FIXED_FPS*2, FIXED_FPS*4],
570
+ value=FIXED_FPS,
571
+ label="Video Fluidity (Frames per Second)",
572
+ info="Extra frames will be generated using flow estimation, which estimates motion between frames to make the video smoother."
573
+ )
574
  with gr.Accordion("Advanced Settings", open=False):
575
+ last_image_component = gr.Image(type="pil", label="Last Image (Optional)", sources=["upload", "clipboard"])
576
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, info="Used if any Guidance Scale > 1.", lines=3)
577
+ quality_slider = gr.Slider(minimum=1, maximum=10, step=1, value=6, label="Video Quality", info="If set to 10, the generated video may be too large and won't play in the Gradio preview.")
578
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
579
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
580
+ steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps")
581
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage", info="Values above 1 increase GPU usage and may take longer to process.")
582
  guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage")
583
  scheduler_dropdown = gr.Dropdown(
584
  label="Scheduler",
 
587
  info="Select a custom scheduler."
588
  )
589
  flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
590
+ play_result_video = gr.Checkbox(label="Display result", value=True, interactive=True)
591
 
592
  generate_button = gr.Button("Generate Video", variant="primary")
593
+
594
  with gr.Column():
595
+ # ASSIGNED elem_id="generated-video" so JS can find it
596
+ video_output = gr.Video(label="Generated Video", autoplay=True, sources=["upload"], show_download_button=True, show_share_button=True, interactive=False, elem_id="generated-video")
597
+
598
+ # --- Frame Grabbing UI ---
599
+ with gr.Row():
600
+ grab_frame_btn = gr.Button("📸 Use Current Frame as Input", variant="secondary")
601
+ timestamp_box = gr.Number(value=0, label="Timestamp", visible=True, elem_id="hidden-timestamp")
602
+ # -------------------------
603
+
604
  file_output = gr.File(label="Download Video")
605
 
606
  ui_inputs = [
607
  input_image_component, last_image_component, prompt_input, steps_slider,
608
  negative_prompt_input, duration_seconds_input,
609
  guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox,
610
+ quality_slider, scheduler_dropdown, flow_shift_slider, frame_multi,
611
+ play_result_video
612
  ]
613
+
614
+ generate_button.click(
615
+ fn=generate_video,
616
+ inputs=ui_inputs,
617
+ outputs=[video_output, file_output, seed_input]
618
+ )
619
+
620
+ # --- Frame Grabbing Events ---
621
+ # 1. Click button -> JS runs -> puts time in hidden number box
622
+ grab_frame_btn.click(
623
+ fn=None,
624
+ inputs=None,
625
+ outputs=[timestamp_box],
626
+ js=get_timestamp_js
627
+ )
628
+
629
+ # 2. Hidden number box changes -> Python runs -> puts frame in Input Image
630
+ timestamp_box.change(
631
+ fn=extract_frame,
632
+ inputs=[video_output, timestamp_box],
633
+ outputs=[input_image_component]
634
+ )
635
 
636
  if __name__ == "__main__":
637
  demo.queue().launch(
model/loss.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.models as models
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+
10
+ class EPE(nn.Module):
11
+ def __init__(self):
12
+ super(EPE, self).__init__()
13
+
14
+ def forward(self, flow, gt, loss_mask):
15
+ loss_map = (flow - gt.detach()) ** 2
16
+ loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5
17
+ return (loss_map * loss_mask)
18
+
19
+
20
+ class Ternary(nn.Module):
21
+ def __init__(self):
22
+ super(Ternary, self).__init__()
23
+ patch_size = 7
24
+ out_channels = patch_size * patch_size
25
+ self.w = np.eye(out_channels).reshape(
26
+ (patch_size, patch_size, 1, out_channels))
27
+ self.w = np.transpose(self.w, (3, 2, 0, 1))
28
+ self.w = torch.tensor(self.w).float().to(device)
29
+
30
+ def transform(self, img):
31
+ patches = F.conv2d(img, self.w, padding=3, bias=None)
32
+ transf = patches - img
33
+ transf_norm = transf / torch.sqrt(0.81 + transf**2)
34
+ return transf_norm
35
+
36
+ def rgb2gray(self, rgb):
37
+ r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
38
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
39
+ return gray
40
+
41
+ def hamming(self, t1, t2):
42
+ dist = (t1 - t2) ** 2
43
+ dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
44
+ return dist_norm
45
+
46
+ def valid_mask(self, t, padding):
47
+ n, _, h, w = t.size()
48
+ inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
49
+ mask = F.pad(inner, [padding] * 4)
50
+ return mask
51
+
52
+ def forward(self, img0, img1):
53
+ img0 = self.transform(self.rgb2gray(img0))
54
+ img1 = self.transform(self.rgb2gray(img1))
55
+ return self.hamming(img0, img1) * self.valid_mask(img0, 1)
56
+
57
+
58
+ class SOBEL(nn.Module):
59
+ def __init__(self):
60
+ super(SOBEL, self).__init__()
61
+ self.kernelX = torch.tensor([
62
+ [1, 0, -1],
63
+ [2, 0, -2],
64
+ [1, 0, -1],
65
+ ]).float()
66
+ self.kernelY = self.kernelX.clone().T
67
+ self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device)
68
+ self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device)
69
+
70
+ def forward(self, pred, gt):
71
+ N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3]
72
+ img_stack = torch.cat(
73
+ [pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0)
74
+ sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1)
75
+ sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1)
76
+ pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:]
77
+ pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:]
78
+
79
+ L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y)
80
+ loss = (L1X+L1Y)
81
+ return loss
82
+
83
+ class MeanShift(nn.Conv2d):
84
+ def __init__(self, data_mean, data_std, data_range=1, norm=True):
85
+ c = len(data_mean)
86
+ super(MeanShift, self).__init__(c, c, kernel_size=1)
87
+ std = torch.Tensor(data_std)
88
+ self.weight.data = torch.eye(c).view(c, c, 1, 1)
89
+ if norm:
90
+ self.weight.data.div_(std.view(c, 1, 1, 1))
91
+ self.bias.data = -1 * data_range * torch.Tensor(data_mean)
92
+ self.bias.data.div_(std)
93
+ else:
94
+ self.weight.data.mul_(std.view(c, 1, 1, 1))
95
+ self.bias.data = data_range * torch.Tensor(data_mean)
96
+ self.requires_grad = False
97
+
98
+ class VGGPerceptualLoss(torch.nn.Module):
99
+ def __init__(self, rank=0):
100
+ super(VGGPerceptualLoss, self).__init__()
101
+ blocks = []
102
+ pretrained = True
103
+ self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
104
+ self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
105
+ for param in self.parameters():
106
+ param.requires_grad = False
107
+
108
+ def forward(self, X, Y, indices=None):
109
+ X = self.normalize(X)
110
+ Y = self.normalize(Y)
111
+ indices = [2, 7, 12, 21, 30]
112
+ weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
113
+ k = 0
114
+ loss = 0
115
+ for i in range(indices[-1]):
116
+ X = self.vgg_pretrained_features[i](X)
117
+ Y = self.vgg_pretrained_features[i](Y)
118
+ if (i+1) in indices:
119
+ loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1
120
+ k += 1
121
+ return loss
122
+
123
+ if __name__ == '__main__':
124
+ img0 = torch.zeros(3, 3, 256, 256).float().to(device)
125
+ img1 = torch.tensor(np.random.normal(
126
+ 0, 1, (3, 3, 256, 256))).float().to(device)
127
+ ternary_loss = Ternary()
128
+ print(ternary_loss(img0, img1).shape)
model/pytorch_msssim/__init__.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from math import exp
4
+ import numpy as np
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ def gaussian(window_size, sigma):
9
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
10
+ return gauss/gauss.sum()
11
+
12
+
13
+ def create_window(window_size, channel=1):
14
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
15
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
16
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
17
+ return window
18
+
19
+ def create_window_3d(window_size, channel=1):
20
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
21
+ _2D_window = _1D_window.mm(_1D_window.t())
22
+ _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
23
+ window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
24
+ return window
25
+
26
+
27
+ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
28
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
29
+ if val_range is None:
30
+ if torch.max(img1) > 128:
31
+ max_val = 255
32
+ else:
33
+ max_val = 1
34
+
35
+ if torch.min(img1) < -0.5:
36
+ min_val = -1
37
+ else:
38
+ min_val = 0
39
+ L = max_val - min_val
40
+ else:
41
+ L = val_range
42
+
43
+ padd = 0
44
+ (_, channel, height, width) = img1.size()
45
+ if window is None:
46
+ real_size = min(window_size, height, width)
47
+ window = create_window(real_size, channel=channel).to(img1.device).type_as(img1)
48
+
49
+ mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
50
+ mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
51
+
52
+ mu1_sq = mu1.pow(2)
53
+ mu2_sq = mu2.pow(2)
54
+ mu1_mu2 = mu1 * mu2
55
+
56
+ sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq
57
+ sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq
58
+ sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2
59
+
60
+ C1 = (0.01 * L) ** 2
61
+ C2 = (0.03 * L) ** 2
62
+
63
+ v1 = 2.0 * sigma12 + C2
64
+ v2 = sigma1_sq + sigma2_sq + C2
65
+ cs = torch.mean(v1 / v2) # contrast sensitivity
66
+
67
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
68
+
69
+ if size_average:
70
+ ret = ssim_map.mean()
71
+ else:
72
+ ret = ssim_map.mean(1).mean(1).mean(1)
73
+
74
+ if full:
75
+ return ret, cs
76
+ return ret
77
+
78
+
79
+ def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
80
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
81
+ if val_range is None:
82
+ if torch.max(img1) > 128:
83
+ max_val = 255
84
+ else:
85
+ max_val = 1
86
+
87
+ if torch.min(img1) < -0.5:
88
+ min_val = -1
89
+ else:
90
+ min_val = 0
91
+ L = max_val - min_val
92
+ else:
93
+ L = val_range
94
+
95
+ padd = 0
96
+ (_, _, height, width) = img1.size()
97
+ if window is None:
98
+ real_size = min(window_size, height, width)
99
+ window = create_window_3d(real_size, channel=1).to(img1.device).type_as(img1)
100
+ # Channel is set to 1 since we consider color images as volumetric images
101
+
102
+ img1 = img1.unsqueeze(1)
103
+ img2 = img2.unsqueeze(1)
104
+
105
+ mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
106
+ mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
107
+
108
+ mu1_sq = mu1.pow(2)
109
+ mu2_sq = mu2.pow(2)
110
+ mu1_mu2 = mu1 * mu2
111
+
112
+ sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq
113
+ sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq
114
+ sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2
115
+
116
+ C1 = (0.01 * L) ** 2
117
+ C2 = (0.03 * L) ** 2
118
+
119
+ v1 = 2.0 * sigma12 + C2
120
+ v2 = sigma1_sq + sigma2_sq + C2
121
+ cs = torch.mean(v1 / v2) # contrast sensitivity
122
+
123
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
124
+
125
+ if size_average:
126
+ ret = ssim_map.mean()
127
+ else:
128
+ ret = ssim_map.mean(1).mean(1).mean(1)
129
+
130
+ if full:
131
+ return ret, cs
132
+ return ret
133
+
134
+
135
+ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
136
+ device = img1.device
137
+ weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device).type_as(img1)
138
+ levels = weights.size()[0]
139
+ mssim = []
140
+ mcs = []
141
+ for _ in range(levels):
142
+ sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
143
+ mssim.append(sim)
144
+ mcs.append(cs)
145
+
146
+ img1 = F.avg_pool2d(img1, (2, 2))
147
+ img2 = F.avg_pool2d(img2, (2, 2))
148
+
149
+ mssim = torch.stack(mssim)
150
+ mcs = torch.stack(mcs)
151
+
152
+ # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
153
+ if normalize:
154
+ mssim = (mssim + 1) / 2
155
+ mcs = (mcs + 1) / 2
156
+
157
+ pow1 = mcs ** weights
158
+ pow2 = mssim ** weights
159
+ # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
160
+ output = torch.prod(pow1[:-1] * pow2[-1])
161
+ return output
162
+
163
+
164
+ # Classes to re-use window
165
+ class SSIM(torch.nn.Module):
166
+ def __init__(self, window_size=11, size_average=True, val_range=None):
167
+ super(SSIM, self).__init__()
168
+ self.window_size = window_size
169
+ self.size_average = size_average
170
+ self.val_range = val_range
171
+
172
+ # Assume 3 channel for SSIM
173
+ self.channel = 3
174
+ self.window = create_window(window_size, channel=self.channel)
175
+
176
+ def forward(self, img1, img2):
177
+ (_, channel, _, _) = img1.size()
178
+
179
+ if channel == self.channel and self.window.dtype == img1.dtype:
180
+ window = self.window
181
+ else:
182
+ window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
183
+ self.window = window
184
+ self.channel = channel
185
+
186
+ _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
187
+ dssim = (1 - _ssim) / 2
188
+ return dssim
189
+
190
+ class MSSSIM(torch.nn.Module):
191
+ def __init__(self, window_size=11, size_average=True, channel=3):
192
+ super(MSSSIM, self).__init__()
193
+ self.window_size = window_size
194
+ self.size_average = size_average
195
+ self.channel = channel
196
+
197
+ def forward(self, img1, img2):
198
+ return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
model/warplayer.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+ backwarp_tenGrid = {}
6
+
7
+
8
+ def warp(tenInput, tenFlow):
9
+ k = (str(tenFlow.device), str(tenFlow.size()))
10
+ if k not in backwarp_tenGrid:
11
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device).view(
12
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
13
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device).view(
14
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
15
+ backwarp_tenGrid[k] = torch.cat(
16
+ [tenHorizontal, tenVertical], 1).to(tenFlow.device)
17
+
18
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
19
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
20
+
21
+ grid = backwarp_tenGrid[k].type_as(tenFlow)
22
+
23
+ g = (grid + tenFlow).permute(0, 2, 3, 1)
24
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
requirements.txt CHANGED
@@ -10,3 +10,10 @@ imageio
10
  imageio-ffmpeg
11
  opencv-python
12
  torchao==0.11.0
 
 
 
 
 
 
 
 
10
  imageio-ffmpeg
11
  opencv-python
12
  torchao==0.11.0
13
+
14
+ numpy>=1.16, <=1.23.5
15
+ # tqdm>=4.35.0
16
+ # sk-video>=1.1.10
17
+ # opencv-python>=4.1.2
18
+ # moviepy>=1.0.3
19
+ torchvision