jeffbolznv commited on
Commit
e7bebe6
·
1 Parent(s): 1cecf5d

vulkan: Implement grouped query attention in the coopmat2 FA shader (llama/12559)

Browse files

When adjacent batches of Q share the same batches of K/V, batch them into
the same workgroup. For example, when:

dst(128,32,1,1) = FA(q(128,1,32,1), k(128,16640,8,1), v(128,16640,8,1))

previously we would run 32 workgroups computing 1 result each, now we will
run 8 workgroups computing 4 results each.

This doesn't directly translate to better performance (at least when you have
>=32 SMs), but in a subsequent change I'll enable split_k which will scale much
better with 4x fewer workgroups.

ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -31,6 +31,7 @@
31
 
32
  #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
33
  #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
 
34
 
35
  #define VK_VENDOR_ID_AMD 0x1002
36
  #define VK_VENDOR_ID_APPLE 0x106b
@@ -501,6 +502,8 @@ struct vk_flash_attn_push_constants {
501
  uint32_t n_head_log2;
502
  float m0;
503
  float m1;
 
 
504
  };
505
 
506
  struct vk_op_push_constants {
@@ -5402,7 +5405,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5402
  const uint32_t nbm1 = mask ? mask->nb[1] : 0;
5403
 
5404
  const uint32_t D = neq0;
5405
- const uint32_t N = neq1;
5406
  const uint32_t KV = nek1;
5407
 
5408
  GGML_ASSERT(ne0 == D);
@@ -5460,6 +5463,22 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5460
  vk_pipeline pipeline = pipelines[aligned];
5461
  assert(pipeline);
5462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5463
  if (dryrun) {
5464
  // Request descriptor sets
5465
  ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
@@ -5549,7 +5568,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5549
  v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
5550
  nbm1,
5551
  scale, max_bias, logit_softcap,
5552
- mask != nullptr, n_head_log2, m0, m1 };
5553
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5554
  {
5555
  vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
@@ -5558,7 +5577,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5558
  vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5559
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5560
  },
5561
- sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 });
5562
  }
5563
 
5564
  static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
 
31
 
32
  #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
33
  #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
34
+ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
35
 
36
  #define VK_VENDOR_ID_AMD 0x1002
37
  #define VK_VENDOR_ID_APPLE 0x106b
 
502
  uint32_t n_head_log2;
503
  float m0;
504
  float m1;
505
+
506
+ uint32_t gqa_ratio;
507
  };
508
 
509
  struct vk_op_push_constants {
 
5405
  const uint32_t nbm1 = mask ? mask->nb[1] : 0;
5406
 
5407
  const uint32_t D = neq0;
5408
+ uint32_t N = neq1;
5409
  const uint32_t KV = nek1;
5410
 
5411
  GGML_ASSERT(ne0 == D);
 
5463
  vk_pipeline pipeline = pipelines[aligned];
5464
  assert(pipeline);
5465
 
5466
+ uint32_t gqa_ratio = 1;
5467
+ uint32_t qk_ratio = neq2 / nek2;
5468
+ uint32_t workgroups_x = (uint32_t)neq1;
5469
+ uint32_t workgroups_y = (uint32_t)neq2;
5470
+ uint32_t workgroups_z = (uint32_t)neq3;
5471
+
5472
+ if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows &&
5473
+ qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
5474
+ // grouped query attention - make the N dimension equal to gqa_ratio, reduce
5475
+ // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
5476
+ // and change addressing calculations to index Q's dimension 2.
5477
+ gqa_ratio = qk_ratio;
5478
+ N = gqa_ratio;
5479
+ workgroups_y /= N;
5480
+ }
5481
+
5482
  if (dryrun) {
5483
  // Request descriptor sets
5484
  ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
 
5568
  v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
5569
  nbm1,
5570
  scale, max_bias, logit_softcap,
5571
+ mask != nullptr, n_head_log2, m0, m1, gqa_ratio };
5572
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5573
  {
5574
  vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
 
5577
  vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5578
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5579
  },
5580
+ sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
5581
  }
5582
 
5583
  static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp CHANGED
@@ -61,6 +61,8 @@ layout (push_constant) uniform parameter {
61
  uint32_t n_head_log2;
62
  float m0;
63
  float m1;
 
 
64
  } p;
65
 
66
  layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
@@ -103,6 +105,28 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
103
  #define DECODEFUNC
104
  #endif
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  void main() {
107
  #ifdef NEEDS_INIT_IQ_SHMEM
108
  init_iq_shmem(gl_WorkGroupSize);
@@ -116,7 +140,9 @@ void main() {
116
 
117
  const uint32_t i = gl_WorkGroupID.x;
118
 
119
- const uint32_t iq2 = gl_WorkGroupID.y;
 
 
120
  const uint32_t iq3 = gl_WorkGroupID.z;
121
 
122
  // broadcast factors
@@ -149,8 +175,10 @@ void main() {
149
  tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
150
  tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
151
 
152
- // nb?1 are already divided by the type size and are in units of elements
153
- uint32_t q_stride = p.nb01;
 
 
154
  uint32_t k_stride = p.nb11;
155
  uint32_t v_stride = p.nb21;
156
  // hint to the compiler that strides are aligned for the aligned variant of the shader
@@ -182,16 +210,11 @@ void main() {
182
  L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
183
  M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
184
 
185
- ACC_TYPE slope = ACC_TYPE(1.0);
186
 
187
  // ALiBi
188
  if (p.max_bias > 0.0f) {
189
- const uint32_t h = iq2;
190
-
191
- const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
192
- const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
193
-
194
- slope = pow(base, ACC_TYPE(exph));
195
  }
196
 
197
  [[dont_unroll]]
@@ -215,12 +238,16 @@ void main() {
215
  if (p.mask != 0) {
216
  tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
217
  tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
 
 
 
 
218
 
219
  coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
220
 
221
  coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
222
 
223
- S += slope*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
224
  }
225
 
226
  // Clear padding elements to -inf, so they don't contribute to rowmax
@@ -297,13 +324,18 @@ void main() {
297
 
298
  O = Ldiag*O;
299
 
300
- tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
301
- tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
302
-
303
- // permute dimensions
304
- tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
305
  uint32_t o_offset = iq3*p.ne2*p.ne1;
306
 
307
  coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
308
- coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute);
 
 
 
 
 
 
 
 
 
 
309
  }
 
61
  uint32_t n_head_log2;
62
  float m0;
63
  float m1;
64
+
65
+ uint32_t gqa_ratio;
66
  } p;
67
 
68
  layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
 
105
  #define DECODEFUNC
106
  #endif
107
 
108
+ // Store the output when doing grouped query attention.
109
+ // Rows index by Q's dimension 2, and the first N rows are valid.
110
+ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
111
+ {
112
+ if (r < N && c < D) {
113
+ uint32_t offset = (iq2 + r) * D + c;
114
+ data_o[o_offset + offset] = D_TYPE(elem);
115
+ }
116
+ return elem;
117
+ }
118
+
119
+ // Load the slope matrix, indexed by Q's dimension 2.
120
+ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
121
+ {
122
+ const uint32_t h = iq2 + (r & (p.gqa_ratio - 1));
123
+
124
+ const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
125
+ const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
126
+
127
+ return ACC_TYPE(pow(base, ACC_TYPE(exph)));
128
+ }
129
+
130
  void main() {
131
  #ifdef NEEDS_INIT_IQ_SHMEM
132
  init_iq_shmem(gl_WorkGroupSize);
 
140
 
141
  const uint32_t i = gl_WorkGroupID.x;
142
 
143
+ // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
144
+ // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
145
+ const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
146
  const uint32_t iq3 = gl_WorkGroupID.z;
147
 
148
  // broadcast factors
 
175
  tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
176
  tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
177
 
178
+ // nb?1 are already divided by the type size and are in units of elements.
179
+ // When using grouped query attention, Q is indexed by iq2, so the stride
180
+ // should be nb02 (which is in bytes).
181
+ uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
182
  uint32_t k_stride = p.nb11;
183
  uint32_t v_stride = p.nb21;
184
  // hint to the compiler that strides are aligned for the aligned variant of the shader
 
210
  L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
211
  M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
212
 
213
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
214
 
215
  // ALiBi
216
  if (p.max_bias > 0.0f) {
217
+ coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
 
 
 
 
 
218
  }
219
 
220
  [[dont_unroll]]
 
238
  if (p.mask != 0) {
239
  tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
240
  tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
241
+ // When using grouped query attention, all rows use the same mask.
242
+ if (p.gqa_ratio > 1) {
243
+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
244
+ }
245
 
246
  coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
247
 
248
  coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
249
 
250
+ S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
251
  }
252
 
253
  // Clear padding elements to -inf, so they don't contribute to rowmax
 
324
 
325
  O = Ldiag*O;
326
 
 
 
 
 
 
327
  uint32_t o_offset = iq3*p.ne2*p.ne1;
328
 
329
  coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
330
+ if (p.gqa_ratio > 1) {
331
+ coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
332
+ } else {
333
+ tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
334
+ tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
335
+
336
+ // permute dimensions
337
+ tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
338
+
339
+ coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
340
+ }
341
  }