jeffbolznv commited on
Commit
5ab06d6
·
1 Parent(s): fac18c1

vulkan: Implement split_k for coopmat2 flash attention. (llama/12627)

Browse files

When using group query attention, we have one workgroup per KV batch and this
can be very few workgroups (e.g. just 8 in some models). Enable split_k to
spread the work across SMs. This helps a lot when the KV cache is large.

ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -353,6 +353,7 @@ struct vk_device_struct {
353
  vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
354
  vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
355
  vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
 
356
 
357
  std::unordered_map<std::string, vk_pipeline_ref> pipelines;
358
  std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
@@ -504,6 +505,8 @@ struct vk_flash_attn_push_constants {
504
  float m1;
505
 
506
  uint32_t gqa_ratio;
 
 
507
  };
508
 
509
  struct vk_op_push_constants {
@@ -1476,7 +1479,7 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
1476
 
1477
  // small rows, large cols
1478
  if (small_rows) {
1479
- return {flash_attention_num_small_rows, 128};
1480
  }
1481
  // small cols to reduce register count
1482
  if (ggml_is_quantized(type) || D == 256) {
@@ -2332,6 +2335,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2332
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2333
 
2334
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
 
2335
  ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2336
 
2337
  for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -5479,9 +5483,38 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5479
  workgroups_y /= N;
5480
  }
5481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5482
  if (dryrun) {
5483
  // Request descriptor sets
5484
  ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
 
 
 
5485
  return;
5486
  }
5487
 
@@ -5502,8 +5535,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5502
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
5503
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5504
 
5505
- ggml_vk_sync_buffers(subctx);
5506
-
5507
  vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
5508
  size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
5509
 
@@ -5568,16 +5599,45 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5568
  v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
5569
  nbm1,
5570
  scale, max_bias, logit_softcap,
5571
- mask != nullptr, n_head_log2, m0, m1, gqa_ratio };
5572
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5573
- {
5574
- vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5575
- vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5576
- vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5577
- vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5578
- vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5579
- },
5580
- sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5581
  }
5582
 
5583
  static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
 
353
  vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
354
  vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
355
  vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
356
+ vk_pipeline pipeline_flash_attn_split_k_reduce;
357
 
358
  std::unordered_map<std::string, vk_pipeline_ref> pipelines;
359
  std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
 
505
  float m1;
506
 
507
  uint32_t gqa_ratio;
508
+ uint32_t split_kv;
509
+ uint32_t k_num;
510
  };
511
 
512
  struct vk_op_push_constants {
 
1479
 
1480
  // small rows, large cols
1481
  if (small_rows) {
1482
+ return {flash_attention_num_small_rows, 64};
1483
  }
1484
  // small cols to reduce register count
1485
  if (ggml_is_quantized(type) || D == 256) {
 
2335
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2336
 
2337
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2338
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
2339
  ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2340
 
2341
  for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
 
5483
  workgroups_y /= N;
5484
  }
5485
 
5486
+ uint32_t split_kv = KV;
5487
+ uint32_t split_k = 1;
5488
+
5489
+ if (gqa_ratio > 1 && ctx->device->shader_core_count > 0) {
5490
+ GGML_ASSERT(workgroups_x == 1);
5491
+ // Try to run two workgroups per SM.
5492
+ split_k = ctx->device->shader_core_count * 2 / workgroups_y;
5493
+ if (split_k > 1) {
5494
+ // Try to evenly split KV into split_k chunks, but it needs to be a multiple
5495
+ // of "align", so recompute split_k based on that.
5496
+ split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
5497
+ split_k = CEIL_DIV(KV, split_kv);
5498
+ workgroups_x = split_k;
5499
+ }
5500
+ }
5501
+
5502
+ // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
5503
+ // and the per-row m and L values (ne1 rows).
5504
+ const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
5505
+ if (split_k_size > ctx->device->max_memory_allocation_size) {
5506
+ GGML_ABORT("Requested preallocation size is too large");
5507
+ }
5508
+ if (ctx->prealloc_size_split_k < split_k_size) {
5509
+ ctx->prealloc_size_split_k = split_k_size;
5510
+ }
5511
+
5512
  if (dryrun) {
5513
  // Request descriptor sets
5514
  ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
5515
+ if (split_k > 1) {
5516
+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
5517
+ }
5518
  return;
5519
  }
5520
 
 
5535
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
5536
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5537
 
 
 
5538
  vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
5539
  size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
5540
 
 
5599
  v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
5600
  nbm1,
5601
  scale, max_bias, logit_softcap,
5602
+ mask != nullptr, n_head_log2, m0, m1,
5603
+ gqa_ratio, split_kv, split_k };
5604
+
5605
+ ggml_vk_sync_buffers(subctx);
5606
+
5607
+ if (split_k > 1) {
5608
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5609
+ {
5610
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5611
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5612
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5613
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5614
+ vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
5615
+ },
5616
+ // We only use split_k when group query attention is enabled, which means
5617
+ // there's no more than one tile of rows (i.e. workgroups_x would have been
5618
+ // one). We reuse workgroups_x to mean the number of splits, so we need to
5619
+ // cancel out the divide by wg_denoms[0].
5620
+ sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
5621
+
5622
+ ggml_vk_sync_buffers(subctx);
5623
+ const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
5624
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
5625
+ {
5626
+ vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
5627
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5628
+ },
5629
+ pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 });
5630
+ } else {
5631
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5632
+ {
5633
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5634
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5635
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5636
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5637
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5638
+ },
5639
+ sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
5640
+ }
5641
  }
5642
 
5643
  static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp CHANGED
@@ -63,6 +63,8 @@ layout (push_constant) uniform parameter {
63
  float m1;
64
 
65
  uint32_t gqa_ratio;
 
 
66
  } p;
67
 
68
  layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
@@ -116,6 +118,16 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
116
  return elem;
117
  }
118
 
 
 
 
 
 
 
 
 
 
 
119
  // Load the slope matrix, indexed by Q's dimension 2.
120
  ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
121
  {
@@ -135,10 +147,18 @@ void main() {
135
  const uint32_t N = p.N;
136
  const uint32_t KV = p.KV;
137
 
 
 
 
 
 
 
 
 
138
  const uint32_t Tr = CEIL_DIV(N, Br);
139
- const uint32_t Tc = CEIL_DIV(KV, Bc);
140
 
141
- const uint32_t i = gl_WorkGroupID.x;
 
142
 
143
  // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
144
  // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
@@ -218,7 +238,7 @@ void main() {
218
  }
219
 
220
  [[dont_unroll]]
221
- for (uint32_t j = 0; j < Tc; ++j) {
222
 
223
  coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
224
 
@@ -312,6 +332,20 @@ void main() {
312
  O = coopMatMulAdd(P_A, V, O);
313
  }
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
316
 
317
  // resize L by using smear/reduce
 
63
  float m1;
64
 
65
  uint32_t gqa_ratio;
66
+ uint32_t split_kv;
67
+ uint32_t k_num;
68
  } p;
69
 
70
  layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
 
118
  return elem;
119
  }
120
 
121
+ // Store column zero. This is used to save per-row m and L values for split_k.
122
+ ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
123
+ {
124
+ if (r < N && c == 0) {
125
+ uint32_t offset = iq2 + r;
126
+ data_o[o_offset + offset] = D_TYPE(elem);
127
+ }
128
+ return elem;
129
+ }
130
+
131
  // Load the slope matrix, indexed by Q's dimension 2.
132
  ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
133
  {
 
147
  const uint32_t N = p.N;
148
  const uint32_t KV = p.KV;
149
 
150
+ uint32_t i = gl_WorkGroupID.x;
151
+ uint32_t split_k_index = 0;
152
+
153
+ if (p.k_num > 1) {
154
+ i = 0;
155
+ split_k_index = gl_WorkGroupID.x;
156
+ }
157
+
158
  const uint32_t Tr = CEIL_DIV(N, Br);
 
159
 
160
+ const uint32_t start_j = split_k_index * p.split_kv / Bc;
161
+ const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
162
 
163
  // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
164
  // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
 
238
  }
239
 
240
  [[dont_unroll]]
241
+ for (uint32_t j = start_j; j < end_j; ++j) {
242
 
243
  coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
244
 
 
332
  O = coopMatMulAdd(P_A, V, O);
333
  }
334
 
335
+ // If there is split_k, then the split_k resolve shader does the final
336
+ // division by L. Store the intermediate O value and per-row m and L values.
337
+ if (p.k_num > 1) {
338
+ coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
339
+
340
+ uint32_t o_offset = D * p.ne1 * split_k_index;
341
+ coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
342
+
343
+ o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
344
+ coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
345
+ coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
346
+ return;
347
+ }
348
+
349
  coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
350
 
351
  // resize L by using smear/reduce
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #extension GL_EXT_control_flow_attributes : enable
4
+
5
+ #define BLOCK_SIZE 32
6
+
7
+ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
8
+
9
+ layout (binding = 0) readonly buffer A {float data_a[];};
10
+ layout (binding = 1) writeonly buffer D {float data_d[];};
11
+
12
+ layout (push_constant) uniform parameter {
13
+ uint D;
14
+ uint N;
15
+ uint k_num;
16
+ } p;
17
+
18
+ void main() {
19
+ // Each workgroup handles a row
20
+ const uint n = gl_WorkGroupID.x;
21
+ const uint tid = gl_LocalInvocationID.x;
22
+
23
+ uint D = p.D;
24
+ uint N = p.N;
25
+ uint k_num = p.k_num;
26
+
27
+ uint l_offset = D * N * k_num + n;
28
+ uint m_offset = D * N * k_num + N + n;
29
+ uint lm_stride = N * 2;
30
+
31
+ // Compute the max m value for the row
32
+ float m_max = -1.0/0.0;
33
+ [[unroll]] for (uint k = 0; k < k_num; ++k) {
34
+ float m = data_a[m_offset + k * lm_stride];
35
+ m_max = max(m_max, m);
36
+ }
37
+
38
+ // Compute L based on m_max
39
+ float L = 0;
40
+ [[unroll]] for (uint k = 0; k < k_num; ++k) {
41
+ float l = data_a[l_offset + k * lm_stride];
42
+ float m = data_a[m_offset + k * lm_stride];
43
+ L += exp(m - m_max) * l;
44
+ }
45
+
46
+ L = 1.0 / L;
47
+
48
+ // Scale and sum the O contributions based on m_max and store the result to memory
49
+ for (uint d = tid; d < D; d += BLOCK_SIZE) {
50
+ float O = 0.0;
51
+ [[unroll]] for (uint k = 0; k < k_num; ++k) {
52
+ uint o_offset = D * N * k + D * n + d;
53
+ float m = data_a[m_offset + k * lm_stride];
54
+ O += exp(m - m_max) * data_a[o_offset];
55
+ }
56
+ O *= L;
57
+ data_d[D * n + d] = O;
58
+ }
59
+ }
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -465,6 +465,7 @@ void process_shaders() {
465
  string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
466
 
467
  string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
 
468
  string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
469
 
470
  string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 
465
  string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
466
 
467
  string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
468
+ string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
469
  string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
470
 
471
  string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});