import torch import numpy as np import torch.nn.functional as F from tqdm.auto import trange from importlib import import_module sampling = None BACKEND = None if not BACKEND: try: _ = import_module("modules.sd_samplers_kdiffusion") sampling = import_module("k_diffusion.sampling") BACKEND = "WebUI" except ImportError: pass if not BACKEND: try: sampling = import_module("comfy.k_diffusion.sampling") BACKEND = "ComfyUI" except ImportError: pass class _Rescaler: """Context manager for resizing model inputs (e.g., latents, masks) to match tensor size.""" def __init__(self, model, x, mode='nearest-exact', **extra_args): self.model = model self.x = x self.mode = mode self.extra_args = extra_args self.backend = BACKEND if self.backend == "WebUI": self.init_latent = getattr(model, "init_latent", None) self.mask = getattr(model, "mask", None) self.nmask = getattr(model, "nmask", None) elif self.backend == "ComfyUI": self.latent_image = getattr(model, "latent_image", None) self.noise = getattr(model, "noise", None) self.denoise_mask = self.extra_args.get("denoise_mask", None) def __enter__(self): if self.x.shape[1] not in [1, 3, 4]: raise ValueError(f"Unsupported number of channels: {self.x.shape[1]}") if self.backend == "WebUI": if self.init_latent is not None and self.init_latent.shape[2:4] != self.x.shape[2:4]: self.model.init_latent = F.interpolate(self.init_latent, size=self.x.shape[2:4], mode=self.mode) if self.mask is not None and self.mask.shape[1:3] != self.x.shape[2:4]: self.model.mask = F.interpolate(self.mask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0) if self.nmask is not None and self.nmask.shape[1:3] != self.x.shape[2:4]: self.model.nmask = F.interpolate(self.nmask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0) elif self.backend == "ComfyUI": if self.latent_image is not None and self.latent_image.shape[2:4] != self.x.shape[2:4]: self.model.latent_image = F.interpolate(self.latent_image, size=self.x.shape[2:4], mode=self.mode) if self.noise is not None and self.noise.shape[2:4] != self.x.shape[2:4]: self.model.noise = F.interpolate(self.noise, size=self.x.shape[2:4], mode=self.mode) if self.denoise_mask is not None and self.denoise_mask.shape[2:4] != self.x.shape[2:4]: self.extra_args["denoise_mask"] = F.interpolate(self.denoise_mask, size=self.x.shape[2:4], mode=self.mode) return self def __exit__(self, exc_type, exc_value, traceback): if self.backend == "WebUI": if hasattr(self, "init_latent"): self.model.init_latent = self.init_latent if hasattr(self, "mask"): self.model.mask = self.mask if hasattr(self, "nmask"): self.model.nmask = self.nmask elif self.backend == "ComfyUI": if hasattr(self, "latent_image"): self.model.latent_image = self.latent_image if hasattr(self, "noise"): self.model.noise = self.noise if hasattr(self, "denoise_mask"): self.extra_args["denoise_mask"] = self.denoise_mask def default_noise_sampler(x): """Generate random noise with the same shape as x.""" return lambda sigma, sigma_next: torch.randn_like(x) def get_ancestral_step(sigma_from, sigma_to, eta=1.): """Calculate sigma_down and sigma_up for ancestral sampling step.""" if not eta: return sigma_to, 0. sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 return sigma_down, sigma_up def compute_gaussian_curvature(x): """Compute Gaussian curvature of the input tensor. Args: x: Input tensor of shape [batch, channels, height, width]. Returns: torch.Tensor: Curvature tensor of shape [batch, height, width]. """ if x.dim() != 4 or min(x.shape[2], x.shape[3]) < 2: raise ValueError(f"Invalid tensor dimensions or size: {x.shape}") x_3d = torch.mean(x, dim=1, keepdim=True) grad_x, grad_y = torch.gradient(x_3d.squeeze(1), dim=(1, 2)) grad_x = torch.clamp(grad_x, -1e2, 1e2) grad_y = torch.clamp(grad_y, -1e2, 1e2) grad_xx, grad_xy = torch.gradient(grad_x, dim=(1, 2)) grad_yx, grad_yy = torch.gradient(grad_y, dim=(1, 2)) grad_xx = torch.clamp(grad_xx, -1e2, 1e2) grad_xy = torch.clamp(grad_xy, -1e2, 1e2) grad_yy = torch.clamp(grad_yy, -1e2, 1e2) curvature = (grad_xx * grad_yy - grad_xy**2) / (1 + grad_x**2 + grad_y**2 + 1e-8)**2 curvature = torch.clamp(curvature, min=-0.5, max=0.5) # TODO: Implement convolution-based gradient for better performance return curvature def compute_simple_curvature(x): """Compute simple curvature based on gradient magnitudes. Args: x: Input tensor of shape [batch, channels, height, width]. Returns: torch.Tensor: Curvature tensor of shape [batch, height, width]. """ if x.dim() != 4 or min(x.shape[2], x.shape[3]) < 2: raise ValueError(f"Invalid tensor dimensions or size: {x.shape}") x_3d = torch.mean(x, dim=1, keepdim=True) grad_x, grad_y = torch.gradient(x_3d.squeeze(1), dim=(1, 2)) grad_x = torch.clamp(grad_x, -1e2, 1e2) grad_y = torch.clamp(grad_y, -1e2, 1e2) curvature = torch.abs(grad_x) + torch.abs(grad_y) curvature = torch.clamp(curvature, min=0.0, max=0.5) return curvature def compute_normals(x): """Compute surface normals of the input tensor. Args: x: Input tensor of shape [batch, channels, height, width]. Returns: torch.Tensor: Normals tensor of shape [batch, 3, height, width]. """ if x.dim() != 4 or min(x.shape[2], x.shape[3]) < 2: raise ValueError(f"Invalid tensor dimensions or size: {x.shape}") x_3d = torch.mean(x, dim=1, keepdim=True) grad_x, grad_y = torch.gradient(x_3d.squeeze(1), dim=(1, 2)) grad_x = torch.clamp(grad_x, -1e2, 1e2) grad_y = torch.clamp(grad_y, -1e2, 1e2) normals = torch.stack([-grad_x, -grad_y, torch.ones_like(grad_x)], dim=1) norm = torch.norm(normals, dim=1, keepdim=True) normals = normals / (norm + 1e-6) # TODO: Implement convolution-based gradient for better performance return normals def compute_dynamic_eta(sigma, sigma_max, eta_start=0.0, eta_end=0.5): """Compute dynamic eta based on sigma ratio.""" sigma_ratio = sigma / sigma_max return eta_end + (eta_start - eta_end) * sigma_ratio def apply_geometric_blur(x, curvature, sigma=1.0): """Apply Gaussian blur modulated by curvature. Args: x: Input tensor of shape [batch, channels, height, width]. curvature: Curvature tensor of shape [batch, height, width]. sigma: Base sigma for Gaussian blur. Returns: torch.Tensor: Blurred tensor of same shape as x. """ if x.dim() != 4: raise ValueError(f"Invalid tensor dimensions: {x.shape}") sigma = sigma * (1 - curvature.mean().item()) kernel_size = min(int(2 * np.ceil(2 * sigma) + 1), 15) # Cap kernel size if kernel_size % 2 == 0: kernel_size += 1 return F.gaussian_blur(x, kernel_size=[kernel_size, kernel_size], sigma=[sigma, sigma]) def apply_mask(x, mask=None, latent_mask=None): """Apply mask to the input tensor. Args: x: Input tensor of shape [batch, channels, height, width]. mask: Mask tensor of same shape as x. latent_mask: Latent mask tensor of same shape as x. Returns: torch.Tensor: Masked tensor of same shape as x. """ if mask is not None and latent_mask is not None: if mask.shape != x.shape or latent_mask.shape != x.shape: raise ValueError(f"Mismatch in mask shapes: x={x.shape}, mask={mask.shape}, latent_mask={latent_mask.shape}") x = x * (1 - latent_mask) + mask * latent_mask return x @torch.no_grad() def _in_resized_space_vec(x, model, dt, sigma_hat, interpolation_mode='nearest-exact', **extra_args): """Perform denoising in resized space with interpolation.""" if x.dim() != 4 or min(x.shape[2], x.shape[3]) < 2: raise ValueError(f"Invalid tensor dimensions or size: {x.shape}") m, n = x.shape[2], x.shape[3] y = F.interpolate(x, size=(m + 2, n + 2), mode=interpolation_mode) with _Rescaler(model, y, interpolation_mode, **extra_args) as rescaler: denoised = model(y, sigma_hat * y.new_ones([y.shape[0]]), **extra_args) d = (y - denoised) / sigma_hat d = torch.clamp(d, -1e2, 1e2) d = F.interpolate(d * dt, size=(m, n), mode=interpolation_mode) return d @torch.no_grad() def dy_sampling_step(x, model, dt, sigma_hat, interpolation_mode='nearest-exact', **extra_args): """Perform dynamic sampling step with reduced grid.""" if x.shape[1] not in [1, 3, 4]: raise ValueError(f"Unsupported number of channels: {x.shape[1]}") original_shape = x.shape batch_size, channels, m, n = original_shape[0], original_shape[1], original_shape[2] // 2, original_shape[3] // 2 extra_row = x.shape[2] % 2 == 1 extra_col = x.shape[3] % 2 == 1 if extra_row: extra_row_content = x[:, :, -1:, :] x = x[:, :, :-1, :] if extra_col: extra_col_content = x[:, :, :, -1:] x = x[:, :, :, :-1] a_list = x.unfold(2, 2, 2).unfold(3, 2, 2).contiguous().view(batch_size, channels, m * n, 2, 2) c = a_list[:, :, :, 1, 1].view(batch_size, channels, m, n) with _Rescaler(model, c, interpolation_mode, **extra_args) as rescaler: denoised = model(c, sigma_hat * c.new_ones([c.shape[0]]), **rescaler.extra_args) d = sampling.to_d(c, sigma_hat, denoised) c = c + d * dt d_list = c.view(batch_size, channels, m * n, 1, 1) a_list[:, :, :, 1, 1] = d_list[:, :, :, 0, 0] x = a_list.view(batch_size, channels, m, n, 2, 2).permute(0, 1, 2, 4, 3, 5).reshape(batch_size, channels, 2 * m, 2 * n) if extra_row or extra_col: x_expanded = torch.zeros(original_shape, dtype=x.dtype, device=x.device) x_expanded[:, :, :2 * m, :2 * n] = x if extra_row: x_expanded[:, :, -1:, :2 * n + 1] = extra_row_content if extra_col: x_expanded[:, :, :2 * m, -1:] = extra_col_content if extra_row and extra_col: x_expanded[:, :, -1:, -1:] = extra_col_content[:, :, -1:, :] x = x_expanded return x @torch.no_grad() def sample_Kohaku_LoNyu_Yog_v1_test(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None, eta=1., interpolation_mode='nearest-exact'): """Kohaku_LoNyu_Yog sampling with combined standard and inverted steps.""" if x.shape[1] not in [1, 3, 4]: raise ValueError(f"Unsupported number of channels: {x.shape[1]}") extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler for i in trange(len(sigmas) - 1, disable=disable): gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = sampling.to_d(x, sigma_hat, denoised) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) dt = sigma_down - sigmas[i] if i <= (len(sigmas) - 1) / 2: x2 = -x with _Rescaler(model, x2, interpolation_mode, **extra_args) as rescaler: denoised2 = model(x2, sigma_hat * s_in, **extra_args) d2 = sampling.to_d(x2, sigma_hat, denoised2) x3 = x + ((d + d2) / 2) * dt with _Rescaler(model, x3, interpolation_mode, **extra_args) as rescaler: denoised3 = model(x3, sigma_hat * s_in, **extra_args) d3 = sampling.to_d(x3, sigma_hat, denoised3) real_d = (d + d3) / 2 x = x + real_d * dt x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up else: x = x + d * dt return x @torch.no_grad() def kohaku_lonyu_yog_stochastic_v1_test(model, x, sigmas, extra_args=None, callback=None, disable=None, langevin_strength=0.05, interpolation_mode='nearest-exact'): """Stochastic Kohaku_LoNyu_Yog sampling with curvature-based noise.""" if x.shape[1] not in [1, 3, 4]: raise ValueError(f"Unsupported number of channels: {x.shape[1]}") extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): dt = sigmas[i + 1] - sigmas[i] denoised = model(x, sigmas[i] * s_in, **extra_args) curvature = compute_simple_curvature(x) noise_scale = min(langevin_strength * curvature.mean(), 0.4) noise = torch.randn_like(x) * noise_scale * torch.sqrt(sigmas[i]) grad = (x - denoised) / sigmas[i] grad = torch.clamp(grad, -1e2, 1e2) x = x + grad * dt + noise * curvature return x @torch.no_grad() def kohaku_lonyu_yog_compatible_v1_test(model, x, sigmas, extra_args=None, callback=None, disable=None, interpolation_mode='nearest-exact'): """Kohaku_LoNyu_Yog sampling compatible with masks.""" if x.shape[1] not in [1, 3, 4]: raise ValueError(f"Unsupported number of channels: {x.shape[1]}") extra_args = {} if extra_args is None else extra_args mask = extra_args.get('mask', None) latent_mask = extra_args.get('latent_mask', None) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): dt = sigmas[i + 1] - sigmas[i] denoised = model(x, sigmas[i] * s_in, **extra_args) grad = (x - denoised) / sigmas[i] grad = torch.clamp(grad, -1e2, 1e2) x = x + grad * dt x = apply_mask(x, mask, latent_mask) return x @torch.no_grad() def sample_Kohaku_LoNyu_Yog_v2_v1_test(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.0, noise_sampler=None, eta_start=0.9, eta_end=0.6, use_normals=True, interpolation_mode='nearest-exact'): """Kohaku_LoNyu_Yog v2 sampling with geometric corrections.""" if x.shape[1] not in [1, 3, 4]: raise ValueError(f"Unsupported number of channels: {x.shape[1]}") extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler sigma_max = torch.max(sigmas) old_denoised = None for i in trange(len(sigmas) - 1, disable=disable): sigma = sigmas[i] dt = sigmas[i + 1] - sigma gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigma <= s_tmax else 0. sigma_hat = sigma * (1 + gamma) curvature = compute_gaussian_curvature(x) eta = compute_dynamic_eta(sigma, sigma_max, eta_start, eta_end) if gamma > 0: eps = torch.randn_like(x) * s_noise x = x + eps * torch.sqrt(sigma_hat**2 - sigma**2) denoised = model(x, sigma_hat * s_in, **extra_args) grad = (x - denoised) / sigma_hat grad = torch.clamp(grad, -1e2, 1e2) if use_normals: normals = compute_normals(x) normal_correction = torch.einsum('bchw,bkhw->bchw', grad, normals) normal_correction = torch.clamp(normal_correction, -1e2, 1e2) curvature_weight = 1.0 + 0.5 * torch.abs(curvature) grad = grad * curvature_weight + 0.05 * normal_correction if old_denoised is not None: denoised = 0.6 * denoised + 0.4 * old_denoised x = x + grad * dt if sigmas[i + 1] > 0: noise = noise_sampler(sigma, sigmas[i + 1]) * s_noise * eta x = x + noise * curvature old_denoised = denoised return x @torch.no_grad() def kohaku_lonyu_yog_geo_compatible_v1_test(model, x, sigmas, extra_args=None, callback=None, disable=None, interpolation_mode='nearest-exact'): """Kohaku_LoNyu_Yog sampling with geometric corrections and mask support.""" if x.shape[1] not in [1, 3, 4]: raise ValueError(f"Unsupported number of channels: {x.shape[1]}") extra_args = {} if extra_args is None else extra_args mask = extra_args.get('mask', None) latent_mask = extra_args.get('latent_mask', None) s_in = x.new_ones([x.shape[0]]) old_denoised = None for i in trange(len(sigmas) - 1, disable=disable): dt = sigmas[i + 1] - sigmas[i] denoised = model(x, sigmas[i] * s_in, **extra_args) curvature = compute_gaussian_curvature(x) normals = compute_normals(x) grad = (x - denoised) / sigmas[i] grad = torch.clamp(grad, -1e2, 1e2) curvature_weight = 1.0 + 0.5 * torch.abs(curvature) normal_correction = torch.einsum('bchw,bkhw->bchw', grad, normals) normal_correction = torch.clamp(normal_correction, -1e2, 1e2) corrected_grad = grad * curvature_weight + 0.05 * normal_correction if old_denoised is not None: denoised = 0.6 * denoised + 0.4 * old_denoised x = x + corrected_grad * dt x = apply_mask(x, mask, latent_mask) old_denoised = denoised return x @torch.no_grad() def kohaku_lonyu_yog_dy_v1_test(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0.05, s_tmin=0., s_tmax=float('inf'), s_noise=0.5, interpolation_mode='nearest-exact'): """Kohaku_LoNyu_Yog sampling with dynamic steps and geometric corrections.""" if x.shape[1] not in [1, 3, 4]: raise ValueError(f"Unsupported number of channels: {x.shape[1]}") extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) old_denoised = None for i in trange(len(sigmas) - 1, disable=disable): sigma = sigmas[i] dt = sigmas[i + 1] - sigma gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigma <= s_tmax else 0. sigma_hat = sigma * (1 + gamma) if gamma > 0: eps = torch.randn_like(x) * s_noise x = x + eps * torch.sqrt(sigma_hat**2 - sigma**2) denoised = model(x, sigma_hat * s_in, **extra_args) grad = (x - denoised) / sigma_hat grad = torch.clamp(grad, -1e2, 1e2) curvature = compute_gaussian_curvature(x) normals = compute_normals(x) curvature_weight = 1.0 + 0.5 * torch.abs(curvature) normal_correction = torch.einsum('bchw,bkhw->bchw', grad, normals) normal_correction = torch.clamp(normal_correction, -1e2, 1e2) corrected_grad = grad * curvature_weight + 0.05 * normal_correction if sigmas[i + 1] > 0 and i % 2 == 1: x = dy_sampling_step(x, model, dt, sigma_hat, interpolation_mode, **extra_args) else: x = x + corrected_grad * dt if old_denoised is not None: denoised = 0.6 * denoised + 0.4 * old_denoised old_denoised = denoised return x