File size: 5,357 Bytes
4b0f2f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import math
import pytest
import torch
import sage_attention as sa
cuda_available = torch.cuda.is_available()
@pytest.mark.skipif(not cuda_available, reason="CUDA is required")
@pytest.mark.parametrize("tensor_layout", ["HND", "NHD"])
def test_per_block_int8_shapes_and_types(tensor_layout):
device = "cuda"
dtype = torch.float16
if tensor_layout == "HND":
q = torch.randn(2, 4, 129, 128, dtype=dtype, device=device)
k = torch.randn(2, 4, 257, 128, dtype=dtype, device=device)
expected_q_scale_shape = (2, 4, math.ceil(129 / 128))
expected_k_scale_shape = (2, 4, math.ceil(257 / 64))
else:
q = torch.randn(2, 129, 4, 128, dtype=dtype, device=device)
k = torch.randn(2, 257, 4, 128, dtype=dtype, device=device)
expected_q_scale_shape = (2, 4, math.ceil(129 / 128))
expected_k_scale_shape = (2, 4, math.ceil(257 / 64))
km = (
torch.randn(2, 4, 128, dtype=dtype, device=device)
if tensor_layout == "HND"
else torch.randn(2, 4, 128, dtype=dtype, device=device)
)
q_int8, q_scale, k_int8, k_scale = sa.per_block_int8(
q, k, km, tensor_layout=tensor_layout
)
assert q_int8.shape == q.shape and q_int8.dtype == torch.int8
assert k_int8.shape == k.shape and k_int8.dtype == torch.int8
assert q_scale.shape == expected_q_scale_shape and q_scale.dtype == torch.float32
assert k_scale.shape == expected_k_scale_shape and k_scale.dtype == torch.float32
assert q_int8.device == q.device == k.device == q_scale.device == k_scale.device
assert torch.isfinite(q_scale).all()
assert torch.isfinite(k_scale).all()
@pytest.mark.skipif(not cuda_available, reason="CUDA is required")
@pytest.mark.parametrize("tensor_layout", ["HND", "NHD"])
@pytest.mark.parametrize("head_dim", [64, 128])
def test_per_warp_int8_shapes_and_types(tensor_layout, head_dim):
device = "cuda"
dtype = torch.float16
if tensor_layout == "HND":
q = torch.randn(1, 2, 130, head_dim, dtype=dtype, device=device)
k = torch.randn(1, 2, 70, head_dim, dtype=dtype, device=device)
expected_q_scale_shape = (
1,
2,
math.ceil(130 / 128) * (128 // (16 if head_dim == 128 else 32)),
)
expected_k_scale_shape = (1, 2, math.ceil(70 / 64))
else:
q = torch.randn(1, 130, 2, head_dim, dtype=dtype, device=device)
k = torch.randn(1, 70, 2, head_dim, dtype=dtype, device=device)
expected_q_scale_shape = (
1,
2,
math.ceil(130 / 128) * (128 // (16 if head_dim == 128 else 32)),
)
expected_k_scale_shape = (1, 2, math.ceil(70 / 64))
q_int8, q_scale, k_int8, k_scale = sa.per_warp_int8(
q,
k,
tensor_layout=tensor_layout,
BLKQ=128,
WARPQ=(16 if head_dim == 128 else 32),
BLKK=64,
)
assert q_int8.shape == q.shape and q_int8.dtype == torch.int8
assert k_int8.shape == k.shape and k_int8.dtype == torch.int8
assert q_scale.shape == expected_q_scale_shape and q_scale.dtype == torch.float32
assert k_scale.shape == expected_k_scale_shape and k_scale.dtype == torch.float32
assert torch.isfinite(q_scale).all()
assert torch.isfinite(k_scale).all()
@pytest.mark.skipif(not cuda_available, reason="CUDA is required")
@pytest.mark.parametrize("tensor_layout", ["HND", "NHD"])
def test_sub_mean_properties(tensor_layout):
device = "cuda"
dtype = torch.float16
if tensor_layout == "HND":
v = torch.randn(2, 3, 65, 128, dtype=dtype, device=device)
seq_dim = 2
nh_dim = 1
else:
v = torch.randn(2, 65, 3, 128, dtype=dtype, device=device)
seq_dim = 1
nh_dim = 2
v_smoothed, vm = sa.sub_mean(v, tensor_layout=tensor_layout)
assert v_smoothed.shape == v.shape and v_smoothed.dtype == torch.float16
assert vm.shape == (v.size(0), v.size(nh_dim), v.size(-1)) and vm.dtype == v.dtype
# The mean along the sequence dimension of smoothed v should be ~0 (in fp16)
mean_after = v_smoothed.mean(dim=seq_dim)
assert torch.isfinite(mean_after).all()
assert (mean_after.abs() < 1e-1).all()
@pytest.mark.skipif(not cuda_available, reason="CUDA is required")
@pytest.mark.parametrize("tensor_layout", ["HND", "NHD"])
@pytest.mark.parametrize("smooth_v", [True, False])
def test_per_channel_fp8_shapes_and_outputs(tensor_layout, smooth_v):
device = "cuda"
dtype = torch.float16
if tensor_layout == "HND":
v = torch.randn(2, 3, 77, 128, dtype=dtype, device=device)
kv_len = v.size(2)
else:
v = torch.randn(2, 77, 3, 128, dtype=dtype, device=device)
kv_len = v.size(1)
v_fp8, v_scale, vm = sa.per_channel_fp8(
v, tensor_layout=tensor_layout, smooth_v=smooth_v
)
assert v_fp8.dtype == torch.float8_e4m3fn
assert v_scale.shape == (2, 3, 128)
if smooth_v:
assert vm is not None and vm.shape == (2, 3, 128) and vm.dtype == torch.float32
else:
assert vm is None
# Padded seq len should be multiple of 64
padded_len = ((kv_len + 63) // 64) * 64
if tensor_layout == "HND":
assert v_fp8.shape == (2, 3, 128, padded_len)
else:
assert v_fp8.shape == (2, 128, 3, padded_len)
assert torch.isfinite(v_scale).all()
|