Spaces:
Running
Running
Commit ·
5ab06d6
1
Parent(s): fac18c1
vulkan: Implement split_k for coopmat2 flash attention. (llama/12627)
Browse filesWhen using group query attention, we have one workgroup per KV batch and this
can be very few workgroups (e.g. just 8 in some models). Enable split_k to
spread the work across SMs. This helps a lot when the KV cache is large.
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -353,6 +353,7 @@ struct vk_device_struct {
|
|
| 353 |
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
|
| 354 |
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
| 355 |
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
|
|
|
| 356 |
|
| 357 |
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
|
| 358 |
std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
|
|
@@ -504,6 +505,8 @@ struct vk_flash_attn_push_constants {
|
|
| 504 |
float m1;
|
| 505 |
|
| 506 |
uint32_t gqa_ratio;
|
|
|
|
|
|
|
| 507 |
};
|
| 508 |
|
| 509 |
struct vk_op_push_constants {
|
|
@@ -1476,7 +1479,7 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
|
|
| 1476 |
|
| 1477 |
// small rows, large cols
|
| 1478 |
if (small_rows) {
|
| 1479 |
-
return {flash_attention_num_small_rows,
|
| 1480 |
}
|
| 1481 |
// small cols to reduce register count
|
| 1482 |
if (ggml_is_quantized(type) || D == 256) {
|
|
@@ -2332,6 +2335,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2332 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2333 |
|
| 2334 |
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
|
|
|
| 2335 |
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
| 2336 |
|
| 2337 |
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
|
@@ -5479,9 +5483,38 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 5479 |
workgroups_y /= N;
|
| 5480 |
}
|
| 5481 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5482 |
if (dryrun) {
|
| 5483 |
// Request descriptor sets
|
| 5484 |
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
|
|
|
|
|
|
|
|
|
| 5485 |
return;
|
| 5486 |
}
|
| 5487 |
|
|
@@ -5502,8 +5535,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 5502 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 5503 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 5504 |
|
| 5505 |
-
ggml_vk_sync_buffers(subctx);
|
| 5506 |
-
|
| 5507 |
vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
|
| 5508 |
size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
|
| 5509 |
|
|
@@ -5568,16 +5599,45 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 5568 |
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
| 5569 |
nbm1,
|
| 5570 |
scale, max_bias, logit_softcap,
|
| 5571 |
-
mask != nullptr, n_head_log2, m0, m1,
|
| 5572 |
-
|
| 5573 |
-
|
| 5574 |
-
|
| 5575 |
-
|
| 5576 |
-
|
| 5577 |
-
|
| 5578 |
-
|
| 5579 |
-
|
| 5580 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5581 |
}
|
| 5582 |
|
| 5583 |
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
|
|
|
|
| 353 |
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
|
| 354 |
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
| 355 |
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
| 356 |
+
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
| 357 |
|
| 358 |
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
|
| 359 |
std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
|
|
|
|
| 505 |
float m1;
|
| 506 |
|
| 507 |
uint32_t gqa_ratio;
|
| 508 |
+
uint32_t split_kv;
|
| 509 |
+
uint32_t k_num;
|
| 510 |
};
|
| 511 |
|
| 512 |
struct vk_op_push_constants {
|
|
|
|
| 1479 |
|
| 1480 |
// small rows, large cols
|
| 1481 |
if (small_rows) {
|
| 1482 |
+
return {flash_attention_num_small_rows, 64};
|
| 1483 |
}
|
| 1484 |
// small cols to reduce register count
|
| 1485 |
if (ggml_is_quantized(type) || D == 256) {
|
|
|
|
| 2335 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2336 |
|
| 2337 |
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
| 2338 |
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
|
| 2339 |
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
| 2340 |
|
| 2341 |
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
|
|
|
| 5483 |
workgroups_y /= N;
|
| 5484 |
}
|
| 5485 |
|
| 5486 |
+
uint32_t split_kv = KV;
|
| 5487 |
+
uint32_t split_k = 1;
|
| 5488 |
+
|
| 5489 |
+
if (gqa_ratio > 1 && ctx->device->shader_core_count > 0) {
|
| 5490 |
+
GGML_ASSERT(workgroups_x == 1);
|
| 5491 |
+
// Try to run two workgroups per SM.
|
| 5492 |
+
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
|
| 5493 |
+
if (split_k > 1) {
|
| 5494 |
+
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
| 5495 |
+
// of "align", so recompute split_k based on that.
|
| 5496 |
+
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
|
| 5497 |
+
split_k = CEIL_DIV(KV, split_kv);
|
| 5498 |
+
workgroups_x = split_k;
|
| 5499 |
+
}
|
| 5500 |
+
}
|
| 5501 |
+
|
| 5502 |
+
// Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
|
| 5503 |
+
// and the per-row m and L values (ne1 rows).
|
| 5504 |
+
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
|
| 5505 |
+
if (split_k_size > ctx->device->max_memory_allocation_size) {
|
| 5506 |
+
GGML_ABORT("Requested preallocation size is too large");
|
| 5507 |
+
}
|
| 5508 |
+
if (ctx->prealloc_size_split_k < split_k_size) {
|
| 5509 |
+
ctx->prealloc_size_split_k = split_k_size;
|
| 5510 |
+
}
|
| 5511 |
+
|
| 5512 |
if (dryrun) {
|
| 5513 |
// Request descriptor sets
|
| 5514 |
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
| 5515 |
+
if (split_k > 1) {
|
| 5516 |
+
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
|
| 5517 |
+
}
|
| 5518 |
return;
|
| 5519 |
}
|
| 5520 |
|
|
|
|
| 5535 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 5536 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 5537 |
|
|
|
|
|
|
|
| 5538 |
vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
|
| 5539 |
size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
|
| 5540 |
|
|
|
|
| 5599 |
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
| 5600 |
nbm1,
|
| 5601 |
scale, max_bias, logit_softcap,
|
| 5602 |
+
mask != nullptr, n_head_log2, m0, m1,
|
| 5603 |
+
gqa_ratio, split_kv, split_k };
|
| 5604 |
+
|
| 5605 |
+
ggml_vk_sync_buffers(subctx);
|
| 5606 |
+
|
| 5607 |
+
if (split_k > 1) {
|
| 5608 |
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
| 5609 |
+
{
|
| 5610 |
+
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
| 5611 |
+
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
|
| 5612 |
+
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
|
| 5613 |
+
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
| 5614 |
+
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
| 5615 |
+
},
|
| 5616 |
+
// We only use split_k when group query attention is enabled, which means
|
| 5617 |
+
// there's no more than one tile of rows (i.e. workgroups_x would have been
|
| 5618 |
+
// one). We reuse workgroups_x to mean the number of splits, so we need to
|
| 5619 |
+
// cancel out the divide by wg_denoms[0].
|
| 5620 |
+
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
| 5621 |
+
|
| 5622 |
+
ggml_vk_sync_buffers(subctx);
|
| 5623 |
+
const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
|
| 5624 |
+
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
| 5625 |
+
{
|
| 5626 |
+
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
| 5627 |
+
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
| 5628 |
+
},
|
| 5629 |
+
pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 });
|
| 5630 |
+
} else {
|
| 5631 |
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
| 5632 |
+
{
|
| 5633 |
+
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
| 5634 |
+
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
|
| 5635 |
+
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
|
| 5636 |
+
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
| 5637 |
+
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
| 5638 |
+
},
|
| 5639 |
+
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
|
| 5640 |
+
}
|
| 5641 |
}
|
| 5642 |
|
| 5643 |
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
CHANGED
|
@@ -63,6 +63,8 @@ layout (push_constant) uniform parameter {
|
|
| 63 |
float m1;
|
| 64 |
|
| 65 |
uint32_t gqa_ratio;
|
|
|
|
|
|
|
| 66 |
} p;
|
| 67 |
|
| 68 |
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
|
@@ -116,6 +118,16 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
|
| 116 |
return elem;
|
| 117 |
}
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
// Load the slope matrix, indexed by Q's dimension 2.
|
| 120 |
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
| 121 |
{
|
|
@@ -135,10 +147,18 @@ void main() {
|
|
| 135 |
const uint32_t N = p.N;
|
| 136 |
const uint32_t KV = p.KV;
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
const uint32_t Tr = CEIL_DIV(N, Br);
|
| 139 |
-
const uint32_t Tc = CEIL_DIV(KV, Bc);
|
| 140 |
|
| 141 |
-
const uint32_t
|
|
|
|
| 142 |
|
| 143 |
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
| 144 |
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
|
@@ -218,7 +238,7 @@ void main() {
|
|
| 218 |
}
|
| 219 |
|
| 220 |
[[dont_unroll]]
|
| 221 |
-
for (uint32_t j =
|
| 222 |
|
| 223 |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
| 224 |
|
|
@@ -312,6 +332,20 @@ void main() {
|
|
| 312 |
O = coopMatMulAdd(P_A, V, O);
|
| 313 |
}
|
| 314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
|
| 316 |
|
| 317 |
// resize L by using smear/reduce
|
|
|
|
| 63 |
float m1;
|
| 64 |
|
| 65 |
uint32_t gqa_ratio;
|
| 66 |
+
uint32_t split_kv;
|
| 67 |
+
uint32_t k_num;
|
| 68 |
} p;
|
| 69 |
|
| 70 |
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
|
|
|
| 118 |
return elem;
|
| 119 |
}
|
| 120 |
|
| 121 |
+
// Store column zero. This is used to save per-row m and L values for split_k.
|
| 122 |
+
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
| 123 |
+
{
|
| 124 |
+
if (r < N && c == 0) {
|
| 125 |
+
uint32_t offset = iq2 + r;
|
| 126 |
+
data_o[o_offset + offset] = D_TYPE(elem);
|
| 127 |
+
}
|
| 128 |
+
return elem;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
// Load the slope matrix, indexed by Q's dimension 2.
|
| 132 |
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
| 133 |
{
|
|
|
|
| 147 |
const uint32_t N = p.N;
|
| 148 |
const uint32_t KV = p.KV;
|
| 149 |
|
| 150 |
+
uint32_t i = gl_WorkGroupID.x;
|
| 151 |
+
uint32_t split_k_index = 0;
|
| 152 |
+
|
| 153 |
+
if (p.k_num > 1) {
|
| 154 |
+
i = 0;
|
| 155 |
+
split_k_index = gl_WorkGroupID.x;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
const uint32_t Tr = CEIL_DIV(N, Br);
|
|
|
|
| 159 |
|
| 160 |
+
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
| 161 |
+
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
| 162 |
|
| 163 |
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
| 164 |
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
|
|
|
| 238 |
}
|
| 239 |
|
| 240 |
[[dont_unroll]]
|
| 241 |
+
for (uint32_t j = start_j; j < end_j; ++j) {
|
| 242 |
|
| 243 |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
| 244 |
|
|
|
|
| 332 |
O = coopMatMulAdd(P_A, V, O);
|
| 333 |
}
|
| 334 |
|
| 335 |
+
// If there is split_k, then the split_k resolve shader does the final
|
| 336 |
+
// division by L. Store the intermediate O value and per-row m and L values.
|
| 337 |
+
if (p.k_num > 1) {
|
| 338 |
+
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
| 339 |
+
|
| 340 |
+
uint32_t o_offset = D * p.ne1 * split_k_index;
|
| 341 |
+
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
| 342 |
+
|
| 343 |
+
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
| 344 |
+
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
| 345 |
+
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
| 346 |
+
return;
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
|
| 350 |
|
| 351 |
// resize L by using smear/reduce
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#extension GL_EXT_control_flow_attributes : enable
|
| 4 |
+
|
| 5 |
+
#define BLOCK_SIZE 32
|
| 6 |
+
|
| 7 |
+
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
| 8 |
+
|
| 9 |
+
layout (binding = 0) readonly buffer A {float data_a[];};
|
| 10 |
+
layout (binding = 1) writeonly buffer D {float data_d[];};
|
| 11 |
+
|
| 12 |
+
layout (push_constant) uniform parameter {
|
| 13 |
+
uint D;
|
| 14 |
+
uint N;
|
| 15 |
+
uint k_num;
|
| 16 |
+
} p;
|
| 17 |
+
|
| 18 |
+
void main() {
|
| 19 |
+
// Each workgroup handles a row
|
| 20 |
+
const uint n = gl_WorkGroupID.x;
|
| 21 |
+
const uint tid = gl_LocalInvocationID.x;
|
| 22 |
+
|
| 23 |
+
uint D = p.D;
|
| 24 |
+
uint N = p.N;
|
| 25 |
+
uint k_num = p.k_num;
|
| 26 |
+
|
| 27 |
+
uint l_offset = D * N * k_num + n;
|
| 28 |
+
uint m_offset = D * N * k_num + N + n;
|
| 29 |
+
uint lm_stride = N * 2;
|
| 30 |
+
|
| 31 |
+
// Compute the max m value for the row
|
| 32 |
+
float m_max = -1.0/0.0;
|
| 33 |
+
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
| 34 |
+
float m = data_a[m_offset + k * lm_stride];
|
| 35 |
+
m_max = max(m_max, m);
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
// Compute L based on m_max
|
| 39 |
+
float L = 0;
|
| 40 |
+
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
| 41 |
+
float l = data_a[l_offset + k * lm_stride];
|
| 42 |
+
float m = data_a[m_offset + k * lm_stride];
|
| 43 |
+
L += exp(m - m_max) * l;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
L = 1.0 / L;
|
| 47 |
+
|
| 48 |
+
// Scale and sum the O contributions based on m_max and store the result to memory
|
| 49 |
+
for (uint d = tid; d < D; d += BLOCK_SIZE) {
|
| 50 |
+
float O = 0.0;
|
| 51 |
+
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
| 52 |
+
uint o_offset = D * N * k + D * n + d;
|
| 53 |
+
float m = data_a[m_offset + k * lm_stride];
|
| 54 |
+
O += exp(m - m_max) * data_a[o_offset];
|
| 55 |
+
}
|
| 56 |
+
O *= L;
|
| 57 |
+
data_d[D * n + d] = O;
|
| 58 |
+
}
|
| 59 |
+
}
|
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
CHANGED
|
@@ -465,6 +465,7 @@ void process_shaders() {
|
|
| 465 |
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
| 466 |
|
| 467 |
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
|
|
|
|
| 468 |
string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
|
| 469 |
|
| 470 |
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
|
|
| 465 |
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
| 466 |
|
| 467 |
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
|
| 468 |
+
string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
|
| 469 |
string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
|
| 470 |
|
| 471 |
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|