jeffbolznv commited on
Commit
d7e9115
·
1 Parent(s): 5ec4382

vulkan: support fattn sinks (llama/15126)

Browse files
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -2286,14 +2286,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
2286
  };
2287
 
2288
  #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
2289
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2290
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2291
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2292
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2293
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2294
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2295
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2296
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2297
 
2298
  #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
2299
  CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
@@ -2910,7 +2910,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2910
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2911
 
2912
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2913
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
2914
  ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2915
 
2916
  for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -6507,11 +6507,14 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
6507
  return supported;
6508
  }
6509
 
6510
- static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
6511
  VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
6512
  std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
6513
  std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
6514
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
 
 
 
6515
  std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
6516
 
6517
  GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
@@ -6710,10 +6713,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6710
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
6711
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
6712
 
6713
- vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
6714
- size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
6715
 
6716
- bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false;
6717
 
6718
  if (ctx->device->uma) {
6719
  ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
@@ -6728,6 +6731,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6728
  ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
6729
  M_uma = d_M != nullptr;
6730
  }
 
 
 
 
6731
  }
6732
 
6733
 
@@ -6763,7 +6770,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6763
  }
6764
  }
6765
 
6766
- uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
 
 
 
 
 
 
 
 
 
 
6767
 
6768
  const vk_flash_attn_push_constants pc = { N, KV,
6769
  (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
@@ -6787,6 +6804,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6787
  vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
6788
  vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
6789
  vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
 
6790
  vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6791
  },
6792
  // We only use split_k when group query attention is enabled, which means
@@ -6796,10 +6814,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6796
  pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
6797
 
6798
  ggml_vk_sync_buffers(subctx);
6799
- const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
6800
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
6801
  {
6802
  vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
 
6803
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6804
  },
6805
  pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
@@ -6810,6 +6829,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6810
  vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
6811
  vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
6812
  vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
 
6813
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6814
  },
6815
  pc, { workgroups_x, workgroups_y, workgroups_z });
@@ -9874,7 +9894,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9874
  break;
9875
 
9876
  case GGML_OP_FLASH_ATTN_EXT:
9877
- ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
9878
 
9879
  break;
9880
 
@@ -10951,8 +10971,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10951
  if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
10952
  return false;
10953
  }
10954
- // TODO: support attention sinks [TAG_ATTN_SINKS]
10955
- if (op->src[4]) {
10956
  return false;
10957
  }
10958
  if (op->src[0]->type != GGML_TYPE_F32) {
@@ -11547,6 +11566,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11547
  if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
11548
  const float * params = (const float *)tensor->op_params;
11549
  tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
 
 
 
11550
  } else if (tensor->op == GGML_OP_MUL_MAT) {
11551
  tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
11552
  } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
 
2286
  };
2287
 
2288
  #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
2289
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2290
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2291
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2292
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2293
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2294
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2295
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2296
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2297
 
2298
  #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
2299
  CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
 
2910
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2911
 
2912
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2913
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
2914
  ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2915
 
2916
  for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
 
6507
  return supported;
6508
  }
6509
 
6510
+ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst, bool dryrun = false) {
6511
  VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
6512
  std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
6513
  std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
6514
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
6515
+ if (sinks) {
6516
+ std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3];
6517
+ }
6518
  std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
6519
 
6520
  GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
 
6713
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
6714
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
6715
 
6716
+ vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr;
6717
+ size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0;
6718
 
6719
+ bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false;
6720
 
6721
  if (ctx->device->uma) {
6722
  ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
 
6731
  ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
6732
  M_uma = d_M != nullptr;
6733
  }
6734
+ if (sinks) {
6735
+ ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset);
6736
+ S_uma = d_S != nullptr;
6737
+ }
6738
  }
6739
 
6740
 
 
6770
  }
6771
  }
6772
 
6773
+ if (!S_uma) {
6774
+ d_S = d_Q;
6775
+ s_buf_offset = q_buf_offset;
6776
+ if (sinks) {
6777
+ ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context;
6778
+ d_S = s_buf_ctx->dev_buffer;
6779
+ s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs;
6780
+ }
6781
+ }
6782
+
6783
+ uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
6784
 
6785
  const vk_flash_attn_push_constants pc = { N, KV,
6786
  (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
 
6804
  vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
6805
  vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
6806
  vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
6807
+ vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
6808
  vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6809
  },
6810
  // We only use split_k when group query attention is enabled, which means
 
6814
  pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
6815
 
6816
  ggml_vk_sync_buffers(subctx);
6817
+ const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
6818
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
6819
  {
6820
  vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6821
+ vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
6822
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6823
  },
6824
  pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
 
6829
  vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
6830
  vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
6831
  vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
6832
+ vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
6833
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6834
  },
6835
  pc, { workgroups_x, workgroups_y, workgroups_z });
 
9894
  break;
9895
 
9896
  case GGML_OP_FLASH_ATTN_EXT:
9897
+ ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node, dryrun);
9898
 
9899
  break;
9900
 
 
10971
  if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
10972
  return false;
10973
  }
10974
+ if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) {
 
10975
  return false;
10976
  }
10977
  if (op->src[0]->type != GGML_TYPE_F32) {
 
11566
  if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
11567
  const float * params = (const float *)tensor->op_params;
11568
  tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
11569
+ if (src_clone[4]) {
11570
+ ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]);
11571
+ }
11572
  } else if (tensor->op == GGML_OP_MUL_MAT) {
11573
  tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
11574
  } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp CHANGED
@@ -305,6 +305,27 @@ void main() {
305
  return;
306
  }
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  float Lfrcp[Br];
309
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
310
  Lfrcp[r] = 1.0 / Lf[r];
 
305
  return;
306
  }
307
 
308
+ if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
309
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
310
+ float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
311
+
312
+ float ms = 1.0f;
313
+ float vs = 1.0f;
314
+
315
+ if (sink > Mf[r]) {
316
+ ms = exp(Mf[r] - sink);
317
+
318
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
319
+ Of[r][d] *= ms;
320
+ }
321
+ } else {
322
+ vs = exp(sink - Mf[r]);
323
+ }
324
+
325
+ Lf[r] = Lf[r]*ms + vs;
326
+ }
327
+ }
328
+
329
  float Lfrcp[Br];
330
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
331
  Lfrcp[r] = 1.0 / Lf[r];
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp CHANGED
@@ -50,10 +50,13 @@ layout (push_constant) uniform parameter {
50
  uint32_t k_num;
51
  } p;
52
 
 
53
  #define MASK_ENABLE_BIT (1<<16)
54
  #define N_LOG2_MASK 0xFFFF
55
 
56
- layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
 
 
57
 
58
  #if defined(A_TYPE_PACKED16)
59
  #define BINDING_IDX_K 0
@@ -111,6 +114,14 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
111
  return ACC_TYPE(pow(base, ACC_TYPE(exph)));
112
  }
113
 
 
 
 
 
 
 
 
 
114
  uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
115
  iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
116
  q_stride, k_stride, v_stride, m_stride;
 
50
  uint32_t k_num;
51
  } p;
52
 
53
+ #define SINK_ENABLE_BIT (1<<24)
54
  #define MASK_ENABLE_BIT (1<<16)
55
  #define N_LOG2_MASK 0xFFFF
56
 
57
+ layout (binding = 4) readonly buffer S {float data_s[];};
58
+
59
+ layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
60
 
61
  #if defined(A_TYPE_PACKED16)
62
  #define BINDING_IDX_K 0
 
114
  return ACC_TYPE(pow(base, ACC_TYPE(exph)));
115
  }
116
 
117
+ // Load the sink value, indexed by Q's dimension 2.
118
+ ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
119
+ {
120
+ const uint32_t h = iq2 + (r % p.gqa_ratio);
121
+
122
+ return ACC_TYPE(data_s[h]);
123
+ }
124
+
125
  uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
126
  iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
127
  q_stride, k_stride, v_stride, m_stride;
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp CHANGED
@@ -329,6 +329,27 @@ void main() {
329
  return;
330
  }
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  float Lfrcp[rows_per_thread];
333
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
334
  Lfrcp[r] = 1.0 / Lf[r];
 
329
  return;
330
  }
331
 
332
+ if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
333
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
334
+ float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
335
+
336
+ float ms = 1.0f;
337
+ float vs = 1.0f;
338
+
339
+ if (sink > Mf[r]) {
340
+ ms = exp(Mf[r] - sink);
341
+
342
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
343
+ Of[r][d] *= ACC_TYPE(ms);
344
+ }
345
+ } else {
346
+ vs = exp(sink - Mf[r]);
347
+ }
348
+
349
+ Lf[r] = Lf[r]*ms + vs;
350
+ }
351
+ }
352
+
353
  float Lfrcp[rows_per_thread];
354
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
355
  Lfrcp[r] = 1.0 / Lf[r];
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp CHANGED
@@ -248,6 +248,34 @@ void main() {
248
  // resize L by using smear/reduce
249
  coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  [[unroll]]
252
  for (int k = 0; k < Ldiag.length(); ++k) {
253
  Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
 
248
  // resize L by using smear/reduce
249
  coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
250
 
251
+ if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
252
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> S;
253
+ coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
254
+
255
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Mr;
256
+
257
+ // resize M by using smear/reduce
258
+ coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
259
+
260
+ // O, Ldiag, Mr all have the same type so all element locations match
261
+ [[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {
262
+ ACC_TYPE sink = S[i];
263
+
264
+ ACC_TYPE ms = ACC_TYPE(1.0f);
265
+ ACC_TYPE vs = ACC_TYPE(1.0f);
266
+
267
+ if (sink > Mr[i]) {
268
+ ms = exp(Mr[i] - sink);
269
+
270
+ O[i] *= ms;
271
+ } else {
272
+ vs = exp(sink - Mr[i]);
273
+ }
274
+
275
+ Ldiag[i] = Ldiag[i]*ms + vs;
276
+ }
277
+ }
278
+
279
  [[unroll]]
280
  for (int k = 0; k < Ldiag.length(); ++k) {
281
  Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp CHANGED
@@ -7,13 +7,15 @@ layout(constant_id = 0) const uint BLOCK_SIZE = 32;
7
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8
 
9
  layout (binding = 0) readonly buffer A {float data_a[];};
10
- layout (binding = 1) writeonly buffer D {float data_d[];};
 
11
 
12
  layout (push_constant) uniform parameter {
13
  uint D;
14
  uint N;
15
  uint ne3;
16
  uint k_num;
 
17
  } p;
18
 
19
  shared float tmpsh[BLOCK_SIZE];
@@ -73,6 +75,22 @@ void main() {
73
  }
74
  L = tmpsh[0];
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  L = 1.0 / L;
77
 
78
  // D dimension is split across workgroups in the y dimension
@@ -85,6 +103,13 @@ void main() {
85
  float m = data_a[m_offset + k * lm_stride];
86
  O += exp(m - m_max) * data_a[o_offset];
87
  }
 
 
 
 
 
 
 
88
  O *= L;
89
  data_d[iq3 * D * N + D * n + d] = O;
90
  }
 
7
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8
 
9
  layout (binding = 0) readonly buffer A {float data_a[];};
10
+ layout (binding = 1) readonly buffer B {float data_s[];};
11
+ layout (binding = 2) writeonly buffer D {float data_d[];};
12
 
13
  layout (push_constant) uniform parameter {
14
  uint D;
15
  uint N;
16
  uint ne3;
17
  uint k_num;
18
+ uint sinks;
19
  } p;
20
 
21
  shared float tmpsh[BLOCK_SIZE];
 
75
  }
76
  L = tmpsh[0];
77
 
78
+ float sink;
79
+ if (p.sinks != 0) {
80
+ sink = data_s[n];
81
+
82
+ float ms = 1.0f;
83
+ float vs = 1.0f;
84
+
85
+ if (sink > m_max) {
86
+ ms = exp(m_max - sink);
87
+ } else {
88
+ vs = exp(sink - m_max);
89
+ }
90
+
91
+ L = L*ms + vs;
92
+ }
93
+
94
  L = 1.0 / L;
95
 
96
  // D dimension is split across workgroups in the y dimension
 
103
  float m = data_a[m_offset + k * lm_stride];
104
  O += exp(m - m_max) * data_a[o_offset];
105
  }
106
+ if (p.sinks != 0) {
107
+ if (sink > m_max) {
108
+ float ms = 1.0f;
109
+ ms = exp(m_max - sink);
110
+ O *= ms;
111
+ }
112
+ }
113
  O *= L;
114
  data_d[iq3 * D * N + D * n + d] = O;
115
  }