Spaces:
Running
Running
Commit
·
41a76e6
1
Parent(s):
a6fa78e
vulkan: Support mul_mat_id with f32 accumulators (llama/15337)
Browse files* vulkan: Add missing bounds checking to scalar/coopmat1 mul_mat_id
* vulkan: Support mul_mat_id with f32 accumulators, but they are not hooked up
- There's no explicit way to request f32 precision for mul_mat_id, but there
probably should be, and this gets the code in place for that.
- A couple fixes to check_results.
- Remove casts to fp16 in coopmat1 FA shader (found by inspection).
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -2387,26 +2387,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2387 |
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
| 2388 |
}
|
| 2389 |
#endif
|
| 2390 |
-
|
| 2391 |
-
|
| 2392 |
-
|
| 2393 |
-
|
| 2394 |
-
|
| 2395 |
-
|
| 2396 |
-
|
| 2397 |
-
|
| 2398 |
-
|
| 2399 |
-
|
| 2400 |
-
|
| 2401 |
-
|
| 2402 |
-
|
| 2403 |
-
|
| 2404 |
-
|
| 2405 |
-
|
| 2406 |
-
|
| 2407 |
-
|
| 2408 |
-
|
| 2409 |
-
|
| 2410 |
#undef CREATE_MM
|
| 2411 |
#undef CREATE_MM2
|
| 2412 |
} else
|
|
@@ -2502,51 +2502,27 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2502 |
}
|
| 2503 |
#endif
|
| 2504 |
|
| 2505 |
-
|
| 2506 |
-
|
| 2507 |
-
|
| 2508 |
-
|
| 2509 |
-
|
| 2510 |
-
|
| 2511 |
-
|
| 2512 |
-
|
| 2513 |
-
|
| 2514 |
-
|
| 2515 |
-
|
| 2516 |
-
|
| 2517 |
-
|
| 2518 |
-
|
| 2519 |
-
|
| 2520 |
-
|
| 2521 |
-
|
| 2522 |
-
|
| 2523 |
-
|
| 2524 |
-
|
| 2525 |
-
|
| 2526 |
-
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2527 |
-
} else {
|
| 2528 |
-
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2529 |
-
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2530 |
-
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2531 |
-
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2532 |
-
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2533 |
-
|
| 2534 |
-
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2535 |
-
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2536 |
-
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2537 |
-
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2538 |
-
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2539 |
-
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2540 |
-
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2541 |
-
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2542 |
-
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2543 |
-
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2544 |
-
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2545 |
-
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2546 |
-
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2547 |
-
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2548 |
-
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2549 |
-
}
|
| 2550 |
#undef CREATE_MM2
|
| 2551 |
#undef CREATE_MM
|
| 2552 |
} else
|
|
@@ -2631,27 +2607,27 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2631 |
|
| 2632 |
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
| 2633 |
|
| 2634 |
-
|
| 2635 |
-
|
| 2636 |
-
|
| 2637 |
-
|
| 2638 |
-
|
| 2639 |
-
|
| 2640 |
-
|
| 2641 |
-
|
| 2642 |
-
|
| 2643 |
-
|
| 2644 |
-
|
| 2645 |
-
|
| 2646 |
-
|
| 2647 |
-
|
| 2648 |
-
|
| 2649 |
-
|
| 2650 |
-
|
| 2651 |
-
|
| 2652 |
-
|
| 2653 |
-
|
| 2654 |
-
|
| 2655 |
#undef CREATE_MM2
|
| 2656 |
#undef CREATE_MMQ
|
| 2657 |
#undef CREATE_MM
|
|
@@ -4470,7 +4446,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
|
| 4470 |
return nullptr;
|
| 4471 |
}
|
| 4472 |
|
| 4473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4474 |
}
|
| 4475 |
|
| 4476 |
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
|
|
@@ -11723,6 +11709,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
| 11723 |
} else {
|
| 11724 |
tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
|
| 11725 |
}
|
|
|
|
|
|
|
| 11726 |
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
|
| 11727 |
if (src1 == nullptr) {
|
| 11728 |
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
|
|
@@ -11807,6 +11795,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
| 11807 |
src_clone[0]->flags = src0->flags;
|
| 11808 |
tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
|
| 11809 |
src_clone[2]);
|
|
|
|
|
|
|
| 11810 |
}
|
| 11811 |
else {
|
| 11812 |
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
|
|
|
| 2387 |
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
| 2388 |
}
|
| 2389 |
#endif
|
| 2390 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2391 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2392 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2393 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2394 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2395 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2396 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2397 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2398 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2399 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2400 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2401 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2402 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2403 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2404 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2405 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2406 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2407 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2408 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2409 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2410 |
#undef CREATE_MM
|
| 2411 |
#undef CREATE_MM2
|
| 2412 |
} else
|
|
|
|
| 2502 |
}
|
| 2503 |
#endif
|
| 2504 |
|
| 2505 |
+
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2506 |
+
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2507 |
+
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2508 |
+
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2509 |
+
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2510 |
+
|
| 2511 |
+
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2512 |
+
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2513 |
+
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2514 |
+
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2515 |
+
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2516 |
+
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2517 |
+
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2518 |
+
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2519 |
+
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2520 |
+
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2521 |
+
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2522 |
+
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2523 |
+
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2524 |
+
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2525 |
+
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2526 |
#undef CREATE_MM2
|
| 2527 |
#undef CREATE_MM
|
| 2528 |
} else
|
|
|
|
| 2607 |
|
| 2608 |
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
| 2609 |
|
| 2610 |
+
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2611 |
+
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2612 |
+
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2613 |
+
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2614 |
+
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2615 |
+
|
| 2616 |
+
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2617 |
+
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2618 |
+
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2619 |
+
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2620 |
+
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2621 |
+
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2622 |
+
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2623 |
+
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2624 |
+
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2625 |
+
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2626 |
+
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2627 |
+
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2628 |
+
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2629 |
+
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2630 |
+
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2631 |
#undef CREATE_MM2
|
| 2632 |
#undef CREATE_MMQ
|
| 2633 |
#undef CREATE_MM
|
|
|
|
| 4446 |
return nullptr;
|
| 4447 |
}
|
| 4448 |
|
| 4449 |
+
// XXX TODO 'prec' is not actually allowed in mul_mat_id.
|
| 4450 |
+
bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
|
| 4451 |
+
bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr;
|
| 4452 |
+
bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr;
|
| 4453 |
+
|
| 4454 |
+
if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
|
| 4455 |
+
return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc;
|
| 4456 |
+
} else {
|
| 4457 |
+
GGML_ASSERT(support_fp32acc);
|
| 4458 |
+
return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
|
| 4459 |
+
}
|
| 4460 |
}
|
| 4461 |
|
| 4462 |
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
|
|
|
|
| 11709 |
} else {
|
| 11710 |
tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
|
| 11711 |
}
|
| 11712 |
+
ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2));
|
| 11713 |
+
ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3));
|
| 11714 |
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
|
| 11715 |
if (src1 == nullptr) {
|
| 11716 |
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
|
|
|
|
| 11795 |
src_clone[0]->flags = src0->flags;
|
| 11796 |
tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
|
| 11797 |
src_clone[2]);
|
| 11798 |
+
} else if (tensor->op == GGML_OP_ADD_ID) {
|
| 11799 |
+
tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
|
| 11800 |
}
|
| 11801 |
else {
|
| 11802 |
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
CHANGED
|
@@ -210,7 +210,7 @@ void main() {
|
|
| 210 |
|
| 211 |
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
| 212 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
| 213 |
-
Of[r][d] =
|
| 214 |
}
|
| 215 |
}
|
| 216 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
@@ -233,7 +233,7 @@ void main() {
|
|
| 233 |
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
| 234 |
#endif
|
| 235 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
| 236 |
-
Of[r][d] +=
|
| 237 |
}
|
| 238 |
}
|
| 239 |
}
|
|
@@ -288,7 +288,7 @@ void main() {
|
|
| 288 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
| 289 |
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
| 290 |
|
| 291 |
-
Of[r][d] =
|
| 292 |
tmpshv4[tid] = Of[r][d];
|
| 293 |
|
| 294 |
barrier();
|
|
@@ -357,7 +357,7 @@ void main() {
|
|
| 357 |
|
| 358 |
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
| 359 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
| 360 |
-
Of[r][d] *=
|
| 361 |
}
|
| 362 |
}
|
| 363 |
|
|
|
|
| 210 |
|
| 211 |
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
| 212 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
| 213 |
+
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
|
| 214 |
}
|
| 215 |
}
|
| 216 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
|
|
| 233 |
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
| 234 |
#endif
|
| 235 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
| 236 |
+
Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
|
| 237 |
}
|
| 238 |
}
|
| 239 |
}
|
|
|
|
| 288 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
| 289 |
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
| 290 |
|
| 291 |
+
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
|
| 292 |
tmpshv4[tid] = Of[r][d];
|
| 293 |
|
| 294 |
barrier();
|
|
|
|
| 357 |
|
| 358 |
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
| 359 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
| 360 |
+
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
|
| 361 |
}
|
| 362 |
}
|
| 363 |
|