Dan Johansson Charles Xu commited on
Commit
0612f1f
·
1 Parent(s): 40840d0

ggml-cpu: Integrate fp32=bf16xbf16 SME KleidiAI kernel (llama/13053)

Browse files

* ggml-cpu: Integrate fp32=bf16xbf16 SME KleidiAI kernel

Signed-off-by: Dan Johansson <[email protected]>

* * code review fixes

Signed-off-by: Dan Johansson <[email protected]>

* * adds a comment that clarifies barrier usage

Signed-off-by: Dan Johansson <[email protected]>

---------

Signed-off-by: Dan Johansson <[email protected]>
Co-authored-by: Charles Xu <[email protected]>

ggml/src/ggml-cpu/CMakeLists.txt CHANGED
@@ -428,6 +428,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
428
  ${KLEIDIAI_SRC}/kai/ukernels/
429
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/
430
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
 
431
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
432
 
433
  set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
@@ -438,17 +439,19 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
438
  string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
439
  string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
440
 
441
- set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
442
 
443
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c)
444
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c)
445
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c)
446
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
 
447
 
448
  if (NOT DOTPROD_ENABLED MATCHES -1)
449
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c)
450
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c)
451
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
 
452
  endif()
453
 
454
  if (NOT I8MM_ENABLED MATCHES -1)
@@ -456,9 +459,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
456
  endif()
457
 
458
  if (NOT SME_ENABLED MATCHES -1)
459
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c)
460
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c)
461
- set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2")
 
 
 
 
462
  endif()
463
 
464
  set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
 
428
  ${KLEIDIAI_SRC}/kai/ukernels/
429
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/
430
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
431
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
432
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
433
 
434
  set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
 
439
  string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
440
  string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
441
 
442
+ set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP})
443
 
444
+ list(APPEND GGML_KLEIDIAI_SOURCES
445
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c
446
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
447
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
448
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
449
 
450
  if (NOT DOTPROD_ENABLED MATCHES -1)
451
+ list(APPEND GGML_KLEIDIAI_SOURCES
452
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
453
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
454
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
455
  endif()
456
 
457
  if (NOT I8MM_ENABLED MATCHES -1)
 
459
  endif()
460
 
461
  if (NOT SME_ENABLED MATCHES -1)
462
+ list(APPEND GGML_KLEIDIAI_SOURCES
463
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
464
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
465
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
466
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
467
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c)
468
+ set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
469
  endif()
470
 
471
  set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
ggml/src/ggml-cpu/kleidiai/kernels.cpp CHANGED
@@ -4,16 +4,22 @@
4
 
5
  // KleidiAI micro-kernels
6
  #include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
7
- #include "kai_lhs_quant_pack_qsi8d32p_f32.h"
8
- #include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
9
- #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
10
- #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
11
  #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
12
  #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
13
  #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
14
  #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
15
  #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
16
  #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
 
 
 
 
 
 
 
 
 
 
17
  #include "kai_common.h"
18
 
19
  #include "kernels.h"
@@ -61,6 +67,53 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
61
  /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
62
  },
63
  /* .required_cpu = */ CPU_FEATURE_SME,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  },
65
  #endif
66
  #if defined(__APPLE__)
@@ -105,6 +158,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
105
  /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
106
  },
107
  /* .required_cpu = */ CPU_FEATURE_DOTPROD,
 
 
 
108
  },
109
  #endif
110
  #if defined(__ARM_FEATURE_MATMUL_INT8)
@@ -148,6 +204,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
148
  /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
149
  },
150
  /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
 
 
 
151
  },
152
  #endif
153
  #else
@@ -192,6 +251,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
192
  /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
193
  },
194
  /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
 
 
 
195
  },
196
  #endif
197
  #if defined(__ARM_FEATURE_DOTPROD)
@@ -235,12 +297,33 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
235
  /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
236
  },
237
  /* .required_cpu = */ CPU_FEATURE_DOTPROD,
 
 
 
238
  },
239
  #endif
240
  #endif
241
  };
242
 
243
- ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  ggml_kleidiai_kernels * kernels = nullptr;
245
 
246
  for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
 
4
 
5
  // KleidiAI micro-kernels
6
  #include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
 
 
 
 
7
  #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
8
  #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
9
  #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
10
  #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
11
  #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
12
  #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
13
+ #include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
14
+
15
+ #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
16
+ #include "kai_lhs_quant_pack_qsi8d32p_f32.h"
17
+ #include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
18
+
19
+ #include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
20
+ #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
21
+ #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
22
+
23
  #include "kai_common.h"
24
 
25
  #include "kernels.h"
 
67
  /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
68
  },
69
  /* .required_cpu = */ CPU_FEATURE_SME,
70
+ /* .lhs_type = */ GGML_TYPE_F32,
71
+ /* .rhs_type = */ GGML_TYPE_Q4_0,
72
+ /* .op_type = */ GGML_TYPE_F32,
73
+ },
74
+ {
75
+ /* SME GEMM */
76
+ /* .kern_info = */ {
77
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
78
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
79
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
80
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
81
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
82
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
83
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
84
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
85
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
86
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
87
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
88
+ },
89
+ /* SME GEMV */
90
+ /* .kern_info = */ {
91
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
92
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
93
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
94
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
95
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
96
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
97
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
98
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
99
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
100
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
101
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
102
+ },
103
+ /* .lhs_info = */ {
104
+ /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
105
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
106
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
107
+ /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
108
+ },
109
+ /* .rhs_info = */ {
110
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
111
+ /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
112
+ },
113
+ /* .required_cpu = */ CPU_FEATURE_SME,
114
+ /* .lhs_type = */ GGML_TYPE_F32,
115
+ /* .rhs_type = */ GGML_TYPE_F16,
116
+ /* .op_type = */ GGML_TYPE_F32,
117
  },
118
  #endif
119
  #if defined(__APPLE__)
 
158
  /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
159
  },
160
  /* .required_cpu = */ CPU_FEATURE_DOTPROD,
161
+ /* .lhs_type = */ GGML_TYPE_F32,
162
+ /* .rhs_type = */ GGML_TYPE_Q4_0,
163
+ /* .op_type = */ GGML_TYPE_F32,
164
  },
165
  #endif
166
  #if defined(__ARM_FEATURE_MATMUL_INT8)
 
204
  /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
205
  },
206
  /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
207
+ /* .lhs_type = */ GGML_TYPE_F32,
208
+ /* .rhs_type = */ GGML_TYPE_Q4_0,
209
+ /* .op_type = */ GGML_TYPE_F32,
210
  },
211
  #endif
212
  #else
 
251
  /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
252
  },
253
  /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
254
+ /* .lhs_type = */ GGML_TYPE_F32,
255
+ /* .rhs_type = */ GGML_TYPE_Q4_0,
256
+ /* .op_type = */ GGML_TYPE_F32,
257
  },
258
  #endif
259
  #if defined(__ARM_FEATURE_DOTPROD)
 
297
  /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
298
  },
299
  /* .required_cpu = */ CPU_FEATURE_DOTPROD,
300
+ /* .lhs_type = */ GGML_TYPE_F32,
301
+ /* .rhs_type = */ GGML_TYPE_Q4_0,
302
+ /* .op_type = */ GGML_TYPE_F32,
303
  },
304
  #endif
305
  #endif
306
  };
307
 
308
+ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
309
+ ggml_kleidiai_kernels * kernel = nullptr;
310
+
311
+ if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
312
+ for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
313
+ if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
314
+ gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
315
+ gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
316
+ gemm_gemv_kernels[i].op_type == tensor->type) {
317
+ kernel = &gemm_gemv_kernels[i];
318
+ break;
319
+ }
320
+ }
321
+ }
322
+
323
+ return kernel;
324
+ }
325
+
326
+ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
327
  ggml_kleidiai_kernels * kernels = nullptr;
328
 
329
  for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
ggml/src/ggml-cpu/kleidiai/kernels.h CHANGED
@@ -4,6 +4,9 @@
4
 
5
  #pragma once
6
 
 
 
 
7
  enum cpu_feature {
8
  CPU_FEATURE_NONE = 0,
9
  CPU_FEATURE_DOTPROD = 1,
@@ -26,26 +29,53 @@ struct kernel_info {
26
  size_t (*get_nr)(void);
27
  size_t (*get_kr)(void);
28
  size_t (*get_sr)(void);
29
- size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl);
30
- size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl);
 
 
 
 
 
 
31
  size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
32
  size_t (*get_dst_size)(size_t m, size_t n);
33
- void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
34
- float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max);
 
 
 
 
35
  };
36
 
37
  struct lhs_packing_info {
38
  size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
39
- size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
40
- size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
41
- void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
42
- size_t lhs_stride, void* lhs_packed);
 
 
 
 
 
 
 
 
 
 
43
  };
44
 
45
  struct rhs_packing_info {
46
- size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
47
- void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
48
- const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params);
 
 
 
 
 
 
 
49
  };
50
 
51
  struct ggml_kleidiai_kernels {
@@ -55,6 +85,10 @@ struct ggml_kleidiai_kernels {
55
  rhs_packing_info rhs_info;
56
 
57
  cpu_feature required_cpu;
 
 
 
58
  };
59
 
60
- ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features);
 
 
4
 
5
  #pragma once
6
 
7
+ #include <functional>
8
+ #include "ggml.h"
9
+
10
  enum cpu_feature {
11
  CPU_FEATURE_NONE = 0,
12
  CPU_FEATURE_DOTPROD = 1,
 
29
  size_t (*get_nr)(void);
30
  size_t (*get_kr)(void);
31
  size_t (*get_sr)(void);
32
+ std::variant<
33
+ std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
34
+ std::function<size_t(size_t m_idx, size_t k)>
35
+ > get_lhs_offset;
36
+ std::variant<
37
+ std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
38
+ std::function<size_t(size_t n_idx, size_t k)>
39
+ > get_rhs_packed_offset;
40
  size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
41
  size_t (*get_dst_size)(size_t m, size_t n);
42
+ std::variant<
43
+ std::function<void(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
44
+ float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max)>,
45
+ std::function<void(size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
46
+ size_t dst_stride_col, float clamp_min, float clamp_max)>
47
+ > run_kernel;
48
  };
49
 
50
  struct lhs_packing_info {
51
  size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
52
+ std::variant<
53
+ std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
54
+ std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)>
55
+ > get_packed_offset;
56
+ std::variant<
57
+ std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
58
+ std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)>
59
+ > packed_size;
60
+ std::variant<
61
+ std::function<void(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
62
+ size_t lhs_stride, void* lhs_packed)>,
63
+ std::function<void(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
64
+ void* lhs_packed)>
65
+ > pack_func;
66
  };
67
 
68
  struct rhs_packing_info {
69
+ std::variant<
70
+ std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
71
+ std::function<size_t(size_t n, size_t k)>
72
+ > packed_size;
73
+ std::variant<
74
+ std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
75
+ const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
76
+ std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
77
+ const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
78
+ > pack_func;
79
  };
80
 
81
  struct ggml_kleidiai_kernels {
 
85
  rhs_packing_info rhs_info;
86
 
87
  cpu_feature required_cpu;
88
+ ggml_type lhs_type;
89
+ ggml_type rhs_type;
90
+ ggml_type op_type;
91
  };
92
 
93
+ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
94
+ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
ggml/src/ggml-cpu/kleidiai/kleidiai.cpp CHANGED
@@ -34,8 +34,9 @@
34
  #include "ggml-common.h"
35
 
36
  struct ggml_kleidiai_context {
 
37
  ggml_kleidiai_kernels * kernels;
38
- } static ctx = { NULL };
39
 
40
  static void init_kleidiai_context(void) {
41
 
@@ -47,18 +48,18 @@ static void init_kleidiai_context(void) {
47
  const char *env_var = getenv("GGML_KLEIDIAI_SME");
48
  int sme_enabled = 0;
49
 
50
- cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
51
- (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
52
- (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
53
 
54
  if (env_var) {
55
  sme_enabled = atoi(env_var);
56
  }
57
 
58
  if (sme_enabled != 0) {
59
- features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
60
  }
61
- ctx.kernels = ggml_kleidiai_select_kernels(features);
62
  }
63
  ggml_critical_section_end();
64
  }
@@ -68,95 +69,275 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
68
  return tensor->ne[dim];
69
  }
70
 
 
 
 
 
 
 
 
 
 
 
 
71
  namespace ggml::cpu::kleidiai {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  class tensor_traits : public ggml::cpu::tensor_traits {
73
  bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
74
- GGML_ASSERT(ctx.kernels);
75
- kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
 
76
 
77
  size_t k = op->src[0]->ne[0];
 
78
  size_t m = op->src[1]->ne[1];
79
 
80
  size_t mr = kernel->get_mr();
81
  size_t kr = kernel->get_kr();
82
  size_t sr = kernel->get_sr();
83
 
84
- size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr);
 
 
 
 
 
 
 
 
85
 
86
  return true;
87
  }
88
 
 
89
  bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
90
  if (dst->op == GGML_OP_MUL_MAT) {
91
- const ggml_tensor * src0 = dst->src[0];
92
- const ggml_tensor * src1 = dst->src[1];
 
 
 
 
 
 
93
 
94
- GGML_TENSOR_BINARY_OP_LOCALS
 
95
 
96
- GGML_ASSERT(ctx.kernels);
97
- kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
98
- lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
99
 
100
- GGML_ASSERT(kernel);
101
 
102
- const int ith = params->ith;
103
- const int nth = params->nth;
104
 
105
- const size_t k = ne00;
106
- const size_t m = ne11;
107
- const size_t n = ne01;
108
 
109
- const size_t n_step = kernel->get_n_step();
110
- const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
111
- const size_t n_start = ith * num_n_per_thread;
112
 
113
- size_t n_to_process = num_n_per_thread;
114
- if ((n_start + n_to_process) > n) {
115
- n_to_process = n - n_start;
116
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
119
- uint8_t * lhs_packed = (uint8_t*)params->wdata;
120
- const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
 
121
 
122
- size_t mr = kernel->get_mr();
123
- size_t kr = kernel->get_kr();
124
- size_t sr = kernel->get_sr();
125
 
126
- // Calculate number of columns to be processed per thread
127
- const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
128
- const size_t m_start = ith * num_m_per_thread;
129
- size_t m_to_process = num_m_per_thread;
130
- if ((m_start + m_to_process) > m) {
131
- m_to_process = m - m_start;
 
 
 
 
 
132
  }
133
 
134
- if(m_start < m) {
135
- // Transform LHS
136
- const size_t src_stride = src1->nb[1];
137
- const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
138
- const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
139
- void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
140
 
141
- lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
 
142
  }
143
 
144
  ggml_barrier(params->threadpool);
145
 
146
- // Perform the operation
147
- const size_t dst_stride = dst->nb[1];
148
- const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr);
149
- const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0);
150
- const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
151
- const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
152
- const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
153
- float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
154
-
155
- kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr,
156
- dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
157
- return true;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  }
159
- return false;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  }
161
 
162
  public:
@@ -169,13 +350,13 @@ public:
169
  size_t sr = ctx.kernels->gemm.get_sr();
170
 
171
  #ifndef NDEBUG
172
- const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0);
173
  GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
174
  #endif
175
  struct kai_rhs_pack_qs4cxs1s0_param params;
176
  params.lhs_zero_point = 1;
177
  params.rhs_zero_point = 8;
178
- ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, &params);
179
 
180
  return 0;
181
 
@@ -189,7 +370,7 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
189
  }
190
  } // namespace ggml::cpu::kleidiai
191
 
192
- GGML_API enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
193
  tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
194
 
195
  GGML_UNUSED(buffer);
@@ -238,12 +419,11 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
238
  namespace ggml::cpu::kleidiai {
239
  class extra_buffer_type : ggml::cpu::extra_buffer_type {
240
  bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
241
- if ( op->op == GGML_OP_MUL_MAT &&
242
- op->src[0]->type == GGML_TYPE_Q4_0 &&
243
- op->src[0]->buffer &&
244
- (ggml_n_dims(op->src[0]) == 2) &&
245
- op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels
246
- ) {
247
  if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
248
  return false;
249
  }
@@ -260,6 +440,19 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
260
  if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
261
  return (ggml::cpu::tensor_traits *) op->src[0]->extra;
262
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  }
264
  return nullptr;
265
  }
 
34
  #include "ggml-common.h"
35
 
36
  struct ggml_kleidiai_context {
37
+ cpu_feature features;
38
  ggml_kleidiai_kernels * kernels;
39
+ } static ctx = { CPU_FEATURE_NONE, NULL };
40
 
41
  static void init_kleidiai_context(void) {
42
 
 
48
  const char *env_var = getenv("GGML_KLEIDIAI_SME");
49
  int sme_enabled = 0;
50
 
51
+ ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
52
+ (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
53
+ (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
54
 
55
  if (env_var) {
56
  sme_enabled = atoi(env_var);
57
  }
58
 
59
  if (sme_enabled != 0) {
60
+ ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
61
  }
62
+ ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
63
  }
64
  ggml_critical_section_end();
65
  }
 
69
  return tensor->ne[dim];
70
  }
71
 
72
+ template<typename Ret, typename Variant, typename... Args>
73
+ static Ret variant_call(const Variant & var, Args&&... args) {
74
+ return std::visit([&](auto&& func) -> Ret {
75
+ if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
76
+ return func(std::forward<Args>(args)...);
77
+ } else {
78
+ throw std::runtime_error("Invalid function type in variant_call");
79
+ }
80
+ }, var);
81
+ }
82
+
83
  namespace ggml::cpu::kleidiai {
84
+
85
+ static size_t round_down(size_t x, size_t y) {
86
+ return y == 0 ? x : x - (x % y);
87
+ }
88
+
89
+ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
90
+ size_t src_stride = rhs_stride / sizeof(uint16_t);
91
+ size_t dst_stride = n;
92
+
93
+ for (size_t k_idx = 0; k_idx < k; ++k_idx) {
94
+ for (size_t n_idx = 0; n_idx < n; ++n_idx) {
95
+ uint16_t v = *(src + k_idx + n_idx * src_stride);
96
+ *(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
97
+ }
98
+ }
99
+ }
100
+
101
  class tensor_traits : public ggml::cpu::tensor_traits {
102
  bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
103
+ ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
104
+ GGML_ASSERT(kernels);
105
+ kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
106
 
107
  size_t k = op->src[0]->ne[0];
108
+ size_t n = op->src[0]->ne[1];
109
  size_t m = op->src[1]->ne[1];
110
 
111
  size_t mr = kernel->get_mr();
112
  size_t kr = kernel->get_kr();
113
  size_t sr = kernel->get_sr();
114
 
115
+ if (kernels->rhs_type == GGML_TYPE_Q4_0) {
116
+ size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
117
+ } else if (kernels->rhs_type == GGML_TYPE_F16) {
118
+ size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
119
+ variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
120
+ k * n * sizeof(float) + n * sizeof(float);
121
+ } else {
122
+ GGML_ASSERT(false);
123
+ }
124
 
125
  return true;
126
  }
127
 
128
+
129
  bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
130
  if (dst->op == GGML_OP_MUL_MAT) {
131
+ if (dst->src[0]->type == GGML_TYPE_Q4_0) {
132
+ return compute_forward_q4_0(params, dst);
133
+ } else if (dst->src[0]->type == GGML_TYPE_F16) {
134
+ return compute_forward_kv_cache(params, dst);
135
+ }
136
+ }
137
+ return false;
138
+ }
139
 
140
+ bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
141
+ static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
142
 
143
+ const ggml_tensor * src0 = dst->src[0];
144
+ const ggml_tensor * src1 = dst->src[1];
 
145
 
146
+ GGML_TENSOR_BINARY_OP_LOCALS
147
 
148
+ ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
149
+ GGML_ASSERT(kernels);
150
 
151
+ kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
152
+ GGML_ASSERT(kernel);
 
153
 
154
+ const int nth = params->nth;
155
+ const int ith = params->ith;
 
156
 
157
+ const int64_t lhs_batch_size0 = ne12;
158
+ const int64_t rhs_batch_size0 = ne02;
159
+ const int64_t batch_size = rhs_batch_size0;
160
+
161
+ const int64_t r = lhs_batch_size0 / rhs_batch_size0;
162
+
163
+ const int64_t m = ne11 * r;
164
+ const int64_t n = ne01;
165
+ const int64_t k = ne00;
166
+
167
+ const size_t lhs_stride = src1->nb[1];
168
+ const size_t rhs_stride = src0->nb[1];
169
+ const size_t dst_stride = dst->nb[1];
170
+
171
+ const int64_t mr = static_cast<int64_t>(kernel->get_mr());
172
+ const int64_t nr = static_cast<int64_t>(kernel->get_nr());
173
+ const int64_t kr = static_cast<int64_t>(kernel->get_kr());
174
+ const int64_t sr = static_cast<int64_t>(kernel->get_sr());
175
+
176
+ const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
177
+ const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
178
+ const size_t kxn_size = k * n * sizeof(float);
179
+ const size_t bias_size = n * sizeof(float);
180
+
181
+ const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
182
+ GGML_ASSERT(wsize_required <= params->wsize);
183
+
184
+ uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
185
+ uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
186
+ uint8_t * rhs_kxn = rhs_packed + rhs_packed_size;
187
+ uint8_t * bias = rhs_kxn + kxn_size;
188
+
189
+ for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
190
+ const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
191
+ const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
192
+ uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
193
 
194
+ // LHS packing
195
+ {
196
+ const int64_t m_roundup_mr = kai_roundup(m, mr);
197
+ const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
198
 
199
+ if (ith < num_threads) {
200
+ const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
201
+ const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
202
 
203
+ const int64_t m_start = ith * num_m_per_thread0;
204
+ const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
205
+
206
+ const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
207
+ const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
208
+
209
+ const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
210
+ void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
211
+
212
+ variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
213
+ }
214
  }
215
 
216
+ // RHS packing
217
+ if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
218
+ // First thread to reach this point handles RHS packing
219
+ memset(bias, 0, n * sizeof(float));
220
+ transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
221
+ reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
222
 
223
+ variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
224
+ rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
225
  }
226
 
227
  ggml_barrier(params->threadpool);
228
 
229
+ first_to_arrive.clear(std::memory_order_release);
230
+
231
+ // Perform the matmul
232
+ {
233
+ const int64_t m_to_process = m;
234
+ const int64_t m_start = 0;
235
+
236
+ const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
237
+ const int64_t num_threads = KAI_MIN(n / n_step, nth);
238
+
239
+ if (ith < num_threads) {
240
+ const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
241
+ const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
242
+
243
+ const int64_t n_start = ith * num_n_per_thread0;
244
+ const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
245
+
246
+ const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
247
+ const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
248
+ const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
249
+
250
+ const void * lhs_ptr = lhs_packed + lhs_packed_offset;
251
+ const void * rhs_ptr = rhs_packed + rhs_packed_offset;
252
+ float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
253
+
254
+ variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
255
+ }
256
+ }
257
+
258
+ if (batch_idx != batch_size - 1) {
259
+ // This barrier is necessary when the batch size is larger than 1. While processing a batch,
260
+ // the work data buffer (params->wdata) is used as temporary storage which means that only
261
+ // a single batch can be processed at any given time. No barrier is needed for the last
262
+ // batch since GGML inserts a barrier between the execution of every operator.
263
+ ggml_barrier(params->threadpool);
264
+ }
265
  }
266
+
267
+ return true;
268
+ }
269
+
270
+ bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
271
+ const ggml_tensor * src0 = dst->src[0];
272
+ const ggml_tensor * src1 = dst->src[1];
273
+
274
+ GGML_TENSOR_BINARY_OP_LOCALS
275
+
276
+ ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
277
+ GGML_ASSERT(kernels);
278
+
279
+ kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
280
+ lhs_packing_info * lhs_info = &kernels->lhs_info;
281
+
282
+ GGML_ASSERT(kernel);
283
+
284
+ const int ith = params->ith;
285
+ const int nth = params->nth;
286
+
287
+ const size_t k = ne00;
288
+ const size_t m = ne11;
289
+ const size_t n = ne01;
290
+
291
+ size_t mr = kernel->get_mr();
292
+ size_t kr = kernel->get_kr();
293
+ size_t sr = kernel->get_sr();
294
+
295
+ const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
296
+ uint8_t * lhs_packed = (uint8_t*)params->wdata;
297
+ const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
298
+
299
+ const size_t n_step = kernel->get_n_step();
300
+ const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
301
+ const size_t n_start = ith * num_n_per_thread;
302
+
303
+ size_t n_to_process = num_n_per_thread;
304
+ if ((n_start + n_to_process) > n) {
305
+ n_to_process = n - n_start;
306
+ }
307
+
308
+ // Calculate number of columns to be processed per thread
309
+ const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
310
+ const size_t m_start = ith * num_m_per_thread;
311
+ size_t m_to_process = num_m_per_thread;
312
+ if ((m_start + m_to_process) > m) {
313
+ m_to_process = m - m_start;
314
+ }
315
+
316
+ if (m_start < m) {
317
+ // Transform LHS
318
+ const size_t src_stride = src1->nb[1];
319
+ const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
320
+ const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
321
+ void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
322
+
323
+ variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
324
+ }
325
+
326
+ ggml_barrier(params->threadpool);
327
+
328
+ // Perform the operation
329
+ const size_t dst_stride = dst->nb[1];
330
+ const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
331
+ const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
332
+ const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
333
+ const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
334
+ const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
335
+ float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
336
+
337
+ variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
338
+ sizeof(float), -FLT_MAX, FLT_MAX);
339
+
340
+ return true;
341
  }
342
 
343
  public:
 
350
  size_t sr = ctx.kernels->gemm.get_sr();
351
 
352
  #ifndef NDEBUG
353
+ const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
354
  GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
355
  #endif
356
  struct kai_rhs_pack_qs4cxs1s0_param params;
357
  params.lhs_zero_point = 1;
358
  params.rhs_zero_point = 8;
359
+ variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
360
 
361
  return 0;
362
 
 
370
  }
371
  } // namespace ggml::cpu::kleidiai
372
 
373
+ static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
374
  tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
375
 
376
  GGML_UNUSED(buffer);
 
419
  namespace ggml::cpu::kleidiai {
420
  class extra_buffer_type : ggml::cpu::extra_buffer_type {
421
  bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
422
+ if (op->op == GGML_OP_MUL_MAT &&
423
+ op->src[0]->type == GGML_TYPE_Q4_0 &&
424
+ op->src[0]->buffer &&
425
+ (ggml_n_dims(op->src[0]) == 2) &&
426
+ op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
 
427
  if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
428
  return false;
429
  }
 
440
  if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
441
  return (ggml::cpu::tensor_traits *) op->src[0]->extra;
442
  }
443
+ else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
444
+ op->src[0]->op == GGML_OP_VIEW &&
445
+ (op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
446
+ op->src[1]->ne[1] > 1) {
447
+ if ((op->src[0]->nb[0] != 2) ||
448
+ (op->src[1]->nb[0] != 4) ||
449
+ (op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
450
+ (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
451
+ return nullptr;
452
+ }
453
+
454
+ return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
455
+ }
456
  }
457
  return nullptr;
458
  }