jeffbolznv commited on
Commit
45fbb42
·
1 Parent(s): 367fa85

vulkan: optimize flash attention split_k_reduce (llama/14554)

Browse files

* vulkan: allow FA split_k with smaller KV values

* vulkan: spread split_k_reduce work across more threads

k_num can get rather large. Use the whole workgroup to reduce the M/L values.

Launch a thread for each element in the HSV dimension of the output. Helps a
lot for large HSV (like deepseek).

ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -2706,7 +2706,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2706
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2707
 
2708
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2709
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
2710
  ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2711
 
2712
  for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -6252,13 +6252,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6252
  const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
6253
 
6254
  // Try to use split_k when KV is large enough to be worth the overhead
6255
- if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
6256
  // Try to run two workgroups per SM.
6257
  split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
6258
  if (split_k > 1) {
6259
  // Try to evenly split KV into split_k chunks, but it needs to be a multiple
6260
  // of "align", so recompute split_k based on that.
6261
- split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
6262
  split_k = CEIL_DIV(KV, split_kv);
6263
  workgroups_x = split_k;
6264
  }
@@ -6392,7 +6392,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6392
  vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6393
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6394
  },
6395
- pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
6396
  } else {
6397
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
6398
  {
 
2706
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2707
 
2708
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2709
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
2710
  ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2711
 
2712
  for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
 
6252
  const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
6253
 
6254
  // Try to use split_k when KV is large enough to be worth the overhead
6255
+ if (workgroups_x == 1 && shader_core_count > 0) {
6256
  // Try to run two workgroups per SM.
6257
  split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
6258
  if (split_k > 1) {
6259
  // Try to evenly split KV into split_k chunks, but it needs to be a multiple
6260
  // of "align", so recompute split_k based on that.
6261
+ split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
6262
  split_k = CEIL_DIV(KV, split_kv);
6263
  workgroups_x = split_k;
6264
  }
 
6392
  vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6393
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6394
  },
6395
+ pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
6396
  } else {
6397
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
6398
  {
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #extension GL_EXT_control_flow_attributes : enable
4
 
5
- #define BLOCK_SIZE 32
6
 
7
- layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
8
 
9
  layout (binding = 0) readonly buffer A {float data_a[];};
10
  layout (binding = 1) writeonly buffer D {float data_d[];};
@@ -16,6 +16,8 @@ layout (push_constant) uniform parameter {
16
  uint k_num;
17
  } p;
18
 
 
 
19
  void main() {
20
  // Each workgroup handles a row
21
  const uint n = gl_WorkGroupID.x;
@@ -32,23 +34,51 @@ void main() {
32
 
33
  // Compute the max m value for the row
34
  float m_max = -1.0/0.0;
35
- [[unroll]] for (uint k = 0; k < k_num; ++k) {
36
- float m = data_a[m_offset + k * lm_stride];
37
  m_max = max(m_max, m);
38
  }
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  // Compute L based on m_max
41
  float L = 0;
42
- [[unroll]] for (uint k = 0; k < k_num; ++k) {
43
- float l = data_a[l_offset + k * lm_stride];
44
- float m = data_a[m_offset + k * lm_stride];
45
  L += exp(m - m_max) * l;
46
  }
47
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  L = 1.0 / L;
49
 
 
 
50
  // Scale and sum the O contributions based on m_max and store the result to memory
51
- for (uint d = tid; d < D; d += BLOCK_SIZE) {
52
  float O = 0.0;
53
  [[unroll]] for (uint k = 0; k < k_num; ++k) {
54
  uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
 
2
 
3
  #extension GL_EXT_control_flow_attributes : enable
4
 
5
+ layout(constant_id = 0) const uint BLOCK_SIZE = 32;
6
 
7
+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8
 
9
  layout (binding = 0) readonly buffer A {float data_a[];};
10
  layout (binding = 1) writeonly buffer D {float data_d[];};
 
16
  uint k_num;
17
  } p;
18
 
19
+ shared float tmpsh[BLOCK_SIZE];
20
+
21
  void main() {
22
  // Each workgroup handles a row
23
  const uint n = gl_WorkGroupID.x;
 
34
 
35
  // Compute the max m value for the row
36
  float m_max = -1.0/0.0;
37
+ for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
38
+ float m = data_a[m_offset + (k + tid) * lm_stride];
39
  m_max = max(m_max, m);
40
  }
41
 
42
+ // reduce across the workgroup
43
+ tmpsh[tid] = m_max;
44
+ barrier();
45
+ [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
46
+ if (tid < s) {
47
+ m_max = max(m_max, tmpsh[tid + s]);
48
+ tmpsh[tid] = m_max;
49
+ }
50
+ barrier();
51
+ }
52
+ m_max = tmpsh[0];
53
+
54
+ barrier();
55
+
56
  // Compute L based on m_max
57
  float L = 0;
58
+ for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
59
+ float l = data_a[l_offset + (k + tid) * lm_stride];
60
+ float m = data_a[m_offset + (k + tid) * lm_stride];
61
  L += exp(m - m_max) * l;
62
  }
63
 
64
+ // reduce across the workgroup
65
+ tmpsh[tid] = L;
66
+ barrier();
67
+ [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
68
+ if (tid < s) {
69
+ L += tmpsh[tid + s];
70
+ tmpsh[tid] = L;
71
+ }
72
+ barrier();
73
+ }
74
+ L = tmpsh[0];
75
+
76
  L = 1.0 / L;
77
 
78
+ // D dimension is split across workgroups in the y dimension
79
+ uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
80
  // Scale and sum the O contributions based on m_max and store the result to memory
81
+ if (d < D) {
82
  float O = 0.0;
83
  [[unroll]] for (uint k = 0; k < k_num; ++k) {
84
  uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;