LightDiffusion-Next / src /Attention /AttentionMethods.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""Attention implementations supporting PyTorch, XFormers, and SageAttention."""
try:
import sageattention
except ImportError:
sageattention = None
try:
import spas_sage_attn
except ImportError:
spas_sage_attn = None
try:
import xformers
BROKEN_XFORMERS = xformers.__version__.startswith("0.0.2") and not xformers.__version__.startswith("0.0.20")
except ImportError:
xformers = None
BROKEN_XFORMERS = False
import torch
import torch.nn.functional as F
# Pre-computed padding targets for SageAttention supported dimensions
# Maps dimension -> (target_dim, padding_amount) or None if no padding needed
_SAGE_PAD_CACHE: dict[int, tuple[int, int] | None] = {}
def _get_sage_padding(dim: int) -> tuple[int, int] | None:
"""Get pre-computed padding target for a given dimension.
Returns (target_dim, pad_amount) or None if no padding needed.
"""
if dim not in _SAGE_PAD_CACHE:
if dim in (64, 96, 128):
_SAGE_PAD_CACHE[dim] = None # No padding needed
elif dim < 64:
_SAGE_PAD_CACHE[dim] = (64, 64 - dim)
elif dim < 128:
_SAGE_PAD_CACHE[dim] = (128, 128 - dim)
else:
_SAGE_PAD_CACHE[dim] = None # Unsupported, no padding
return _SAGE_PAD_CACHE[dim]
def _pad_for_sage(q, k, v, dim):
"""Pad tensors to supported SageAttention dimensions (64, 96, 128)."""
padding = _get_sage_padding(dim)
if padding is None:
return q, k, v, dim
target, pad = padding
return (F.pad(q, (0, pad)), F.pad(k, (0, pad)), F.pad(v, (0, pad)), dim)
def _reshape_for_heads(q, k, v, heads, flux=False, skip_reshape=False):
"""Reshape tensors for multi-head attention."""
if flux and skip_reshape:
return q, k, v, q.shape[-1]
b = q.shape[0]
dim_head = q.shape[-1] // heads
reshape_fn = lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2).contiguous()
return reshape_fn(q), reshape_fn(k), reshape_fn(v), dim_head
def _reshape_output(out, b, heads, dim_head, flux=False, skip_reshape=False):
"""Reshape attention output back to original format."""
if flux and not skip_reshape:
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
if not flux:
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
def attention_pytorch(q, k, v, heads, mask=None, skip_reshape=False, flux=False):
"""Multi-head attention using PyTorch SDPA."""
b = q.shape[0]
if not flux:
seq_q, seq_kv = q.shape[1], k.shape[1]
dim_head = q.shape[-1] // heads
q = q.view(b, seq_q, heads, dim_head).transpose(1, 2)
k = k.view(b, seq_kv, heads, dim_head).transpose(1, 2)
v = v.view(b, seq_kv, heads, dim_head).transpose(1, 2)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
return out.transpose(1, 2).reshape(b, seq_q, heads * dim_head)
dim_head = q.shape[-1] if skip_reshape else q.shape[-1] // heads
if not skip_reshape:
q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
def attention_xformers(q, k, v, heads, mask=None, skip_reshape=False, flux=False):
"""Multi-head attention using XFormers."""
b = q.shape[0]
if not flux:
dim_head = q.shape[-1] // heads
q, k, v = [t.view(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head).contiguous()
for t in (q, k, v)]
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
except (NotImplementedError, RuntimeError):
out = F.scaled_dot_product_attention(
q.view(b, heads, -1, dim_head), k.view(b, heads, -1, dim_head), v.view(b, heads, -1, dim_head),
attn_mask=mask, dropout_p=0.0, is_causal=False).reshape(b * heads, -1, dim_head)
return out.view(b, heads, -1, dim_head).permute(0, 2, 1, 3).reshape(b, -1, heads * dim_head)
dim_head = q.shape[-1] if skip_reshape else q.shape[-1] // heads
if BROKEN_XFORMERS and b * heads > 65535:
return attention_pytorch(q, k, v, heads, mask, skip_reshape, flux)
if skip_reshape:
q, k, v = [t.reshape(b * heads, -1, dim_head) for t in (q, k, v)]
else:
q, k, v = [t.reshape(b, -1, heads, dim_head) for t in (q, k, v)]
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
except (NotImplementedError, RuntimeError):
out = F.scaled_dot_product_attention(
q.view(b, heads, -1, dim_head), k.view(b, heads, -1, dim_head), v.view(b, heads, -1, dim_head),
attn_mask=mask, dropout_p=0.0, is_causal=False).reshape(b * heads, -1, dim_head)
if skip_reshape:
return out.view(b, heads, -1, dim_head).permute(0, 2, 1, 3).reshape(b, -1, heads * dim_head)
return out.reshape(b, -1, heads * dim_head)
def attention_sage(q, k, v, heads, mask=None, skip_reshape=False, flux=False):
"""Multi-head attention using SageAttention."""
if mask is not None and mask.device != q.device:
mask = mask.to(q.device)
b = q.shape[0]
dim_head = q.shape[-1] if (flux and skip_reshape) else q.shape[-1] // heads
if not (flux and skip_reshape):
if not flux:
q, k, v = [t.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous()
for t in (q, k, v)]
else:
q, k, v = [t.reshape(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
# Pad and compute attention
qp, kp, vp, orig_dim = _pad_for_sage(q, k, v, dim_head)
if orig_dim != dim_head or orig_dim in [64, 96, 128]:
out = sageattention.sageattn(qp, kp, vp, tensor_layout="HND", attn_mask=mask, is_causal=False)
if orig_dim != dim_head:
out = out[..., :orig_dim]
elif dim_head > 128:
# Fallback for unsupported dimensions
try:
out = xformers.ops.memory_efficient_attention(
q.reshape(b * heads, -1, dim_head), k.reshape(b * heads, -1, dim_head),
v.reshape(b * heads, -1, dim_head), attn_bias=mask)
out = out.reshape(b, heads, -1, dim_head)
except:
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
else:
out = sageattention.sageattn(qp, kp, vp, tensor_layout="HND", attn_mask=mask, is_causal=False)
out = out[..., :dim_head]
if not flux:
return out.reshape(b, heads, -1, dim_head).permute(0, 2, 1, 3).reshape(b, -1, heads * dim_head)
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
def attention_sparge(q, k, v, heads, mask=None, skip_reshape=False, flux=False):
"""Multi-head attention using SpargeAttn (Sparse + SageAttention)."""
b = q.shape[0]
dim_head = q.shape[-1] if (flux and skip_reshape) else q.shape[-1] // heads
if not (flux and skip_reshape):
if not flux:
q, k, v = [t.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous()
for t in (q, k, v)]
else:
q, k, v = [t.reshape(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
qp, kp, vp, orig_dim = _pad_for_sage(q, k, v, dim_head)
sparge_kwargs = dict(simthreshd1=0.6, cdfthreshd=0.97, pvthreshd=15, is_causal=False)
if orig_dim != dim_head or orig_dim in [64, 96, 128]:
out = spas_sage_attn.spas_sage2_attn_meansim_cuda(qp, kp, vp, **sparge_kwargs)
if orig_dim != dim_head:
out = out[..., :orig_dim]
elif dim_head > 128:
out = sageattention.sageattn(q, k, v, tensor_layout="HND", attn_mask=mask, is_causal=False)
else:
out = spas_sage_attn.spas_sage2_attn_meansim_cuda(qp, kp, vp, **sparge_kwargs)
out = out[..., :dim_head]
if not flux:
return out.reshape(b, heads, -1, dim_head).permute(0, 2, 1, 3).reshape(b, -1, heads * dim_head)
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
# Simple 4D attention variants (B, C, H, W format)
def sage_attention(q, k, v):
"""SageAttention for 4D tensors (B, C, H, W)."""
B, C, H, W = q.shape
q, k, v = [t.view(B, 1, C, -1).transpose(2, 3).contiguous() for t in (q, k, v)]
qp, kp, vp, orig = _pad_for_sage(q, k, v, C)
if C > 128:
out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
else:
out = sageattention.sageattn(qp, kp, vp, tensor_layout="HND", is_causal=False)
if orig != C:
out = out[..., :C]
return out.transpose(2, 3).reshape(B, C, H, W)
def sparge_attention(q, k, v):
"""SpargeAttn for 4D tensors (B, C, H, W)."""
B, C, H, W = q.shape
q, k, v = [t.view(B, 1, C, -1).transpose(2, 3).contiguous() for t in (q, k, v)]
qp, kp, vp, orig = _pad_for_sage(q, k, v, C)
sparge_kwargs = dict(simthreshd1=0.6, cdfthreshd=0.97, pvthreshd=15, is_causal=False)
if C > 128:
out = sageattention.sageattn(q, k, v, tensor_layout="HND", is_causal=False)
else:
out = spas_sage_attn.spas_sage2_attn_meansim_cuda(qp, kp, vp, **sparge_kwargs)
if orig != C:
out = out[..., :C]
return out.transpose(2, 3).reshape(B, C, H, W)
def xformers_attention(q, k, v):
"""XFormers attention for 4D tensors (B, C, H, W)."""
B, C, H, W = q.shape
q, k, v = [t.view(B, C, -1).transpose(1, 2).contiguous() for t in (q, k, v)]
try:
out = xformers.ops.memory_efficient_attention(q, k, v)
except (NotImplementedError, RuntimeError):
out = F.scaled_dot_product_attention(q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1), dropout_p=0.0, is_causal=False).squeeze(1)
return out.transpose(1, 2).reshape(B, C, H, W)
def pytorch_attention(q, k, v):
"""PyTorch attention for 4D tensors (B, C, H, W)."""
B, C, H, W = q.shape
q, k, v = [t.view(B, 1, C, -1).transpose(2, 3).contiguous() for t in (q, k, v)]
out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
return out.transpose(2, 3).reshape(B, C, H, W)