Spaces:
Running
Running
Commit ·
e7bebe6
1
Parent(s): 1cecf5d
vulkan: Implement grouped query attention in the coopmat2 FA shader (llama/12559)
Browse filesWhen adjacent batches of Q share the same batches of K/V, batch them into
the same workgroup. For example, when:
dst(128,32,1,1) = FA(q(128,1,32,1), k(128,16640,8,1), v(128,16640,8,1))
previously we would run 32 workgroups computing 1 result each, now we will
run 8 workgroups computing 4 results each.
This doesn't directly translate to better performance (at least when you have
>=32 SMs), but in a subsequent change I'll enable split_k which will scale much
better with 4x fewer workgroups.
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -31,6 +31,7 @@
|
|
| 31 |
|
| 32 |
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
|
| 33 |
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
|
|
|
| 34 |
|
| 35 |
#define VK_VENDOR_ID_AMD 0x1002
|
| 36 |
#define VK_VENDOR_ID_APPLE 0x106b
|
|
@@ -501,6 +502,8 @@ struct vk_flash_attn_push_constants {
|
|
| 501 |
uint32_t n_head_log2;
|
| 502 |
float m0;
|
| 503 |
float m1;
|
|
|
|
|
|
|
| 504 |
};
|
| 505 |
|
| 506 |
struct vk_op_push_constants {
|
|
@@ -5402,7 +5405,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 5402 |
const uint32_t nbm1 = mask ? mask->nb[1] : 0;
|
| 5403 |
|
| 5404 |
const uint32_t D = neq0;
|
| 5405 |
-
|
| 5406 |
const uint32_t KV = nek1;
|
| 5407 |
|
| 5408 |
GGML_ASSERT(ne0 == D);
|
|
@@ -5460,6 +5463,22 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 5460 |
vk_pipeline pipeline = pipelines[aligned];
|
| 5461 |
assert(pipeline);
|
| 5462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5463 |
if (dryrun) {
|
| 5464 |
// Request descriptor sets
|
| 5465 |
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
|
@@ -5549,7 +5568,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 5549 |
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
| 5550 |
nbm1,
|
| 5551 |
scale, max_bias, logit_softcap,
|
| 5552 |
-
mask != nullptr, n_head_log2, m0, m1 };
|
| 5553 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
| 5554 |
{
|
| 5555 |
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
|
@@ -5558,7 +5577,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 5558 |
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
| 5559 |
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
| 5560 |
},
|
| 5561 |
-
sizeof(vk_flash_attn_push_constants), &pc, {
|
| 5562 |
}
|
| 5563 |
|
| 5564 |
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
|
|
|
|
| 31 |
|
| 32 |
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
|
| 33 |
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
| 34 |
+
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
| 35 |
|
| 36 |
#define VK_VENDOR_ID_AMD 0x1002
|
| 37 |
#define VK_VENDOR_ID_APPLE 0x106b
|
|
|
|
| 502 |
uint32_t n_head_log2;
|
| 503 |
float m0;
|
| 504 |
float m1;
|
| 505 |
+
|
| 506 |
+
uint32_t gqa_ratio;
|
| 507 |
};
|
| 508 |
|
| 509 |
struct vk_op_push_constants {
|
|
|
|
| 5405 |
const uint32_t nbm1 = mask ? mask->nb[1] : 0;
|
| 5406 |
|
| 5407 |
const uint32_t D = neq0;
|
| 5408 |
+
uint32_t N = neq1;
|
| 5409 |
const uint32_t KV = nek1;
|
| 5410 |
|
| 5411 |
GGML_ASSERT(ne0 == D);
|
|
|
|
| 5463 |
vk_pipeline pipeline = pipelines[aligned];
|
| 5464 |
assert(pipeline);
|
| 5465 |
|
| 5466 |
+
uint32_t gqa_ratio = 1;
|
| 5467 |
+
uint32_t qk_ratio = neq2 / nek2;
|
| 5468 |
+
uint32_t workgroups_x = (uint32_t)neq1;
|
| 5469 |
+
uint32_t workgroups_y = (uint32_t)neq2;
|
| 5470 |
+
uint32_t workgroups_z = (uint32_t)neq3;
|
| 5471 |
+
|
| 5472 |
+
if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows &&
|
| 5473 |
+
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
| 5474 |
+
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
| 5475 |
+
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
| 5476 |
+
// and change addressing calculations to index Q's dimension 2.
|
| 5477 |
+
gqa_ratio = qk_ratio;
|
| 5478 |
+
N = gqa_ratio;
|
| 5479 |
+
workgroups_y /= N;
|
| 5480 |
+
}
|
| 5481 |
+
|
| 5482 |
if (dryrun) {
|
| 5483 |
// Request descriptor sets
|
| 5484 |
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
|
|
|
| 5568 |
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
| 5569 |
nbm1,
|
| 5570 |
scale, max_bias, logit_softcap,
|
| 5571 |
+
mask != nullptr, n_head_log2, m0, m1, gqa_ratio };
|
| 5572 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
| 5573 |
{
|
| 5574 |
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
|
|
|
| 5577 |
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
| 5578 |
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
| 5579 |
},
|
| 5580 |
+
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
|
| 5581 |
}
|
| 5582 |
|
| 5583 |
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
CHANGED
|
@@ -61,6 +61,8 @@ layout (push_constant) uniform parameter {
|
|
| 61 |
uint32_t n_head_log2;
|
| 62 |
float m0;
|
| 63 |
float m1;
|
|
|
|
|
|
|
| 64 |
} p;
|
| 65 |
|
| 66 |
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
|
@@ -103,6 +105,28 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
|
|
| 103 |
#define DECODEFUNC
|
| 104 |
#endif
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
void main() {
|
| 107 |
#ifdef NEEDS_INIT_IQ_SHMEM
|
| 108 |
init_iq_shmem(gl_WorkGroupSize);
|
|
@@ -116,7 +140,9 @@ void main() {
|
|
| 116 |
|
| 117 |
const uint32_t i = gl_WorkGroupID.x;
|
| 118 |
|
| 119 |
-
|
|
|
|
|
|
|
| 120 |
const uint32_t iq3 = gl_WorkGroupID.z;
|
| 121 |
|
| 122 |
// broadcast factors
|
|
@@ -149,8 +175,10 @@ void main() {
|
|
| 149 |
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
| 150 |
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
| 151 |
|
| 152 |
-
// nb?1 are already divided by the type size and are in units of elements
|
| 153 |
-
|
|
|
|
|
|
|
| 154 |
uint32_t k_stride = p.nb11;
|
| 155 |
uint32_t v_stride = p.nb21;
|
| 156 |
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
|
@@ -182,16 +210,11 @@ void main() {
|
|
| 182 |
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
| 183 |
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
|
| 184 |
|
| 185 |
-
ACC_TYPE
|
| 186 |
|
| 187 |
// ALiBi
|
| 188 |
if (p.max_bias > 0.0f) {
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
| 192 |
-
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
| 193 |
-
|
| 194 |
-
slope = pow(base, ACC_TYPE(exph));
|
| 195 |
}
|
| 196 |
|
| 197 |
[[dont_unroll]]
|
|
@@ -215,12 +238,16 @@ void main() {
|
|
| 215 |
if (p.mask != 0) {
|
| 216 |
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
| 217 |
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
| 220 |
|
| 221 |
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
| 222 |
|
| 223 |
-
S +=
|
| 224 |
}
|
| 225 |
|
| 226 |
// Clear padding elements to -inf, so they don't contribute to rowmax
|
|
@@ -297,13 +324,18 @@ void main() {
|
|
| 297 |
|
| 298 |
O = Ldiag*O;
|
| 299 |
|
| 300 |
-
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
|
| 301 |
-
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
|
| 302 |
-
|
| 303 |
-
// permute dimensions
|
| 304 |
-
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
| 305 |
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
| 306 |
|
| 307 |
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
}
|
|
|
|
| 61 |
uint32_t n_head_log2;
|
| 62 |
float m0;
|
| 63 |
float m1;
|
| 64 |
+
|
| 65 |
+
uint32_t gqa_ratio;
|
| 66 |
} p;
|
| 67 |
|
| 68 |
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
|
|
|
| 105 |
#define DECODEFUNC
|
| 106 |
#endif
|
| 107 |
|
| 108 |
+
// Store the output when doing grouped query attention.
|
| 109 |
+
// Rows index by Q's dimension 2, and the first N rows are valid.
|
| 110 |
+
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
| 111 |
+
{
|
| 112 |
+
if (r < N && c < D) {
|
| 113 |
+
uint32_t offset = (iq2 + r) * D + c;
|
| 114 |
+
data_o[o_offset + offset] = D_TYPE(elem);
|
| 115 |
+
}
|
| 116 |
+
return elem;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
// Load the slope matrix, indexed by Q's dimension 2.
|
| 120 |
+
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
| 121 |
+
{
|
| 122 |
+
const uint32_t h = iq2 + (r & (p.gqa_ratio - 1));
|
| 123 |
+
|
| 124 |
+
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
| 125 |
+
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
| 126 |
+
|
| 127 |
+
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
void main() {
|
| 131 |
#ifdef NEEDS_INIT_IQ_SHMEM
|
| 132 |
init_iq_shmem(gl_WorkGroupSize);
|
|
|
|
| 140 |
|
| 141 |
const uint32_t i = gl_WorkGroupID.x;
|
| 142 |
|
| 143 |
+
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
| 144 |
+
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
| 145 |
+
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
| 146 |
const uint32_t iq3 = gl_WorkGroupID.z;
|
| 147 |
|
| 148 |
// broadcast factors
|
|
|
|
| 175 |
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
| 176 |
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
| 177 |
|
| 178 |
+
// nb?1 are already divided by the type size and are in units of elements.
|
| 179 |
+
// When using grouped query attention, Q is indexed by iq2, so the stride
|
| 180 |
+
// should be nb02 (which is in bytes).
|
| 181 |
+
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
| 182 |
uint32_t k_stride = p.nb11;
|
| 183 |
uint32_t v_stride = p.nb21;
|
| 184 |
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
|
|
|
| 210 |
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
| 211 |
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
|
| 212 |
|
| 213 |
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
|
| 214 |
|
| 215 |
// ALiBi
|
| 216 |
if (p.max_bias > 0.0f) {
|
| 217 |
+
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
}
|
| 219 |
|
| 220 |
[[dont_unroll]]
|
|
|
|
| 238 |
if (p.mask != 0) {
|
| 239 |
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
| 240 |
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
| 241 |
+
// When using grouped query attention, all rows use the same mask.
|
| 242 |
+
if (p.gqa_ratio > 1) {
|
| 243 |
+
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
|
| 244 |
+
}
|
| 245 |
|
| 246 |
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
| 247 |
|
| 248 |
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
| 249 |
|
| 250 |
+
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
| 251 |
}
|
| 252 |
|
| 253 |
// Clear padding elements to -inf, so they don't contribute to rowmax
|
|
|
|
| 324 |
|
| 325 |
O = Ldiag*O;
|
| 326 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
| 328 |
|
| 329 |
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
| 330 |
+
if (p.gqa_ratio > 1) {
|
| 331 |
+
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
| 332 |
+
} else {
|
| 333 |
+
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
|
| 334 |
+
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
|
| 335 |
+
|
| 336 |
+
// permute dimensions
|
| 337 |
+
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
| 338 |
+
|
| 339 |
+
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
|
| 340 |
+
}
|
| 341 |
}
|