Spaces:
Running
Running
Commit
·
d7e9115
1
Parent(s):
5ec4382
vulkan: support fattn sinks (llama/15126)
Browse files- ggml/src/ggml-vulkan/ggml-vulkan.cpp +40 -18
- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +12 -1
- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +21 -0
- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +28 -0
- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -2286,14 +2286,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2286 |
};
|
| 2287 |
|
| 2288 |
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
|
| 2289 |
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main",
|
| 2290 |
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main",
|
| 2291 |
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main",
|
| 2292 |
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main",
|
| 2293 |
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main",
|
| 2294 |
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main",
|
| 2295 |
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main",
|
| 2296 |
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main",
|
| 2297 |
|
| 2298 |
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
| 2299 |
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
|
|
@@ -2910,7 +2910,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2910 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2911 |
|
| 2912 |
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
| 2913 |
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main",
|
| 2914 |
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
| 2915 |
|
| 2916 |
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
|
@@ -6507,11 +6507,14 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
|
| 6507 |
return supported;
|
| 6508 |
}
|
| 6509 |
|
| 6510 |
-
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
|
| 6511 |
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
|
| 6512 |
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
|
| 6513 |
std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
|
| 6514 |
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
|
|
|
|
|
|
|
|
|
| 6515 |
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
|
| 6516 |
|
| 6517 |
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
@@ -6710,10 +6713,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 6710 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 6711 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 6712 |
|
| 6713 |
-
vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
|
| 6714 |
-
size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
|
| 6715 |
|
| 6716 |
-
bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false;
|
| 6717 |
|
| 6718 |
if (ctx->device->uma) {
|
| 6719 |
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
|
|
@@ -6728,6 +6731,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 6728 |
ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
|
| 6729 |
M_uma = d_M != nullptr;
|
| 6730 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6731 |
}
|
| 6732 |
|
| 6733 |
|
|
@@ -6763,7 +6770,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 6763 |
}
|
| 6764 |
}
|
| 6765 |
|
| 6766 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6767 |
|
| 6768 |
const vk_flash_attn_push_constants pc = { N, KV,
|
| 6769 |
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
|
|
@@ -6787,6 +6804,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 6787 |
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
|
| 6788 |
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
|
| 6789 |
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
|
|
|
| 6790 |
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
| 6791 |
},
|
| 6792 |
// We only use split_k when group query attention is enabled, which means
|
|
@@ -6796,10 +6814,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 6796 |
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
| 6797 |
|
| 6798 |
ggml_vk_sync_buffers(subctx);
|
| 6799 |
-
const std::array<uint32_t,
|
| 6800 |
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
| 6801 |
{
|
| 6802 |
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
|
|
|
| 6803 |
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
| 6804 |
},
|
| 6805 |
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
|
|
@@ -6810,6 +6829,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 6810 |
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
|
| 6811 |
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
|
| 6812 |
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
|
|
|
| 6813 |
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
| 6814 |
},
|
| 6815 |
pc, { workgroups_x, workgroups_y, workgroups_z });
|
|
@@ -9874,7 +9894,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 9874 |
break;
|
| 9875 |
|
| 9876 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 9877 |
-
ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
|
| 9878 |
|
| 9879 |
break;
|
| 9880 |
|
|
@@ -10951,8 +10971,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 10951 |
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
|
| 10952 |
return false;
|
| 10953 |
}
|
| 10954 |
-
|
| 10955 |
-
if (op->src[4]) {
|
| 10956 |
return false;
|
| 10957 |
}
|
| 10958 |
if (op->src[0]->type != GGML_TYPE_F32) {
|
|
@@ -11547,6 +11566,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
| 11547 |
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
| 11548 |
const float * params = (const float *)tensor->op_params;
|
| 11549 |
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
|
|
|
|
|
|
|
|
|
|
| 11550 |
} else if (tensor->op == GGML_OP_MUL_MAT) {
|
| 11551 |
tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
|
| 11552 |
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
|
|
|
|
| 2286 |
};
|
| 2287 |
|
| 2288 |
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
|
| 2289 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
| 2290 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
| 2291 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
| 2292 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
| 2293 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
| 2294 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
| 2295 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
| 2296 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
| 2297 |
|
| 2298 |
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
| 2299 |
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
|
|
|
|
| 2910 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2911 |
|
| 2912 |
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
| 2913 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
| 2914 |
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
| 2915 |
|
| 2916 |
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
|
|
|
| 6507 |
return supported;
|
| 6508 |
}
|
| 6509 |
|
| 6510 |
+
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst, bool dryrun = false) {
|
| 6511 |
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
|
| 6512 |
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
|
| 6513 |
std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
|
| 6514 |
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
| 6515 |
+
if (sinks) {
|
| 6516 |
+
std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3];
|
| 6517 |
+
}
|
| 6518 |
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
|
| 6519 |
|
| 6520 |
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
|
|
| 6713 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 6714 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 6715 |
|
| 6716 |
+
vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr;
|
| 6717 |
+
size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0;
|
| 6718 |
|
| 6719 |
+
bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false;
|
| 6720 |
|
| 6721 |
if (ctx->device->uma) {
|
| 6722 |
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
|
|
|
|
| 6731 |
ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
|
| 6732 |
M_uma = d_M != nullptr;
|
| 6733 |
}
|
| 6734 |
+
if (sinks) {
|
| 6735 |
+
ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset);
|
| 6736 |
+
S_uma = d_S != nullptr;
|
| 6737 |
+
}
|
| 6738 |
}
|
| 6739 |
|
| 6740 |
|
|
|
|
| 6770 |
}
|
| 6771 |
}
|
| 6772 |
|
| 6773 |
+
if (!S_uma) {
|
| 6774 |
+
d_S = d_Q;
|
| 6775 |
+
s_buf_offset = q_buf_offset;
|
| 6776 |
+
if (sinks) {
|
| 6777 |
+
ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context;
|
| 6778 |
+
d_S = s_buf_ctx->dev_buffer;
|
| 6779 |
+
s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs;
|
| 6780 |
+
}
|
| 6781 |
+
}
|
| 6782 |
+
|
| 6783 |
+
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
|
| 6784 |
|
| 6785 |
const vk_flash_attn_push_constants pc = { N, KV,
|
| 6786 |
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
|
|
|
|
| 6804 |
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
|
| 6805 |
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
|
| 6806 |
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
| 6807 |
+
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
|
| 6808 |
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
| 6809 |
},
|
| 6810 |
// We only use split_k when group query attention is enabled, which means
|
|
|
|
| 6814 |
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
| 6815 |
|
| 6816 |
ggml_vk_sync_buffers(subctx);
|
| 6817 |
+
const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
|
| 6818 |
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
| 6819 |
{
|
| 6820 |
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
| 6821 |
+
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
|
| 6822 |
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
| 6823 |
},
|
| 6824 |
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
|
|
|
|
| 6829 |
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
|
| 6830 |
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
|
| 6831 |
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
| 6832 |
+
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
|
| 6833 |
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
| 6834 |
},
|
| 6835 |
pc, { workgroups_x, workgroups_y, workgroups_z });
|
|
|
|
| 9894 |
break;
|
| 9895 |
|
| 9896 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 9897 |
+
ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node, dryrun);
|
| 9898 |
|
| 9899 |
break;
|
| 9900 |
|
|
|
|
| 10971 |
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
|
| 10972 |
return false;
|
| 10973 |
}
|
| 10974 |
+
if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) {
|
|
|
|
| 10975 |
return false;
|
| 10976 |
}
|
| 10977 |
if (op->src[0]->type != GGML_TYPE_F32) {
|
|
|
|
| 11566 |
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
| 11567 |
const float * params = (const float *)tensor->op_params;
|
| 11568 |
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
|
| 11569 |
+
if (src_clone[4]) {
|
| 11570 |
+
ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]);
|
| 11571 |
+
}
|
| 11572 |
} else if (tensor->op == GGML_OP_MUL_MAT) {
|
| 11573 |
tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
|
| 11574 |
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
CHANGED
|
@@ -305,6 +305,27 @@ void main() {
|
|
| 305 |
return;
|
| 306 |
}
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
float Lfrcp[Br];
|
| 309 |
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
| 310 |
Lfrcp[r] = 1.0 / Lf[r];
|
|
|
|
| 305 |
return;
|
| 306 |
}
|
| 307 |
|
| 308 |
+
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
| 309 |
+
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
| 310 |
+
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
|
| 311 |
+
|
| 312 |
+
float ms = 1.0f;
|
| 313 |
+
float vs = 1.0f;
|
| 314 |
+
|
| 315 |
+
if (sink > Mf[r]) {
|
| 316 |
+
ms = exp(Mf[r] - sink);
|
| 317 |
+
|
| 318 |
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
| 319 |
+
Of[r][d] *= ms;
|
| 320 |
+
}
|
| 321 |
+
} else {
|
| 322 |
+
vs = exp(sink - Mf[r]);
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
Lf[r] = Lf[r]*ms + vs;
|
| 326 |
+
}
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
float Lfrcp[Br];
|
| 330 |
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
| 331 |
Lfrcp[r] = 1.0 / Lf[r];
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
CHANGED
|
@@ -50,10 +50,13 @@ layout (push_constant) uniform parameter {
|
|
| 50 |
uint32_t k_num;
|
| 51 |
} p;
|
| 52 |
|
|
|
|
| 53 |
#define MASK_ENABLE_BIT (1<<16)
|
| 54 |
#define N_LOG2_MASK 0xFFFF
|
| 55 |
|
| 56 |
-
layout (binding = 4)
|
|
|
|
|
|
|
| 57 |
|
| 58 |
#if defined(A_TYPE_PACKED16)
|
| 59 |
#define BINDING_IDX_K 0
|
|
@@ -111,6 +114,14 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
|
|
| 111 |
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
| 112 |
}
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
|
| 115 |
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
|
| 116 |
q_stride, k_stride, v_stride, m_stride;
|
|
|
|
| 50 |
uint32_t k_num;
|
| 51 |
} p;
|
| 52 |
|
| 53 |
+
#define SINK_ENABLE_BIT (1<<24)
|
| 54 |
#define MASK_ENABLE_BIT (1<<16)
|
| 55 |
#define N_LOG2_MASK 0xFFFF
|
| 56 |
|
| 57 |
+
layout (binding = 4) readonly buffer S {float data_s[];};
|
| 58 |
+
|
| 59 |
+
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
| 60 |
|
| 61 |
#if defined(A_TYPE_PACKED16)
|
| 62 |
#define BINDING_IDX_K 0
|
|
|
|
| 114 |
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
| 115 |
}
|
| 116 |
|
| 117 |
+
// Load the sink value, indexed by Q's dimension 2.
|
| 118 |
+
ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
| 119 |
+
{
|
| 120 |
+
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
| 121 |
+
|
| 122 |
+
return ACC_TYPE(data_s[h]);
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
|
| 126 |
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
|
| 127 |
q_stride, k_stride, v_stride, m_stride;
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
CHANGED
|
@@ -329,6 +329,27 @@ void main() {
|
|
| 329 |
return;
|
| 330 |
}
|
| 331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
float Lfrcp[rows_per_thread];
|
| 333 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
| 334 |
Lfrcp[r] = 1.0 / Lf[r];
|
|
|
|
| 329 |
return;
|
| 330 |
}
|
| 331 |
|
| 332 |
+
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
| 333 |
+
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
| 334 |
+
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
|
| 335 |
+
|
| 336 |
+
float ms = 1.0f;
|
| 337 |
+
float vs = 1.0f;
|
| 338 |
+
|
| 339 |
+
if (sink > Mf[r]) {
|
| 340 |
+
ms = exp(Mf[r] - sink);
|
| 341 |
+
|
| 342 |
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
| 343 |
+
Of[r][d] *= ACC_TYPE(ms);
|
| 344 |
+
}
|
| 345 |
+
} else {
|
| 346 |
+
vs = exp(sink - Mf[r]);
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
Lf[r] = Lf[r]*ms + vs;
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
float Lfrcp[rows_per_thread];
|
| 354 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
| 355 |
Lfrcp[r] = 1.0 / Lf[r];
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
CHANGED
|
@@ -248,6 +248,34 @@ void main() {
|
|
| 248 |
// resize L by using smear/reduce
|
| 249 |
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
[[unroll]]
|
| 252 |
for (int k = 0; k < Ldiag.length(); ++k) {
|
| 253 |
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
|
|
|
|
| 248 |
// resize L by using smear/reduce
|
| 249 |
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
| 250 |
|
| 251 |
+
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
| 252 |
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> S;
|
| 253 |
+
coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
|
| 254 |
+
|
| 255 |
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Mr;
|
| 256 |
+
|
| 257 |
+
// resize M by using smear/reduce
|
| 258 |
+
coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
| 259 |
+
|
| 260 |
+
// O, Ldiag, Mr all have the same type so all element locations match
|
| 261 |
+
[[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {
|
| 262 |
+
ACC_TYPE sink = S[i];
|
| 263 |
+
|
| 264 |
+
ACC_TYPE ms = ACC_TYPE(1.0f);
|
| 265 |
+
ACC_TYPE vs = ACC_TYPE(1.0f);
|
| 266 |
+
|
| 267 |
+
if (sink > Mr[i]) {
|
| 268 |
+
ms = exp(Mr[i] - sink);
|
| 269 |
+
|
| 270 |
+
O[i] *= ms;
|
| 271 |
+
} else {
|
| 272 |
+
vs = exp(sink - Mr[i]);
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
Ldiag[i] = Ldiag[i]*ms + vs;
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
[[unroll]]
|
| 280 |
for (int k = 0; k < Ldiag.length(); ++k) {
|
| 281 |
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
CHANGED
|
@@ -7,13 +7,15 @@ layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
|
| 7 |
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 8 |
|
| 9 |
layout (binding = 0) readonly buffer A {float data_a[];};
|
| 10 |
-
layout (binding = 1)
|
|
|
|
| 11 |
|
| 12 |
layout (push_constant) uniform parameter {
|
| 13 |
uint D;
|
| 14 |
uint N;
|
| 15 |
uint ne3;
|
| 16 |
uint k_num;
|
|
|
|
| 17 |
} p;
|
| 18 |
|
| 19 |
shared float tmpsh[BLOCK_SIZE];
|
|
@@ -73,6 +75,22 @@ void main() {
|
|
| 73 |
}
|
| 74 |
L = tmpsh[0];
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
L = 1.0 / L;
|
| 77 |
|
| 78 |
// D dimension is split across workgroups in the y dimension
|
|
@@ -85,6 +103,13 @@ void main() {
|
|
| 85 |
float m = data_a[m_offset + k * lm_stride];
|
| 86 |
O += exp(m - m_max) * data_a[o_offset];
|
| 87 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
O *= L;
|
| 89 |
data_d[iq3 * D * N + D * n + d] = O;
|
| 90 |
}
|
|
|
|
| 7 |
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 8 |
|
| 9 |
layout (binding = 0) readonly buffer A {float data_a[];};
|
| 10 |
+
layout (binding = 1) readonly buffer B {float data_s[];};
|
| 11 |
+
layout (binding = 2) writeonly buffer D {float data_d[];};
|
| 12 |
|
| 13 |
layout (push_constant) uniform parameter {
|
| 14 |
uint D;
|
| 15 |
uint N;
|
| 16 |
uint ne3;
|
| 17 |
uint k_num;
|
| 18 |
+
uint sinks;
|
| 19 |
} p;
|
| 20 |
|
| 21 |
shared float tmpsh[BLOCK_SIZE];
|
|
|
|
| 75 |
}
|
| 76 |
L = tmpsh[0];
|
| 77 |
|
| 78 |
+
float sink;
|
| 79 |
+
if (p.sinks != 0) {
|
| 80 |
+
sink = data_s[n];
|
| 81 |
+
|
| 82 |
+
float ms = 1.0f;
|
| 83 |
+
float vs = 1.0f;
|
| 84 |
+
|
| 85 |
+
if (sink > m_max) {
|
| 86 |
+
ms = exp(m_max - sink);
|
| 87 |
+
} else {
|
| 88 |
+
vs = exp(sink - m_max);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
L = L*ms + vs;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
L = 1.0 / L;
|
| 95 |
|
| 96 |
// D dimension is split across workgroups in the y dimension
|
|
|
|
| 103 |
float m = data_a[m_offset + k * lm_stride];
|
| 104 |
O += exp(m - m_max) * data_a[o_offset];
|
| 105 |
}
|
| 106 |
+
if (p.sinks != 0) {
|
| 107 |
+
if (sink > m_max) {
|
| 108 |
+
float ms = 1.0f;
|
| 109 |
+
ms = exp(m_max - sink);
|
| 110 |
+
O *= ms;
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
O *= L;
|
| 114 |
data_d[iq3 * D * N + D * n + d] = O;
|
| 115 |
}
|