ggerganov ngxson HF Staff slaren commited on
Commit
bf225d6
·
1 Parent(s): 344310a

llama : add gpt-oss (llama/15091)

Browse files

* oai moe

* compat with new checkpoint

* add attn sink impl

* add rope scaling yarn

* logits match with latest transformers code

* wip chat template

* rm trailing space

* use ggml_scale_bias

* rm redundant is_swa_all

* convert interleaved gate_up

* graph : fix activation function to match reference (llama/7)

* vocab : handle o200k_harmony special tokens

* ggml : add attention sinks support (llama/1)

* llama : add attn sinks

* ggml : add attn sinks

* cuda : add attn sinks

* vulkan : add support for sinks in softmax

remove unnecessary return

* ggml : add fused swiglu_oai op (llama/11)

* ggml : add fused swiglu_oai op

* Update ggml/src/ggml-cpu/ops.cpp

Co-authored-by: Georgi Gerganov <[email protected]>

* update CUDA impl

* cont : metal impl

* add vulkan impl

* test-backend-ops : more test cases, clean up

* llama : remove unfused impl

* remove extra lines

---------

Co-authored-by: Georgi Gerganov <[email protected]>

---------

Co-authored-by: slaren <[email protected]>

* repack mxfp4 upon conversion

* clean up a bit

* enable thinking

* add quick hack to render only some special tokens

* fix bf16 conversion

* remove vocab hack

* webui ok

* support chat parsing for gpt-oss

* fix webui

* direct mapping mxfp4, FINALLY

* force using mxfp4

* properly use lazy tensor

* ggml : add mxfp4

ggml : use e8m0 conversion instead of powf

Co-authored-by: Diego Devesa <[email protected]>

change kvalues_mxfp4 table to match e2m1 (llama/6)

metal : remove quantization for now (not used)

cuda : fix disabled CUDA graphs due to ffn moe bias

vulkan : add support for mxfp4

cont : add cm2 dequant

* ggml : add ggml_add_id (llama/13)

* ggml : add ggml_add_id

* add cuda impl

* llama : add weight support check for add_id

* perf opt

* add vulkan impl

* rename cuda files

* add metal impl

* allow in-place ggml_add_id

* llama : keep biases on CPU with --cpu-moe

* llama : fix compile error

ggml-ci

* cuda : add fallback for __nv_cvt_e8m0_to_bf16raw

ggml-ci

* cleanup

ggml-ci

* sycl : fix supports_op for MXFP4

ggml-ci

* fix Unknown reasoning format

* ggml-cpu : fix AVX build

ggml-ci

* fix hip build

ggml-ci

* cuda : add mxfp4 dequantization support for cuBLAS

ggml-ci

* ggml-cpu : fix mxfp4 fallback definitions for some architectures

ggml-ci

* cuda : fix version required for __nv_cvt_e8m0_to_bf16raw

---------

Co-authored-by: Xuan Son Nguyen <[email protected]>
Co-authored-by: slaren <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ggml/include/ggml.h +37 -1
  2. ggml/src/ggml-alloc.c +1 -0
  3. ggml/src/ggml-cann/ggml-cann.cpp +8 -0
  4. ggml/src/ggml-common.h +17 -0
  5. ggml/src/ggml-cpu/arch-fallback.h +6 -0
  6. ggml/src/ggml-cpu/arch/arm/quants.c +61 -0
  7. ggml/src/ggml-cpu/arch/x86/quants.c +96 -8
  8. ggml/src/ggml-cpu/ggml-cpu.c +14 -1
  9. ggml/src/ggml-cpu/ops.cpp +207 -9
  10. ggml/src/ggml-cpu/ops.h +2 -7
  11. ggml/src/ggml-cpu/quants.c +35 -0
  12. ggml/src/ggml-cpu/quants.h +8 -0
  13. ggml/src/ggml-cpu/vec.h +19 -4
  14. ggml/src/ggml-cuda/add-id.cu +58 -0
  15. ggml/src/ggml-cuda/add-id.cuh +3 -0
  16. ggml/src/ggml-cuda/common.cuh +26 -0
  17. ggml/src/ggml-cuda/convert.cu +28 -0
  18. ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  19. ggml/src/ggml-cuda/fattn-mma-f16.cuh +3 -1
  20. ggml/src/ggml-cuda/fattn-tile-f16.cu +2 -1
  21. ggml/src/ggml-cuda/fattn-tile-f32.cu +2 -1
  22. ggml/src/ggml-cuda/fattn-vec-f16.cuh +39 -3
  23. ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -2
  24. ggml/src/ggml-cuda/fattn-wmma-f16.cu +2 -1
  25. ggml/src/ggml-cuda/fattn.cu +16 -5
  26. ggml/src/ggml-cuda/ggml-cuda.cu +24 -1
  27. ggml/src/ggml-cuda/im2col.cu +3 -2
  28. ggml/src/ggml-cuda/mmq.cu +4 -0
  29. ggml/src/ggml-cuda/mmq.cuh +80 -2
  30. ggml/src/ggml-cuda/mmvq.cu +9 -0
  31. ggml/src/ggml-cuda/softmax.cu +16 -10
  32. ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  33. ggml/src/ggml-cuda/unary.cu +75 -0
  34. ggml/src/ggml-cuda/unary.cuh +2 -0
  35. ggml/src/ggml-cuda/vecdotq.cuh +52 -16
  36. ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  37. ggml/src/ggml-impl.h +61 -0
  38. ggml/src/ggml-metal/ggml-metal-impl.h +14 -0
  39. ggml/src/ggml-metal/ggml-metal.m +109 -9
  40. ggml/src/ggml-metal/ggml-metal.metal +272 -15
  41. ggml/src/ggml-opencl/ggml-opencl.cpp +2 -0
  42. ggml/src/ggml-quants.c +105 -11
  43. ggml/src/ggml-quants.h +6 -0
  44. ggml/src/ggml-sycl/ggml-sycl.cpp +10 -10
  45. ggml/src/ggml-vulkan/ggml-vulkan.cpp +129 -10
  46. ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  47. ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  48. ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
  49. ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  50. ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
ggml/include/ggml.h CHANGED
@@ -304,6 +304,16 @@
304
  GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
305
  GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
306
 
 
 
 
 
 
 
 
 
 
 
307
  #define GGML_TENSOR_BINARY_OP_LOCALS01 \
308
  GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
309
  GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
@@ -395,7 +405,8 @@ extern "C" {
395
  // GGML_TYPE_IQ4_NL_4_4 = 36,
396
  // GGML_TYPE_IQ4_NL_4_8 = 37,
397
  // GGML_TYPE_IQ4_NL_8_8 = 38,
398
- GGML_TYPE_COUNT = 39,
 
399
  };
400
 
401
  // precision
@@ -430,6 +441,7 @@ extern "C" {
430
  GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
431
  GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
432
  GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
 
433
  };
434
 
435
  // available tensor operations:
@@ -438,6 +450,7 @@ extern "C" {
438
 
439
  GGML_OP_DUP,
440
  GGML_OP_ADD,
 
441
  GGML_OP_ADD1,
442
  GGML_OP_ACC,
443
  GGML_OP_SUB,
@@ -557,6 +570,7 @@ extern "C" {
557
  GGML_GLU_OP_REGLU,
558
  GGML_GLU_OP_GEGLU,
559
  GGML_GLU_OP_SWIGLU,
 
560
  GGML_GLU_OP_GEGLU_ERF,
561
  GGML_GLU_OP_GEGLU_QUICK,
562
 
@@ -831,6 +845,13 @@ extern "C" {
831
  struct ggml_tensor * b,
832
  enum ggml_type type);
833
 
 
 
 
 
 
 
 
834
  GGML_API struct ggml_tensor * ggml_add1(
835
  struct ggml_context * ctx,
836
  struct ggml_tensor * a,
@@ -1198,6 +1219,13 @@ extern "C" {
1198
  struct ggml_tensor * a,
1199
  struct ggml_tensor * b);
1200
 
 
 
 
 
 
 
 
1201
  // normalize along rows
1202
  GGML_API struct ggml_tensor * ggml_norm(
1203
  struct ggml_context * ctx,
@@ -1570,6 +1598,10 @@ extern "C" {
1570
  float scale,
1571
  float max_bias);
1572
 
 
 
 
 
1573
  GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
1574
  struct ggml_context * ctx,
1575
  struct ggml_tensor * a,
@@ -2052,6 +2084,10 @@ extern "C" {
2052
  GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
2053
  const struct ggml_tensor * a);
2054
 
 
 
 
 
2055
  // TODO: needs to be adapted to ggml_flash_attn_ext
2056
  GGML_API struct ggml_tensor * ggml_flash_attn_back(
2057
  struct ggml_context * ctx,
 
304
  GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
305
  GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
306
 
307
+ #define GGML_TENSOR_TERNARY_OP_LOCALS \
308
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
309
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
310
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
311
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
312
+ GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
313
+ GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
314
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
315
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
316
+
317
  #define GGML_TENSOR_BINARY_OP_LOCALS01 \
318
  GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
319
  GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
 
405
  // GGML_TYPE_IQ4_NL_4_4 = 36,
406
  // GGML_TYPE_IQ4_NL_4_8 = 37,
407
  // GGML_TYPE_IQ4_NL_8_8 = 38,
408
+ GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
409
+ GGML_TYPE_COUNT = 40,
410
  };
411
 
412
  // precision
 
441
  GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
442
  GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
443
  GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
444
+ GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
445
  };
446
 
447
  // available tensor operations:
 
450
 
451
  GGML_OP_DUP,
452
  GGML_OP_ADD,
453
+ GGML_OP_ADD_ID,
454
  GGML_OP_ADD1,
455
  GGML_OP_ACC,
456
  GGML_OP_SUB,
 
570
  GGML_GLU_OP_REGLU,
571
  GGML_GLU_OP_GEGLU,
572
  GGML_GLU_OP_SWIGLU,
573
+ GGML_GLU_OP_SWIGLU_OAI,
574
  GGML_GLU_OP_GEGLU_ERF,
575
  GGML_GLU_OP_GEGLU_QUICK,
576
 
 
845
  struct ggml_tensor * b,
846
  enum ggml_type type);
847
 
848
+ // dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
849
+ GGML_API struct ggml_tensor * ggml_add_id(
850
+ struct ggml_context * ctx,
851
+ struct ggml_tensor * a,
852
+ struct ggml_tensor * b,
853
+ struct ggml_tensor * ids);
854
+
855
  GGML_API struct ggml_tensor * ggml_add1(
856
  struct ggml_context * ctx,
857
  struct ggml_tensor * a,
 
1219
  struct ggml_tensor * a,
1220
  struct ggml_tensor * b);
1221
 
1222
+ GGML_API struct ggml_tensor * ggml_swiglu_oai(
1223
+ struct ggml_context * ctx,
1224
+ struct ggml_tensor * a,
1225
+ struct ggml_tensor * b,
1226
+ float alpha,
1227
+ float limit);
1228
+
1229
  // normalize along rows
1230
  GGML_API struct ggml_tensor * ggml_norm(
1231
  struct ggml_context * ctx,
 
1598
  float scale,
1599
  float max_bias);
1600
 
1601
+ GGML_API void ggml_soft_max_add_sinks(
1602
+ struct ggml_tensor * a,
1603
+ struct ggml_tensor * sinks);
1604
+
1605
  GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
1606
  struct ggml_context * ctx,
1607
  struct ggml_tensor * a,
 
2084
  GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
2085
  const struct ggml_tensor * a);
2086
 
2087
+ GGML_API void ggml_flash_attn_ext_add_sinks(
2088
+ struct ggml_tensor * a,
2089
+ struct ggml_tensor * sinks);
2090
+
2091
  // TODO: needs to be adapted to ggml_flash_attn_ext
2092
  GGML_API struct ggml_tensor * ggml_flash_attn_back(
2093
  struct ggml_context * ctx,
ggml/src/ggml-alloc.c CHANGED
@@ -29,6 +29,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
29
  case GGML_OP_DIAG_MASK_ZERO:
30
  case GGML_OP_DIAG_MASK_INF:
31
  case GGML_OP_ADD:
 
32
  case GGML_OP_ADD1:
33
  case GGML_OP_SUB:
34
  case GGML_OP_MUL:
 
29
  case GGML_OP_DIAG_MASK_ZERO:
30
  case GGML_OP_DIAG_MASK_INF:
31
  case GGML_OP_ADD:
32
+ case GGML_OP_ADD_ID:
33
  case GGML_OP_ADD1:
34
  case GGML_OP_SUB:
35
  case GGML_OP_MUL:
ggml/src/ggml-cann/ggml-cann.cpp CHANGED
@@ -2340,6 +2340,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2340
  memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
2341
  return bias == 0.0f; // TODO: support bias != 0.0f
2342
  case GGML_OP_SOFT_MAX:
 
 
 
 
2343
  // TODO: support broadcast
2344
  // ref: https://github.com/ggml-org/llama.cpp/pull/14435
2345
  return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
@@ -2354,6 +2358,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2354
  if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
2355
  return false;
2356
  }
 
 
 
 
2357
  if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
2358
  // different head sizes of K and V are not supported yet
2359
  return false;
 
2340
  memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
2341
  return bias == 0.0f; // TODO: support bias != 0.0f
2342
  case GGML_OP_SOFT_MAX:
2343
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
2344
+ if (op->src[2]) {
2345
+ return false;
2346
+ }
2347
  // TODO: support broadcast
2348
  // ref: https://github.com/ggml-org/llama.cpp/pull/14435
2349
  return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
 
2358
  if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
2359
  return false;
2360
  }
2361
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
2362
+ if (op->src[4]) {
2363
+ return false;
2364
+ }
2365
  if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
2366
  // different head sizes of K and V are not supported yet
2367
  return false;
ggml/src/ggml-common.h CHANGED
@@ -99,6 +99,9 @@ typedef sycl::half2 ggml_half2;
99
  #define QI4_1 (QK4_1 / (4 * QR4_1))
100
  #define QR4_1 2
101
 
 
 
 
102
  #define QI5_0 (QK5_0 / (4 * QR5_0))
103
  #define QR5_0 2
104
 
@@ -184,6 +187,13 @@ typedef struct {
184
  } block_q4_1;
185
  static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
186
 
 
 
 
 
 
 
 
187
  #define QK5_0 32
188
  typedef struct {
189
  ggml_half d; // delta
@@ -1074,10 +1084,17 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
1074
  0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
1075
  GGML_TABLE_END()
1076
 
 
1077
  GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
1078
  -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
1079
  GGML_TABLE_END()
1080
 
 
 
 
 
 
 
1081
  #define NGRID_IQ1S 2048
1082
  #define IQ1S_DELTA 0.125f
1083
  #define IQ1M_DELTA 0.125f
 
99
  #define QI4_1 (QK4_1 / (4 * QR4_1))
100
  #define QR4_1 2
101
 
102
+ #define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
103
+ #define QR_MXFP4 2
104
+
105
  #define QI5_0 (QK5_0 / (4 * QR5_0))
106
  #define QR5_0 2
107
 
 
187
  } block_q4_1;
188
  static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
189
 
190
+ #define QK_MXFP4 32
191
+ typedef struct {
192
+ uint8_t e; // E8M0
193
+ uint8_t qs[QK_MXFP4/2];
194
+ } block_mxfp4;
195
+ static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
196
+
197
  #define QK5_0 32
198
  typedef struct {
199
  ggml_half d; // delta
 
1084
  0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
1085
  GGML_TABLE_END()
1086
 
1087
+ // TODO: fix name to kvalues_iq4_nl
1088
  GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
1089
  -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
1090
  GGML_TABLE_END()
1091
 
1092
+ // e2m1 values (doubled)
1093
+ // ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
1094
+ GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
1095
+ 0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
1096
+ GGML_TABLE_END()
1097
+
1098
  #define NGRID_IQ1S 2048
1099
  #define IQ1S_DELTA 0.125f
1100
  #define IQ1M_DELTA 0.125f
ggml/src/ggml-cpu/arch-fallback.h CHANGED
@@ -13,6 +13,7 @@
13
  #define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
14
  #define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
15
  #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
 
16
  #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
17
  #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
18
  #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@@ -68,6 +69,7 @@
68
  #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
69
  #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
70
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
 
71
  // repack.cpp
72
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
73
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -90,6 +92,7 @@
90
  #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
91
  #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
92
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
 
93
  // repack.cpp
94
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
95
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -120,6 +123,7 @@
120
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
121
  #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
122
  #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
 
123
  // repack.cpp
124
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
125
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -149,6 +153,7 @@
149
  #define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
150
  #define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
151
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
 
152
  // repack.cpp
153
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
154
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -179,6 +184,7 @@
179
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
180
  #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
181
  #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
 
182
  // repack.cpp
183
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
184
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
 
13
  #define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
14
  #define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
15
  #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
16
+ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
17
  #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
18
  #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
19
  #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
 
69
  #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
70
  #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
71
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
72
+ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
73
  // repack.cpp
74
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
75
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
 
92
  #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
93
  #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
94
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
95
+ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
96
  // repack.cpp
97
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
98
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
 
123
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
124
  #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
125
  #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
126
+ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
127
  // repack.cpp
128
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
129
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
 
153
  #define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
154
  #define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
155
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
156
+ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
157
  // repack.cpp
158
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
159
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
 
184
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
185
  #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
186
  #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
187
+ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
188
  // repack.cpp
189
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
190
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
ggml/src/ggml-cpu/arch/arm/quants.c CHANGED
@@ -589,6 +589,67 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
589
  *s = sumf;
590
  }
591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
  void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
593
  const int qk = QK8_0;
594
  const int nb = n / qk;
 
589
  *s = sumf;
590
  }
591
 
592
+ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
593
+ assert(nrc == 1);
594
+ UNUSED(nrc);
595
+ UNUSED(bx);
596
+ UNUSED(by);
597
+ UNUSED(bs);
598
+ assert(n % QK_MXFP4 == 0);
599
+ static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
600
+
601
+ const block_mxfp4 * GGML_RESTRICT x = vx;
602
+ const block_q8_0 * GGML_RESTRICT y = vy;
603
+
604
+ const int nb = n / QK_MXFP4;
605
+
606
+ int ib = 0;
607
+ float sumf = 0;
608
+
609
+ #if defined __ARM_NEON
610
+ const int8x16_t values = vld1q_s8(kvalues_mxfp4);
611
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
612
+ uint8x16x2_t q4bits;
613
+ int8x16x4_t q4b;
614
+ int8x16x4_t q8b;
615
+ int32x4_t prod_1;
616
+ int32x4_t prod_2;
617
+
618
+ for (; ib + 1 < nb; ib += 2) {
619
+ q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
620
+ q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
621
+ q8b.val[0] = vld1q_s8(y[ib + 0].qs);
622
+ q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
623
+ q8b.val[2] = vld1q_s8(y[ib + 1].qs);
624
+ q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);
625
+
626
+ q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
627
+ q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
628
+ q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
629
+ q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
630
+
631
+ prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
632
+ prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
633
+
634
+ sumf +=
635
+ GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
636
+ GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
637
+ }
638
+
639
+ #endif
640
+ for (; ib < nb; ++ib) {
641
+ const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
642
+ int sumi1 = 0;
643
+ int sumi2 = 0;
644
+ for (int j = 0; j < QK_MXFP4/2; ++j) {
645
+ sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
646
+ sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
647
+ }
648
+ sumf += d * (sumi1 + sumi2);
649
+ }
650
+ *s = sumf;
651
+ }
652
+
653
  void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
654
  const int qk = QK8_0;
655
  const int nb = n / qk;
ggml/src/ggml-cpu/arch/x86/quants.c CHANGED
@@ -66,6 +66,12 @@ static inline int hsum_i32_4(const __m128i a) {
66
  }
67
 
68
  #if defined(__AVX2__) || defined(__AVX512F__)
 
 
 
 
 
 
69
  // spread 32 bits to 32 bytes { 0x00, 0xFF }
70
  static inline __m256i bytes_from_bits_32(const uint8_t * x) {
71
  uint32_t x32;
@@ -261,6 +267,11 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const
261
  return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
262
  _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
263
  }
 
 
 
 
 
264
  #endif
265
  #elif defined(__SSSE3__)
266
  // horizontally add 4x4 floats
@@ -746,6 +757,91 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
746
  #endif
747
  }
748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
749
  void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
750
  const int qk = QK8_0;
751
  const int nb = n / qk;
@@ -3206,14 +3302,6 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
3206
  #endif
3207
  }
3208
 
3209
- #if defined(__AVX2__)
3210
- static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
3211
- const __m256i ax = _mm256_sign_epi8(x, x);
3212
- const __m256i sy = _mm256_sign_epi8(y, x);
3213
- return _mm256_maddubs_epi16(ax, sy);
3214
- }
3215
- #endif
3216
-
3217
  void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3218
  assert(n % QK_K == 0);
3219
  assert(nrc == 1);
 
66
  }
67
 
68
  #if defined(__AVX2__) || defined(__AVX512F__)
69
+ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
70
+ const __m256i ax = _mm256_sign_epi8(x, x);
71
+ const __m256i sy = _mm256_sign_epi8(y, x);
72
+ return _mm256_maddubs_epi16(ax, sy);
73
+ }
74
+
75
  // spread 32 bits to 32 bytes { 0x00, 0xFF }
76
  static inline __m256i bytes_from_bits_32(const uint8_t * x) {
77
  uint32_t x32;
 
267
  return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
268
  _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
269
  }
270
+
271
+ static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) {
272
+ return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
273
+ _mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
274
+ }
275
  #endif
276
  #elif defined(__SSSE3__)
277
  // horizontally add 4x4 floats
 
757
  #endif
758
  }
759
 
760
+ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
761
+ assert(nrc == 1);
762
+ UNUSED(nrc);
763
+ UNUSED(bx);
764
+ UNUSED(by);
765
+ UNUSED(bs);
766
+ assert(n % QK_MXFP4 == 0);
767
+ static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
768
+
769
+ const block_mxfp4 * GGML_RESTRICT x = vx;
770
+ const block_q8_0 * GGML_RESTRICT y = vy;
771
+
772
+ const int nb = n / QK_MXFP4;
773
+
774
+ int ib = 0;
775
+ float sumf = 0;
776
+
777
+ #if defined __AVX2__
778
+
779
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
780
+ const __m128i m4b = _mm_set1_epi8(0x0f);
781
+ const __m256i mone = _mm256_set1_epi16(1);
782
+
783
+ __m256 accum1 = _mm256_setzero_ps();
784
+ __m256 accum2 = _mm256_setzero_ps();
785
+ for (; ib + 1 < nb; ib += 2) {
786
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
787
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
788
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
789
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
790
+ const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
791
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
792
+ const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
793
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
794
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
795
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
796
+ const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
797
+ const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
798
+ accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
799
+ _mm256_cvtepi32_ps(p_1), accum1);
800
+ accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
801
+ _mm256_cvtepi32_ps(p_2), accum2);
802
+ }
803
+
804
+ sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
805
+
806
+ #elif defined __AVX__
807
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
808
+ const __m128i m4b = _mm_set1_epi8(0x0f);
809
+
810
+ __m256 accum = _mm256_setzero_ps();
811
+ for (; ib + 1 < nb; ib += 2) {
812
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
813
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
814
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
815
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
816
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
817
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
818
+
819
+ const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
820
+ const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
821
+ const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
822
+ const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
823
+
824
+ const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
825
+ const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d);
826
+ accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
827
+ }
828
+
829
+ sumf = hsum_float_8(accum);
830
+
831
+ #endif
832
+ for (; ib < nb; ++ib) {
833
+ const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
834
+ int sumi1 = 0;
835
+ int sumi2 = 0;
836
+ for (int j = 0; j < QK_MXFP4/2; ++j) {
837
+ sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
838
+ sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
839
+ }
840
+ sumf += d * (sumi1 + sumi2);
841
+ }
842
+ *s = sumf;
843
+ }
844
+
845
  void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
846
  const int qk = QK8_0;
847
  const int nb = n / qk;
 
3302
  #endif
3303
  }
3304
 
 
 
 
 
 
 
 
 
3305
  void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3306
  assert(n % QK_K == 0);
3307
  assert(nrc == 1);
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -253,6 +253,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
253
  .vec_dot_type = GGML_TYPE_Q8_1,
254
  .nrows = 1,
255
  },
 
 
 
 
 
 
256
  [GGML_TYPE_Q2_K] = {
257
  .from_float = quantize_row_q2_K,
258
  .vec_dot = ggml_vec_dot_q2_K_q8_K,
@@ -1670,6 +1676,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1670
  {
1671
  ggml_compute_forward_add(params, tensor);
1672
  } break;
 
 
 
 
1673
  case GGML_OP_ADD1:
1674
  {
1675
  ggml_compute_forward_add1(params, tensor);
@@ -1924,7 +1934,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1924
  } break;
1925
  case GGML_OP_FLASH_ATTN_EXT:
1926
  {
1927
- ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
1928
  } break;
1929
  case GGML_OP_FLASH_ATTN_BACK:
1930
  {
@@ -2111,6 +2121,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2111
  case GGML_OP_DUP:
2112
  case GGML_OP_CONT:
2113
  case GGML_OP_ADD:
 
2114
  case GGML_OP_ADD1:
2115
  case GGML_OP_ACC:
2116
  {
@@ -2172,6 +2183,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2172
  case GGML_GLU_OP_REGLU:
2173
  case GGML_GLU_OP_GEGLU:
2174
  case GGML_GLU_OP_SWIGLU:
 
2175
  case GGML_GLU_OP_GEGLU_ERF:
2176
  case GGML_GLU_OP_GEGLU_QUICK:
2177
  {
@@ -2673,6 +2685,7 @@ struct ggml_cplan ggml_graph_plan(
2673
  }
2674
  } break;
2675
  case GGML_OP_ADD:
 
2676
  case GGML_OP_ADD1:
2677
  {
2678
  if (ggml_is_quantized(node->src[0]->type)) {
 
253
  .vec_dot_type = GGML_TYPE_Q8_1,
254
  .nrows = 1,
255
  },
256
+ [GGML_TYPE_MXFP4] = {
257
+ .from_float = quantize_row_mxfp4,
258
+ .vec_dot = ggml_vec_dot_mxfp4_q8_0,
259
+ .vec_dot_type = GGML_TYPE_Q8_0,
260
+ .nrows = 1,
261
+ },
262
  [GGML_TYPE_Q2_K] = {
263
  .from_float = quantize_row_q2_K,
264
  .vec_dot = ggml_vec_dot_q2_K_q8_K,
 
1676
  {
1677
  ggml_compute_forward_add(params, tensor);
1678
  } break;
1679
+ case GGML_OP_ADD_ID:
1680
+ {
1681
+ ggml_compute_forward_add_id(params, tensor);
1682
+ } break;
1683
  case GGML_OP_ADD1:
1684
  {
1685
  ggml_compute_forward_add1(params, tensor);
 
1934
  } break;
1935
  case GGML_OP_FLASH_ATTN_EXT:
1936
  {
1937
+ ggml_compute_forward_flash_attn_ext(params, tensor);
1938
  } break;
1939
  case GGML_OP_FLASH_ATTN_BACK:
1940
  {
 
2121
  case GGML_OP_DUP:
2122
  case GGML_OP_CONT:
2123
  case GGML_OP_ADD:
2124
+ case GGML_OP_ADD_ID:
2125
  case GGML_OP_ADD1:
2126
  case GGML_OP_ACC:
2127
  {
 
2183
  case GGML_GLU_OP_REGLU:
2184
  case GGML_GLU_OP_GEGLU:
2185
  case GGML_GLU_OP_SWIGLU:
2186
+ case GGML_GLU_OP_SWIGLU_OAI:
2187
  case GGML_GLU_OP_GEGLU_ERF:
2188
  case GGML_GLU_OP_GEGLU_QUICK:
2189
  {
 
2685
  }
2686
  } break;
2687
  case GGML_OP_ADD:
2688
+ case GGML_OP_ADD_ID:
2689
  case GGML_OP_ADD1:
2690
  {
2691
  if (ggml_is_quantized(node->src[0]->type)) {
ggml/src/ggml-cpu/ops.cpp CHANGED
@@ -8,6 +8,7 @@
8
  #include "vec.h"
9
 
10
  #include <float.h>
 
11
 
12
  // ggml_compute_forward_dup
13
 
@@ -1283,6 +1284,7 @@ void ggml_compute_forward_add(
1283
  case GGML_TYPE_Q5_0:
1284
  case GGML_TYPE_Q5_1:
1285
  case GGML_TYPE_Q8_0:
 
1286
  case GGML_TYPE_Q2_K:
1287
  case GGML_TYPE_Q3_K:
1288
  case GGML_TYPE_Q4_K:
@@ -1309,6 +1311,77 @@ void ggml_compute_forward_add(
1309
  }
1310
  }
1311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1312
  // ggml_compute_forward_add1
1313
 
1314
  static void ggml_compute_forward_add1_f32(
@@ -1660,6 +1733,7 @@ void ggml_compute_forward_add1(
1660
  case GGML_TYPE_Q5_1:
1661
  case GGML_TYPE_Q8_0:
1662
  case GGML_TYPE_Q8_1:
 
1663
  case GGML_TYPE_Q2_K:
1664
  case GGML_TYPE_Q3_K:
1665
  case GGML_TYPE_Q4_K:
@@ -1787,6 +1861,7 @@ void ggml_compute_forward_acc(
1787
  case GGML_TYPE_Q5_1:
1788
  case GGML_TYPE_Q8_0:
1789
  case GGML_TYPE_Q8_1:
 
1790
  case GGML_TYPE_Q2_K:
1791
  case GGML_TYPE_Q3_K:
1792
  case GGML_TYPE_Q4_K:
@@ -3614,6 +3689,93 @@ static void ggml_compute_forward_swiglu(
3614
  }
3615
  }
3616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3617
  // ggml_compute_forward_geglu_erf
3618
 
3619
  static void ggml_compute_forward_geglu_erf_f32(
@@ -4599,6 +4761,7 @@ void ggml_compute_forward_out_prod(
4599
  case GGML_TYPE_Q5_0:
4600
  case GGML_TYPE_Q5_1:
4601
  case GGML_TYPE_Q8_0:
 
4602
  case GGML_TYPE_Q2_K:
4603
  case GGML_TYPE_Q3_K:
4604
  case GGML_TYPE_Q4_K:
@@ -4873,6 +5036,7 @@ void ggml_compute_forward_set(
4873
  case GGML_TYPE_Q5_1:
4874
  case GGML_TYPE_Q8_0:
4875
  case GGML_TYPE_Q8_1:
 
4876
  case GGML_TYPE_Q2_K:
4877
  case GGML_TYPE_Q3_K:
4878
  case GGML_TYPE_Q4_K:
@@ -5134,6 +5298,7 @@ void ggml_compute_forward_get_rows(
5134
  case GGML_TYPE_Q5_1:
5135
  case GGML_TYPE_Q8_0:
5136
  case GGML_TYPE_Q8_1:
 
5137
  case GGML_TYPE_Q2_K:
5138
  case GGML_TYPE_Q3_K:
5139
  case GGML_TYPE_Q4_K:
@@ -5523,6 +5688,7 @@ static void ggml_compute_forward_soft_max_f32(
5523
 
5524
  const ggml_tensor * src0 = dst->src[0];
5525
  const ggml_tensor * src1 = dst->src[1];
 
5526
 
5527
  assert(ggml_is_contiguous(dst));
5528
  assert(ggml_are_same_shape(src0, dst));
@@ -5557,6 +5723,9 @@ static void ggml_compute_forward_soft_max_f32(
5557
 
5558
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5559
 
 
 
 
5560
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5561
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5562
  for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
@@ -5599,9 +5768,18 @@ static void ggml_compute_forward_soft_max_f32(
5599
  float max = -INFINITY;
5600
  ggml_vec_max_f32(ne00, &max, wp);
5601
 
 
 
 
 
 
5602
  ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5603
  assert(sum > 0.0);
5604
 
 
 
 
 
5605
  sum = 1.0/sum;
5606
  ggml_vec_scale_f32(ne00, dp, sum);
5607
 
@@ -5836,6 +6014,7 @@ void ggml_compute_forward_clamp(
5836
  case GGML_TYPE_Q5_1:
5837
  case GGML_TYPE_Q8_0:
5838
  case GGML_TYPE_Q8_1:
 
5839
  case GGML_TYPE_Q2_K:
5840
  case GGML_TYPE_Q3_K:
5841
  case GGML_TYPE_Q4_K:
@@ -7989,12 +8168,14 @@ void ggml_compute_forward_argsort(
7989
 
7990
  static void ggml_compute_forward_flash_attn_ext_f16(
7991
  const ggml_compute_params * params,
7992
- const ggml_tensor * q,
7993
- const ggml_tensor * k,
7994
- const ggml_tensor * v,
7995
- const ggml_tensor * mask,
7996
  ggml_tensor * dst) {
7997
 
 
 
 
 
 
 
7998
  GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
7999
  GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8000
  GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
@@ -8189,6 +8370,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8189
  }
8190
  }
8191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8192
  // V /= S
8193
  const float S_inv = 1.0f/S;
8194
  ggml_vec_scale_f32(DV, VKQ32, S_inv);
@@ -8208,17 +8406,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8208
 
8209
  void ggml_compute_forward_flash_attn_ext(
8210
  const ggml_compute_params * params,
8211
- const ggml_tensor * q,
8212
- const ggml_tensor * k,
8213
- const ggml_tensor * v,
8214
- const ggml_tensor * mask,
8215
  ggml_tensor * dst) {
8216
  switch (dst->op_params[3]) {
8217
  case GGML_PREC_DEFAULT:
8218
  case GGML_PREC_F32:
8219
  {
8220
  // uses F32 accumulators
8221
- ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
8222
  } break;
8223
  default:
8224
  {
@@ -9080,6 +9274,10 @@ void ggml_compute_forward_glu(
9080
  {
9081
  ggml_compute_forward_swiglu(params, dst);
9082
  } break;
 
 
 
 
9083
  case GGML_GLU_OP_GEGLU_ERF:
9084
  {
9085
  ggml_compute_forward_geglu_erf(params, dst);
 
8
  #include "vec.h"
9
 
10
  #include <float.h>
11
+ #include <algorithm>
12
 
13
  // ggml_compute_forward_dup
14
 
 
1284
  case GGML_TYPE_Q5_0:
1285
  case GGML_TYPE_Q5_1:
1286
  case GGML_TYPE_Q8_0:
1287
+ case GGML_TYPE_MXFP4:
1288
  case GGML_TYPE_Q2_K:
1289
  case GGML_TYPE_Q3_K:
1290
  case GGML_TYPE_Q4_K:
 
1311
  }
1312
  }
1313
 
1314
+ // ggml_compute_forward_add_id
1315
+
1316
+ static void ggml_compute_forward_add_id_f32(
1317
+ const ggml_compute_params * params,
1318
+ ggml_tensor * dst) {
1319
+
1320
+ const ggml_tensor * src0 = dst->src[0];
1321
+ const ggml_tensor * src1 = dst->src[1];
1322
+ const ggml_tensor * src2 = dst->src[2];
1323
+
1324
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
1325
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
1326
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
1327
+ GGML_ASSERT(src2->type == GGML_TYPE_I32);
1328
+
1329
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
1330
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
1331
+
1332
+ const int ith = params->ith;
1333
+ const int nth = params->nth;
1334
+
1335
+ const int nr = ggml_nrows(src0);
1336
+
1337
+ GGML_TENSOR_TERNARY_OP_LOCALS
1338
+
1339
+ GGML_ASSERT( nb0 == sizeof(float));
1340
+ GGML_ASSERT(nb10 == sizeof(float));
1341
+
1342
+ // rows per thread
1343
+ const int dr = (nr + nth - 1)/nth;
1344
+
1345
+ // row range for this thread
1346
+ const int ir0 = dr*ith;
1347
+ const int ir1 = MIN(ir0 + dr, nr);
1348
+
1349
+ for (int ir = ir0; ir < ir1; ++ir) {
1350
+ // src0 indices
1351
+ const int i3 = ir/(ne2*ne1);
1352
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
1353
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
1354
+
1355
+ // src1 indices
1356
+ const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
1357
+
1358
+ GGML_ASSERT(i11 >= 0 && i11 < ne11);
1359
+
1360
+ ggml_vec_add_f32(ne0,
1361
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
1362
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
1363
+ (float *) ((char *) src1->data + i11*nb11));
1364
+ }
1365
+ }
1366
+
1367
+ void ggml_compute_forward_add_id(
1368
+ const ggml_compute_params * params,
1369
+ ggml_tensor * dst) {
1370
+
1371
+ const ggml_tensor * src0 = dst->src[0];
1372
+
1373
+ switch (src0->type) {
1374
+ case GGML_TYPE_F32:
1375
+ {
1376
+ ggml_compute_forward_add_id_f32(params, dst);
1377
+ } break;
1378
+ default:
1379
+ {
1380
+ GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
1381
+ }
1382
+ }
1383
+ }
1384
+
1385
  // ggml_compute_forward_add1
1386
 
1387
  static void ggml_compute_forward_add1_f32(
 
1733
  case GGML_TYPE_Q5_1:
1734
  case GGML_TYPE_Q8_0:
1735
  case GGML_TYPE_Q8_1:
1736
+ case GGML_TYPE_MXFP4:
1737
  case GGML_TYPE_Q2_K:
1738
  case GGML_TYPE_Q3_K:
1739
  case GGML_TYPE_Q4_K:
 
1861
  case GGML_TYPE_Q5_1:
1862
  case GGML_TYPE_Q8_0:
1863
  case GGML_TYPE_Q8_1:
1864
+ case GGML_TYPE_MXFP4:
1865
  case GGML_TYPE_Q2_K:
1866
  case GGML_TYPE_Q3_K:
1867
  case GGML_TYPE_Q4_K:
 
3689
  }
3690
  }
3691
 
3692
+ // ggml_compute_forward_swiglu_oai
3693
+
3694
+ static void ggml_compute_forward_swiglu_oai_f32(
3695
+ const ggml_compute_params * params,
3696
+ ggml_tensor * dst) {
3697
+
3698
+ const ggml_tensor * src0 = dst->src[0];
3699
+ const ggml_tensor * src1 = dst->src[1];
3700
+ char * src0_d = (char *) src0->data;
3701
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3702
+ const size_t src0_o = src0->nb[1];
3703
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3704
+
3705
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3706
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3707
+
3708
+ if (src1) {
3709
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3710
+ GGML_ASSERT(src0->type == src1->type);
3711
+ }
3712
+
3713
+ const int ith = params->ith;
3714
+ const int nth = params->nth;
3715
+
3716
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3717
+ const int nr = ggml_nrows(src0);
3718
+
3719
+ GGML_ASSERT(dst->ne[0] == nc);
3720
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3721
+
3722
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3723
+ const float alpha = ggml_get_op_params_f32(dst, 2);
3724
+ const float limit = ggml_get_op_params_f32(dst, 3);
3725
+
3726
+ // rows per thread
3727
+ const int dr = (nr + nth - 1)/nth;
3728
+
3729
+ // row range for this thread
3730
+ const int ir0 = dr*ith;
3731
+ const int ir1 = MIN(ir0 + dr, nr);
3732
+
3733
+ for (int i1 = ir0; i1 < ir1; i1++) {
3734
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3735
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3736
+ float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
3737
+
3738
+ if (!src1) {
3739
+ src0_p += swapped ? nc : 0;
3740
+ src1_p += swapped ? 0 : nc;
3741
+ }
3742
+
3743
+ for (int k = 0; k < nc; k++) {
3744
+ const float x = std::min(src0_p[k], limit);
3745
+ const float y = std::clamp(src1_p[k], -limit, limit);
3746
+ const float out_glu = x / (1.f + expf(alpha * (-x)));
3747
+ dst_p[k] = out_glu * (y + 1.f);
3748
+ }
3749
+
3750
+ #ifndef NDEBUG
3751
+ for (int k = 0; k < nc; k++) {
3752
+ const float x = dst_p[k];
3753
+ GGML_UNUSED(x);
3754
+ assert(!isnan(x));
3755
+ assert(!isinf(x));
3756
+ }
3757
+ #endif
3758
+ }
3759
+ }
3760
+
3761
+ static void ggml_compute_forward_swiglu_oai(
3762
+ const ggml_compute_params * params,
3763
+ ggml_tensor * dst) {
3764
+
3765
+ const ggml_tensor * src0 = dst->src[0];
3766
+
3767
+ switch (src0->type) {
3768
+ case GGML_TYPE_F32:
3769
+ {
3770
+ ggml_compute_forward_swiglu_oai_f32(params, dst);
3771
+ } break;
3772
+ default:
3773
+ {
3774
+ GGML_ABORT("fatal error");
3775
+ }
3776
+ }
3777
+ }
3778
+
3779
  // ggml_compute_forward_geglu_erf
3780
 
3781
  static void ggml_compute_forward_geglu_erf_f32(
 
4761
  case GGML_TYPE_Q5_0:
4762
  case GGML_TYPE_Q5_1:
4763
  case GGML_TYPE_Q8_0:
4764
+ case GGML_TYPE_MXFP4:
4765
  case GGML_TYPE_Q2_K:
4766
  case GGML_TYPE_Q3_K:
4767
  case GGML_TYPE_Q4_K:
 
5036
  case GGML_TYPE_Q5_1:
5037
  case GGML_TYPE_Q8_0:
5038
  case GGML_TYPE_Q8_1:
5039
+ case GGML_TYPE_MXFP4:
5040
  case GGML_TYPE_Q2_K:
5041
  case GGML_TYPE_Q3_K:
5042
  case GGML_TYPE_Q4_K:
 
5298
  case GGML_TYPE_Q5_1:
5299
  case GGML_TYPE_Q8_0:
5300
  case GGML_TYPE_Q8_1:
5301
+ case GGML_TYPE_MXFP4:
5302
  case GGML_TYPE_Q2_K:
5303
  case GGML_TYPE_Q3_K:
5304
  case GGML_TYPE_Q4_K:
 
5688
 
5689
  const ggml_tensor * src0 = dst->src[0];
5690
  const ggml_tensor * src1 = dst->src[1];
5691
+ const ggml_tensor * src2 = dst->src[2];
5692
 
5693
  assert(ggml_is_contiguous(dst));
5694
  assert(ggml_are_same_shape(src0, dst));
 
5723
 
5724
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5725
 
5726
+ // sinks
5727
+ const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
5728
+
5729
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5730
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5731
  for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
 
5768
  float max = -INFINITY;
5769
  ggml_vec_max_f32(ne00, &max, wp);
5770
 
5771
+ // if we have sinks, make a correction as if they were included in the softmax
5772
+ if (sk) {
5773
+ max = MAX(max, sk[i02]);
5774
+ }
5775
+
5776
  ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5777
  assert(sum > 0.0);
5778
 
5779
+ if (sk) {
5780
+ sum += (ggml_float) expf(sk[i02] - max);
5781
+ }
5782
+
5783
  sum = 1.0/sum;
5784
  ggml_vec_scale_f32(ne00, dp, sum);
5785
 
 
6014
  case GGML_TYPE_Q5_1:
6015
  case GGML_TYPE_Q8_0:
6016
  case GGML_TYPE_Q8_1:
6017
+ case GGML_TYPE_MXFP4:
6018
  case GGML_TYPE_Q2_K:
6019
  case GGML_TYPE_Q3_K:
6020
  case GGML_TYPE_Q4_K:
 
8168
 
8169
  static void ggml_compute_forward_flash_attn_ext_f16(
8170
  const ggml_compute_params * params,
 
 
 
 
8171
  ggml_tensor * dst) {
8172
 
8173
+ const ggml_tensor * q = dst->src[0];
8174
+ const ggml_tensor * k = dst->src[1];
8175
+ const ggml_tensor * v = dst->src[2];
8176
+ const ggml_tensor * mask = dst->src[3];
8177
+ const ggml_tensor * sinks = dst->src[4];
8178
+
8179
  GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8180
  GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8181
  GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
 
8370
  }
8371
  }
8372
 
8373
+ // sinks
8374
+ if (sinks) {
8375
+ const float s = ((float *)((char *) sinks->data))[h];
8376
+
8377
+ float ms = 1.0f;
8378
+ float vs = 1.0f;
8379
+
8380
+ if (s > M) {
8381
+ ms = expf(M - s);
8382
+ ggml_vec_scale_f32(DV, VKQ32, ms);
8383
+ } else {
8384
+ vs = expf(s - M);
8385
+ }
8386
+
8387
+ S = S*ms + vs;
8388
+ }
8389
+
8390
  // V /= S
8391
  const float S_inv = 1.0f/S;
8392
  ggml_vec_scale_f32(DV, VKQ32, S_inv);
 
8406
 
8407
  void ggml_compute_forward_flash_attn_ext(
8408
  const ggml_compute_params * params,
 
 
 
 
8409
  ggml_tensor * dst) {
8410
  switch (dst->op_params[3]) {
8411
  case GGML_PREC_DEFAULT:
8412
  case GGML_PREC_F32:
8413
  {
8414
  // uses F32 accumulators
8415
+ ggml_compute_forward_flash_attn_ext_f16(params, dst);
8416
  } break;
8417
  default:
8418
  {
 
9274
  {
9275
  ggml_compute_forward_swiglu(params, dst);
9276
  } break;
9277
+ case GGML_GLU_OP_SWIGLU_OAI:
9278
+ {
9279
+ ggml_compute_forward_swiglu_oai(params, dst);
9280
+ } break;
9281
  case GGML_GLU_OP_GEGLU_ERF:
9282
  {
9283
  ggml_compute_forward_geglu_erf(params, dst);
ggml/src/ggml-cpu/ops.h CHANGED
@@ -29,6 +29,7 @@ extern "C" {
29
 
30
  void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);
31
  void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 
32
  void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
33
  void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
34
  void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@@ -82,13 +83,7 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru
82
  void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
83
  void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
84
  void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
85
- void ggml_compute_forward_flash_attn_ext(
86
- const struct ggml_compute_params * params,
87
- const struct ggml_tensor * q,
88
- const struct ggml_tensor * k,
89
- const struct ggml_tensor * v,
90
- const struct ggml_tensor * mask,
91
- struct ggml_tensor * dst);
92
  void ggml_compute_forward_flash_attn_back(
93
  const struct ggml_compute_params * params,
94
  const bool masked,
 
29
 
30
  void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);
31
  void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);
32
+ void ggml_compute_forward_add_id(const struct ggml_compute_params * params, struct ggml_tensor * dst);
33
  void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
34
  void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
35
  void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 
83
  void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
84
  void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
85
  void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
86
+ void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 
 
 
 
 
 
87
  void ggml_compute_forward_flash_attn_back(
88
  const struct ggml_compute_params * params,
89
  const bool masked,
ggml/src/ggml-cpu/quants.c CHANGED
@@ -46,6 +46,10 @@ void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI
46
  quantize_row_q8_1_ref(x, y, k);
47
  }
48
 
 
 
 
 
49
  //
50
  // 2-6 bit quantization in super-blocks
51
  //
@@ -181,6 +185,37 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
181
  *s = sumf;
182
  }
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
185
  const int qk = QK8_0;
186
  const int nb = n / qk;
 
46
  quantize_row_q8_1_ref(x, y, k);
47
  }
48
 
49
+ void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
50
+ quantize_row_mxfp4_ref(x, y, k);
51
+ }
52
+
53
  //
54
  // 2-6 bit quantization in super-blocks
55
  //
 
185
  *s = sumf;
186
  }
187
 
188
+ void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
189
+ assert(nrc == 1);
190
+ UNUSED(nrc);
191
+ UNUSED(bx);
192
+ UNUSED(by);
193
+ UNUSED(bs);
194
+ assert(n % QK_MXFP4 == 0);
195
+ static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
196
+
197
+ const block_mxfp4 * GGML_RESTRICT x = vx;
198
+ const block_q8_0 * GGML_RESTRICT y = vy;
199
+
200
+ const int nb = n / QK_MXFP4;
201
+
202
+ int ib = 0;
203
+ float sumf = 0;
204
+
205
+ for (; ib < nb; ++ib) {
206
+ const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
207
+
208
+ int sumi1 = 0;
209
+ int sumi2 = 0;
210
+ for (int j = 0; j < QK_MXFP4/2; ++j) {
211
+ sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
212
+ sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
213
+ }
214
+ sumf += d * (sumi1 + sumi2);
215
+ }
216
+ *s = sumf;
217
+ }
218
+
219
  void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
220
  const int qk = QK8_0;
221
  const int nb = n / qk;
ggml/src/ggml-cpu/quants.h CHANGED
@@ -19,6 +19,8 @@ void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
19
  void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
20
  void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
21
 
 
 
22
  void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
23
  void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
24
  void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -39,6 +41,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
39
  void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
40
  void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
41
 
 
 
42
  void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
43
  void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
44
  void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -67,8 +71,12 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
67
  void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
68
  void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
69
  void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
 
 
70
  void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
71
  void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
72
  void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
73
  void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
74
  void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
19
  void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
20
  void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
21
 
22
+ void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
23
+
24
  void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
25
  void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
26
  void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 
41
  void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
42
  void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
43
 
44
+ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
45
+
46
  void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
47
  void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
48
  void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
71
  void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
72
  void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
73
  void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
74
+
75
+ void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
76
+
77
  void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
78
  void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
79
+
80
  void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
81
  void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
82
  void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
ggml/src/ggml-cpu/vec.h CHANGED
@@ -55,7 +55,22 @@ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x)
55
 
56
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
57
  inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
58
- inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
60
  for (int i = 0; i < n; ++i) {
61
  z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
@@ -992,9 +1007,9 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
992
 
993
  inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
994
  for (int i = 0; i < n; ++i) {
995
- float v = GGML_CPU_FP16_TO_FP32(x[i]);
996
- float w = GGML_CPU_FP16_TO_FP32(g[i]);
997
- y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
998
  }
999
  }
1000
 
 
55
 
56
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
57
  inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
58
+
59
+ inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
60
+ int i = 0;
61
+ #if defined(__AVX2__)
62
+ for (; i + 7 < n; i += 8) {
63
+ __m256 vx = _mm256_loadu_ps(x + i);
64
+ __m256 vy = _mm256_loadu_ps(y + i);
65
+ __m256 vz = _mm256_add_ps(vx, vy);
66
+ _mm256_storeu_ps(z + i, vz);
67
+ }
68
+ #endif
69
+ for (; i < n; ++i) {
70
+ z[i] = x[i] + y[i];
71
+ }
72
+ }
73
+
74
  inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
75
  for (int i = 0; i < n; ++i) {
76
  z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
 
1007
 
1008
  inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
1009
  for (int i = 0; i < n; ++i) {
1010
+ float xi = GGML_CPU_FP16_TO_FP32(x[i]);
1011
+ float gi = GGML_CPU_FP16_TO_FP32(g[i]);
1012
+ y[i] = GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi);
1013
  }
1014
  }
1015
 
ggml/src/ggml-cuda/add-id.cu ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "add-id.cuh"
2
+
3
+ static __global__ void add_id_kernel(
4
+ const float * src0, const float * src1, const int32_t * src2, float * dst,
5
+ int64_t ne0, int64_t ne1,
6
+ size_t nb01, size_t nb02,
7
+ size_t nb11,
8
+ size_t nb21
9
+ ) {
10
+
11
+ const int64_t i1 = blockIdx.x;
12
+ const int64_t i2 = blockIdx.y;
13
+
14
+ const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21);
15
+
16
+ const size_t nb1 = ne0 * sizeof(float);
17
+ const size_t nb2 = ne1 * nb1;
18
+
19
+ float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
20
+ const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02);
21
+ const float * src1_row = (const float *)((char *)src1 + i11*nb11);
22
+
23
+ for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
24
+ dst_row[i0] = src0_row[i0] + src1_row[i0];
25
+ }
26
+ }
27
+
28
+ void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
29
+ const ggml_tensor * src0 = dst->src[0];
30
+ const ggml_tensor * src1 = dst->src[1];
31
+ const ggml_tensor * src2 = dst->src[2];
32
+
33
+ GGML_TENSOR_TERNARY_OP_LOCALS
34
+
35
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
36
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
37
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
38
+ GGML_ASSERT(src2->type == GGML_TYPE_I32);
39
+
40
+ GGML_ASSERT(nb00 == sizeof(float));
41
+ GGML_ASSERT(nb10 == sizeof(float));
42
+ GGML_ASSERT(nb20 == sizeof(int32_t));
43
+
44
+ const float * src0_d = (const float *)src0->data;
45
+ const float * src1_d = (const float *)src1->data;
46
+ const int32_t * src2_d = (const int32_t *)src2->data;
47
+ float * dst_d = (float *)dst->data;
48
+
49
+ int threads = std::min((int)ne00, 768); // cols
50
+ dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
51
+ add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
52
+ src0_d, src1_d, src2_d, dst_d,
53
+ ne0, ne1,
54
+ nb01, nb02,
55
+ nb11,
56
+ nb21
57
+ );
58
+ }
ggml/src/ggml-cuda/add-id.cuh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/common.cuh CHANGED
@@ -1,6 +1,7 @@
1
  #pragma once
2
 
3
  #include "ggml.h"
 
4
  #include "ggml-cuda.h"
5
 
6
  #include <cstdint>
@@ -549,6 +550,24 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
549
  #endif // defined(GGML_USE_HIP)
550
  }
551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
  typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
553
 
554
  static __device__ __forceinline__ float get_alibi_slope(
@@ -607,6 +626,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
607
  static constexpr int qi = QI8_0;
608
  };
609
 
 
 
 
 
 
 
 
610
  template<>
611
  struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
612
  static constexpr int qk = QK_K;
 
1
  #pragma once
2
 
3
  #include "ggml.h"
4
+ #include "ggml-impl.h"
5
  #include "ggml-cuda.h"
6
 
7
  #include <cstdint>
 
550
  #endif // defined(GGML_USE_HIP)
551
  }
552
 
553
+ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
554
+ #if CUDART_VERSION >= 12080
555
+ const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
556
+ return (float) e;
557
+ #else
558
+ uint32_t bits;
559
+ if (x == 0) {
560
+ bits = 0x00400000;
561
+ } else {
562
+ bits = (uint32_t) x << 23;
563
+ }
564
+
565
+ float result;
566
+ memcpy(&result, &bits, sizeof(float));
567
+ return result;
568
+ #endif // CUDART_VERSION >= 12050
569
+ }
570
+
571
  typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
572
 
573
  static __device__ __forceinline__ float get_alibi_slope(
 
626
  static constexpr int qi = QI8_0;
627
  };
628
 
629
+ template<>
630
+ struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
631
+ static constexpr int qk = QK_MXFP4;
632
+ static constexpr int qr = QR_MXFP4;
633
+ static constexpr int qi = QI_MXFP4;
634
+ };
635
+
636
  template<>
637
  struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
638
  static constexpr int qk = QK_K;
ggml/src/ggml-cuda/convert.cu CHANGED
@@ -465,6 +465,24 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
465
  }
466
  }
467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
469
  static void dequantize_block_cuda(const void * vx, dst_t * y,
470
  const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
@@ -588,6 +606,12 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
588
  dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
589
  }
590
 
 
 
 
 
 
 
591
  template <typename src_t, typename dst_t>
592
  static __global__ void convert_unary(
593
  const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
@@ -677,6 +701,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
677
  return dequantize_row_iq4_xs_cuda;
678
  case GGML_TYPE_IQ3_S:
679
  return dequantize_row_iq3_s_cuda;
 
 
680
  case GGML_TYPE_F32:
681
  return convert_unary_cont_cuda<float>;
682
  case GGML_TYPE_BF16:
@@ -726,6 +752,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
726
  return dequantize_row_iq4_xs_cuda;
727
  case GGML_TYPE_IQ3_S:
728
  return dequantize_row_iq3_s_cuda;
 
 
729
  case GGML_TYPE_F16:
730
  return convert_unary_cont_cuda<half>;
731
  case GGML_TYPE_BF16:
 
465
  }
466
  }
467
 
468
+ template<typename dst_t>
469
+ static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
470
+
471
+ const int64_t i = blockIdx.x;
472
+ const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
473
+
474
+ const int64_t tid = threadIdx.x;
475
+ const int64_t il = tid/8; // 0...3
476
+ const int64_t ib = tid%8; // 0...7
477
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
478
+ const uint8_t * q4 = x[ib].qs + 4*il;
479
+ const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
480
+ for (int j = 0; j < 4; ++j) {
481
+ y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
482
+ y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f;
483
+ }
484
+ }
485
+
486
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
487
  static void dequantize_block_cuda(const void * vx, dst_t * y,
488
  const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
 
606
  dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
607
  }
608
 
609
+ template<typename dst_t>
610
+ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
611
+ const int nb = (k + QK_K - 1) / QK_K;
612
+ dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
613
+ }
614
+
615
  template <typename src_t, typename dst_t>
616
  static __global__ void convert_unary(
617
  const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
 
701
  return dequantize_row_iq4_xs_cuda;
702
  case GGML_TYPE_IQ3_S:
703
  return dequantize_row_iq3_s_cuda;
704
+ case GGML_TYPE_MXFP4:
705
+ return dequantize_row_mxfp4_cuda;
706
  case GGML_TYPE_F32:
707
  return convert_unary_cont_cuda<float>;
708
  case GGML_TYPE_BF16:
 
752
  return dequantize_row_iq4_xs_cuda;
753
  case GGML_TYPE_IQ3_S:
754
  return dequantize_row_iq3_s_cuda;
755
+ case GGML_TYPE_MXFP4:
756
+ return dequantize_row_mxfp4_cuda;
757
  case GGML_TYPE_F16:
758
  return convert_unary_cont_cuda<half>;
759
  case GGML_TYPE_BF16:
ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -15,6 +15,7 @@ typedef void (* fattn_kernel_t)(
15
  const char * __restrict__ K,
16
  const char * __restrict__ V,
17
  const char * __restrict__ mask,
 
18
  const int * __restrict__ KV_max,
19
  float * __restrict__ dst,
20
  float2 * __restrict__ dst_meta,
@@ -736,7 +737,8 @@ void launch_fattn(
736
 
737
  GGML_ASSERT(V || is_mla);
738
 
739
- const ggml_tensor * mask = dst->src[3];
 
740
 
741
  ggml_tensor * KQV = dst;
742
 
@@ -940,6 +942,7 @@ void launch_fattn(
940
  K_data,
941
  V_data,
942
  mask ? ((const char *) mask->data) : nullptr,
 
943
  KV_max.ptr,
944
  !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
945
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
 
15
  const char * __restrict__ K,
16
  const char * __restrict__ V,
17
  const char * __restrict__ mask,
18
+ const char * __restrict__ sinks,
19
  const int * __restrict__ KV_max,
20
  float * __restrict__ dst,
21
  float2 * __restrict__ dst_meta,
 
737
 
738
  GGML_ASSERT(V || is_mla);
739
 
740
+ const ggml_tensor * mask = dst->src[3];
741
+ const ggml_tensor * sinks = dst->src[4];
742
 
743
  ggml_tensor * KQV = dst;
744
 
 
942
  K_data,
943
  V_data,
944
  mask ? ((const char *) mask->data) : nullptr,
945
+ sinks ? ((const char *) sinks->data) : nullptr,
946
  KV_max.ptr,
947
  !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
948
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
ggml/src/ggml-cuda/fattn-mma-f16.cuh CHANGED
@@ -1206,6 +1206,7 @@ static __global__ void flash_attn_ext_f16(
1206
  const char * __restrict__ K,
1207
  const char * __restrict__ V,
1208
  const char * __restrict__ mask,
 
1209
  const int * __restrict__ KV_max,
1210
  float * __restrict__ dst,
1211
  float2 * __restrict__ dst_meta,
@@ -1267,6 +1268,7 @@ static __global__ void flash_attn_ext_f16(
1267
  // kb0 == k start index when in the output tile.
1268
  int kb0_start = kbc % iter_k;
1269
  int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
 
1270
  while (kbc < kbc_stop && kb0_stop == iter_k) {
1271
  const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1272
  const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
@@ -1340,7 +1342,7 @@ static __global__ void flash_attn_ext_f16(
1340
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1341
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1342
  #else
1343
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
1344
  GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
1345
  GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
1346
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
 
1206
  const char * __restrict__ K,
1207
  const char * __restrict__ V,
1208
  const char * __restrict__ mask,
1209
+ const char * __restrict__ sinks,
1210
  const int * __restrict__ KV_max,
1211
  float * __restrict__ dst,
1212
  float2 * __restrict__ dst_meta,
 
1268
  // kb0 == k start index when in the output tile.
1269
  int kb0_start = kbc % iter_k;
1270
  int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
1271
+
1272
  while (kbc < kbc_stop && kb0_stop == iter_k) {
1273
  const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1274
  const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
 
1342
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1343
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1344
  #else
1345
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
1346
  GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
1347
  GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
1348
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
ggml/src/ggml-cuda/fattn-tile-f16.cu CHANGED
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f16(
13
  const char * __restrict__ K,
14
  const char * __restrict__ V,
15
  const char * __restrict__ mask,
 
16
  const int * __restrict__ KV_max,
17
  float * __restrict__ dst,
18
  float2 * __restrict__ dst_meta,
@@ -272,7 +273,7 @@ static __global__ void flash_attn_tile_ext_f16(
272
  }
273
  }
274
  #else
275
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
276
  GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
277
  GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
278
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
 
13
  const char * __restrict__ K,
14
  const char * __restrict__ V,
15
  const char * __restrict__ mask,
16
+ const char * __restrict__ sinks,
17
  const int * __restrict__ KV_max,
18
  float * __restrict__ dst,
19
  float2 * __restrict__ dst_meta,
 
273
  }
274
  }
275
  #else
276
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
277
  GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
278
  GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
279
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
ggml/src/ggml-cuda/fattn-tile-f32.cu CHANGED
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f32(
13
  const char * __restrict__ K,
14
  const char * __restrict__ V,
15
  const char * __restrict__ mask,
 
16
  const int * __restrict__ KV_max,
17
  float * __restrict__ dst,
18
  float2 * __restrict__ dst_meta,
@@ -37,7 +38,7 @@ static __global__ void flash_attn_tile_ext_f32(
37
  return;
38
  #endif // FP16_MMA_AVAILABLE
39
  if (use_logit_softcap && !(D == 128 || D == 256)) {
40
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
41
  GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
42
  GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
43
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
 
13
  const char * __restrict__ K,
14
  const char * __restrict__ V,
15
  const char * __restrict__ mask,
16
+ const char * __restrict__ sinks,
17
  const int * __restrict__ KV_max,
18
  float * __restrict__ dst,
19
  float2 * __restrict__ dst_meta,
 
38
  return;
39
  #endif // FP16_MMA_AVAILABLE
40
  if (use_logit_softcap && !(D == 128 || D == 256)) {
41
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
42
  GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
43
  GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
44
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
ggml/src/ggml-cuda/fattn-vec-f16.cuh CHANGED
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16(
16
  const char * __restrict__ K,
17
  const char * __restrict__ V,
18
  const char * __restrict__ mask,
 
19
  const int * __restrict__ KV_max,
20
  float * __restrict__ dst,
21
  float2 * __restrict__ dst_meta,
@@ -61,7 +62,8 @@ static __global__ void flash_attn_vec_ext_f16(
61
  K += nb13*sequence + nb12*(head / gqa_ratio);
62
  V += nb23*sequence + nb22*(head / gqa_ratio);
63
 
64
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
 
65
 
66
  const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
67
  const half slopeh = __float2half(slopef);
@@ -75,11 +77,12 @@ static __global__ void flash_attn_vec_ext_f16(
75
  half2 * KQ2 = (half2 *) KQ;
76
 
77
  half kqmax[ncols];
 
78
  #pragma unroll
79
  for (int j = 0; j < ncols; ++j) {
80
  kqmax[j] = -HALF_MAX_HALF;
 
81
  }
82
- half kqsum[ncols] = {0.0f};
83
 
84
  __shared__ half kqmax_shared[ncols][WARP_SIZE];
85
  __shared__ half kqsum_shared[ncols][WARP_SIZE];
@@ -283,6 +286,39 @@ static __global__ void flash_attn_vec_ext_f16(
283
  __syncthreads();
284
  }
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  #pragma unroll
287
  for (int j = 0; j < ncols; ++j) {
288
  kqsum[j] = warp_reduce_sum((float)kqsum[j]);
@@ -313,7 +349,7 @@ static __global__ void flash_attn_vec_ext_f16(
313
  dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
314
  }
315
  #else
316
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
317
  GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
318
  GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
319
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
 
16
  const char * __restrict__ K,
17
  const char * __restrict__ V,
18
  const char * __restrict__ mask,
19
+ const char * __restrict__ sinks,
20
  const int * __restrict__ KV_max,
21
  float * __restrict__ dst,
22
  float2 * __restrict__ dst_meta,
 
62
  K += nb13*sequence + nb12*(head / gqa_ratio);
63
  V += nb23*sequence + nb22*(head / gqa_ratio);
64
 
65
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
66
+ const float * sinksf = (const float *) (sinks);
67
 
68
  const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
69
  const half slopeh = __float2half(slopef);
 
77
  half2 * KQ2 = (half2 *) KQ;
78
 
79
  half kqmax[ncols];
80
+ half kqsum[ncols];
81
  #pragma unroll
82
  for (int j = 0; j < ncols; ++j) {
83
  kqmax[j] = -HALF_MAX_HALF;
84
+ kqsum[j] = 0.0f;
85
  }
 
86
 
87
  __shared__ half kqmax_shared[ncols][WARP_SIZE];
88
  __shared__ half kqsum_shared[ncols][WARP_SIZE];
 
286
  __syncthreads();
287
  }
288
 
289
+ if (sinksf && blockIdx.y == 0) {
290
+ const half sink = __float2half(sinksf[head]);
291
+
292
+ #pragma unroll
293
+ for (int j = 0; j < ncols; ++j) {
294
+ if (threadIdx.x == 0) {
295
+ kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
296
+ }
297
+ }
298
+
299
+ __syncthreads();
300
+
301
+ #pragma unroll
302
+ for (int j = 0; j < ncols; ++j) {
303
+ half kqmax_new_j = kqmax_shared[j][threadIdx.x];
304
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
305
+
306
+ const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
307
+ kqmax[j] = kqmax_new_j;
308
+
309
+ const half val = hexp(sink - kqmax[j]);
310
+ kqsum[j] = kqsum[j]*KQ_max_scale;
311
+
312
+ if (tid == 0) {
313
+ kqsum[j] += val;
314
+ }
315
+
316
+ VKQ[j] *= __half2half2(KQ_max_scale);
317
+ }
318
+
319
+ __syncthreads();
320
+ }
321
+
322
  #pragma unroll
323
  for (int j = 0; j < ncols; ++j) {
324
  kqsum[j] = warp_reduce_sum((float)kqsum[j]);
 
349
  dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
350
  }
351
  #else
352
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
353
  GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
354
  GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
355
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
ggml/src/ggml-cuda/fattn-vec-f32.cuh CHANGED
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32(
16
  const char * __restrict__ K,
17
  const char * __restrict__ V,
18
  const char * __restrict__ mask,
 
19
  const int * __restrict__ KV_max,
20
  float * __restrict__ dst,
21
  float2 * __restrict__ dst_meta,
@@ -72,7 +73,8 @@ static __global__ void flash_attn_vec_ext_f32(
72
  K += nb13*sequence + nb12*(head / gqa_ratio);
73
  V += nb23*sequence + nb22*(head / gqa_ratio);
74
 
75
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
 
76
 
77
  const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
78
 
@@ -88,11 +90,12 @@ static __global__ void flash_attn_vec_ext_f32(
88
  }
89
 
90
  float kqmax[ncols];
 
91
  #pragma unroll
92
  for (int j = 0; j < ncols; ++j) {
93
  kqmax[j] = -FLT_MAX/2.0f;
 
94
  }
95
- float kqsum[ncols] = {0.0f};
96
 
97
  __shared__ float kqmax_shared[ncols][WARP_SIZE];
98
  __shared__ float kqsum_shared[ncols][WARP_SIZE];
@@ -279,6 +282,39 @@ static __global__ void flash_attn_vec_ext_f32(
279
  __syncthreads();
280
  }
281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  #pragma unroll
283
  for (int j = 0; j < ncols; ++j) {
284
  kqsum[j] = warp_reduce_sum(kqsum[j]);
 
16
  const char * __restrict__ K,
17
  const char * __restrict__ V,
18
  const char * __restrict__ mask,
19
+ const char * __restrict__ sinks,
20
  const int * __restrict__ KV_max,
21
  float * __restrict__ dst,
22
  float2 * __restrict__ dst_meta,
 
73
  K += nb13*sequence + nb12*(head / gqa_ratio);
74
  V += nb23*sequence + nb22*(head / gqa_ratio);
75
 
76
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
77
+ const float * sinksf = (const float *) (sinks);
78
 
79
  const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
80
 
 
90
  }
91
 
92
  float kqmax[ncols];
93
+ float kqsum[ncols];
94
  #pragma unroll
95
  for (int j = 0; j < ncols; ++j) {
96
  kqmax[j] = -FLT_MAX/2.0f;
97
+ kqsum[j] = 0.0f;
98
  }
 
99
 
100
  __shared__ float kqmax_shared[ncols][WARP_SIZE];
101
  __shared__ float kqsum_shared[ncols][WARP_SIZE];
 
282
  __syncthreads();
283
  }
284
 
285
+ if (sinksf && blockIdx.y == 0) {
286
+ const float sink = sinksf[head];
287
+
288
+ #pragma unroll
289
+ for (int j = 0; j < ncols; ++j) {
290
+ if (threadIdx.x == 0) {
291
+ kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
292
+ }
293
+ }
294
+
295
+ __syncthreads();
296
+
297
+ #pragma unroll
298
+ for (int j = 0; j < ncols; ++j) {
299
+ float kqmax_new_j = kqmax_shared[j][threadIdx.x];
300
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
301
+
302
+ const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
303
+ kqmax[j] = kqmax_new_j;
304
+
305
+ const float val = expf(sink - kqmax[j]);
306
+ kqsum[j] = kqsum[j]*KQ_max_scale;
307
+
308
+ if (tid == 0) {
309
+ kqsum[j] += val;
310
+ }
311
+
312
+ VKQ[j] *= KQ_max_scale;
313
+ }
314
+
315
+ __syncthreads();
316
+ }
317
+
318
  #pragma unroll
319
  for (int j = 0; j < ncols; ++j) {
320
  kqsum[j] = warp_reduce_sum(kqsum[j]);
ggml/src/ggml-cuda/fattn-wmma-f16.cu CHANGED
@@ -29,6 +29,7 @@ static __global__ void flash_attn_ext_f16(
29
  const char * __restrict__ K,
30
  const char * __restrict__ V,
31
  const char * __restrict__ mask,
 
32
  const int * __restrict__ KV_max,
33
  float * __restrict__ dst,
34
  float2 * __restrict__ dst_meta,
@@ -423,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
423
  dst_meta[j_dst_unrolled] = dst_meta_val;
424
  }
425
  #else
426
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
427
  GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
428
  GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
429
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
 
29
  const char * __restrict__ K,
30
  const char * __restrict__ V,
31
  const char * __restrict__ mask,
32
+ const char * __restrict__ sinks,
33
  const int * __restrict__ KV_max,
34
  float * __restrict__ dst,
35
  float2 * __restrict__ dst_meta,
 
424
  dst_meta[j_dst_unrolled] = dst_meta_val;
425
  }
426
  #else
427
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
428
  GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
429
  GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
430
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
ggml/src/ggml-cuda/fattn.cu CHANGED
@@ -269,17 +269,28 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
269
  }
270
 
271
  void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
272
- const ggml_tensor * KQV = dst;
273
- const ggml_tensor * Q = dst->src[0];
274
- const ggml_tensor * K = dst->src[1];
275
- const ggml_tensor * V = dst->src[2];
276
- const ggml_tensor * mask = dst->src[3];
 
277
 
278
  ggml_cuda_set_device(ctx.device);
279
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
280
  const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
281
  const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
282
 
 
 
 
 
 
 
 
 
 
 
283
  #if defined(GGML_HIP_ROCWMMA_FATTN)
284
  if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
285
  ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
 
269
  }
270
 
271
  void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
272
+ const ggml_tensor * KQV = dst;
273
+ const ggml_tensor * Q = dst->src[0];
274
+ const ggml_tensor * K = dst->src[1];
275
+ const ggml_tensor * V = dst->src[2];
276
+ const ggml_tensor * mask = dst->src[3];
277
+ const ggml_tensor * sinks = dst->src[4];
278
 
279
  ggml_cuda_set_device(ctx.device);
280
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
281
  const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
282
  const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
283
 
284
+ // TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
285
+ if (sinks) {
286
+ if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
287
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
288
+ } else {
289
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
290
+ }
291
+ return;
292
+ }
293
+
294
  #if defined(GGML_HIP_ROCWMMA_FATTN)
295
  if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
296
  ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -4,6 +4,7 @@
4
 
5
  #include "ggml-cuda/common.cuh"
6
  #include "ggml-cuda/acc.cuh"
 
7
  #include "ggml-cuda/arange.cuh"
8
  #include "ggml-cuda/argmax.cuh"
9
  #include "ggml-cuda/argsort.cuh"
@@ -2259,6 +2260,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2259
  case GGML_OP_ADD1: // TODO: more efficient implementation
2260
  ggml_cuda_op_add(ctx, dst);
2261
  break;
 
 
 
2262
  case GGML_OP_SUB:
2263
  ggml_cuda_op_sub(ctx, dst);
2264
  break;
@@ -2333,6 +2337,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2333
  case GGML_GLU_OP_SWIGLU:
2334
  ggml_cuda_op_swiglu(ctx, dst);
2335
  break;
 
 
 
2336
  case GGML_GLU_OP_GEGLU_ERF:
2337
  ggml_cuda_op_geglu_erf(ctx, dst);
2338
  break;
@@ -2607,6 +2614,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
2607
 
2608
  const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
2609
  const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
 
 
 
2610
 
2611
  for (int i = 0; i < cgraph->n_nodes; i++) {
2612
  ggml_tensor * node = cgraph->nodes[i];
@@ -2629,7 +2639,13 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
2629
  #endif
2630
  }
2631
 
2632
- if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true)) {
 
 
 
 
 
 
2633
  // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
2634
  // by means of matching node names. See
2635
  // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
@@ -3227,6 +3243,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3227
  case GGML_GLU_OP_REGLU:
3228
  case GGML_GLU_OP_GEGLU:
3229
  case GGML_GLU_OP_SWIGLU:
 
3230
  case GGML_GLU_OP_GEGLU_ERF:
3231
  case GGML_GLU_OP_GEGLU_QUICK:
3232
  return ggml_is_contiguous_1(op->src[0]);
@@ -3277,6 +3294,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3277
  case GGML_TYPE_Q5_0:
3278
  case GGML_TYPE_Q5_1:
3279
  case GGML_TYPE_Q8_0:
 
3280
  case GGML_TYPE_Q2_K:
3281
  case GGML_TYPE_Q3_K:
3282
  case GGML_TYPE_Q4_K:
@@ -3423,6 +3441,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3423
  case GGML_OP_PERMUTE:
3424
  case GGML_OP_TRANSPOSE:
3425
  case GGML_OP_ADD:
 
3426
  case GGML_OP_ADD1:
3427
  case GGML_OP_SUB:
3428
  case GGML_OP_MUL:
@@ -3503,6 +3522,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3503
  const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
3504
  return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
3505
  }
 
 
 
 
3506
  if (op->src[0]->ne[0] == 192) {
3507
  return false;
3508
  }
 
4
 
5
  #include "ggml-cuda/common.cuh"
6
  #include "ggml-cuda/acc.cuh"
7
+ #include "ggml-cuda/add-id.cuh"
8
  #include "ggml-cuda/arange.cuh"
9
  #include "ggml-cuda/argmax.cuh"
10
  #include "ggml-cuda/argsort.cuh"
 
2260
  case GGML_OP_ADD1: // TODO: more efficient implementation
2261
  ggml_cuda_op_add(ctx, dst);
2262
  break;
2263
+ case GGML_OP_ADD_ID:
2264
+ ggml_cuda_op_add_id(ctx, dst);
2265
+ break;
2266
  case GGML_OP_SUB:
2267
  ggml_cuda_op_sub(ctx, dst);
2268
  break;
 
2337
  case GGML_GLU_OP_SWIGLU:
2338
  ggml_cuda_op_swiglu(ctx, dst);
2339
  break;
2340
+ case GGML_GLU_OP_SWIGLU_OAI:
2341
+ ggml_cuda_op_swiglu_oai(ctx, dst);
2342
+ break;
2343
  case GGML_GLU_OP_GEGLU_ERF:
2344
  ggml_cuda_op_geglu_erf(ctx, dst);
2345
  break;
 
2614
 
2615
  const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
2616
  const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
2617
+ const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
2618
+ const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
2619
+ const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
2620
 
2621
  for (int i = 0; i < cgraph->n_nodes; i++) {
2622
  ggml_tensor * node = cgraph->nodes[i];
 
2639
  #endif
2640
  }
2641
 
2642
+ if (node->op == GGML_OP_ADD &&
2643
+ node->src[1] && node->src[1]->ne[1] > 1 &&
2644
+ (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
2645
+ (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
2646
+ strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
2647
+ strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
2648
+ strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
2649
  // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
2650
  // by means of matching node names. See
2651
  // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
 
3243
  case GGML_GLU_OP_REGLU:
3244
  case GGML_GLU_OP_GEGLU:
3245
  case GGML_GLU_OP_SWIGLU:
3246
+ case GGML_GLU_OP_SWIGLU_OAI:
3247
  case GGML_GLU_OP_GEGLU_ERF:
3248
  case GGML_GLU_OP_GEGLU_QUICK:
3249
  return ggml_is_contiguous_1(op->src[0]);
 
3294
  case GGML_TYPE_Q5_0:
3295
  case GGML_TYPE_Q5_1:
3296
  case GGML_TYPE_Q8_0:
3297
+ case GGML_TYPE_MXFP4:
3298
  case GGML_TYPE_Q2_K:
3299
  case GGML_TYPE_Q3_K:
3300
  case GGML_TYPE_Q4_K:
 
3441
  case GGML_OP_PERMUTE:
3442
  case GGML_OP_TRANSPOSE:
3443
  case GGML_OP_ADD:
3444
+ case GGML_OP_ADD_ID:
3445
  case GGML_OP_ADD1:
3446
  case GGML_OP_SUB:
3447
  case GGML_OP_MUL:
 
3522
  const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
3523
  return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
3524
  }
3525
+ // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
3526
+ if (op->src[4] && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { // currently only sinks for head_size 64 and 128 are supported
3527
+ return false;
3528
+ }
3529
  if (op->src[0]->ne[0] == 192) {
3530
  return false;
3531
  }
ggml/src/ggml-cuda/im2col.cu CHANGED
@@ -1,7 +1,5 @@
1
  #include "im2col.cuh"
2
 
3
- #define MIN(a, b) (a) < (b) ? (a) : (b)
4
-
5
  #define MAX_GRIDDIM_Z 65535
6
 
7
  template <typename T>
@@ -38,6 +36,9 @@ static __global__ void im2col_kernel(
38
  dst[offset_dst] = x[offset_src + iih * IW + iiw];
39
  }
40
  }
 
 
 
41
  }
42
 
43
  // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
 
1
  #include "im2col.cuh"
2
 
 
 
3
  #define MAX_GRIDDIM_Z 65535
4
 
5
  template <typename T>
 
36
  dst[offset_dst] = x[offset_src + iih * IW + iiw];
37
  }
38
  }
39
+
40
+ GGML_UNUSED(IC);
41
+ GGML_UNUSED(KH);
42
  }
43
 
44
  // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
ggml/src/ggml-cuda/mmq.cu CHANGED
@@ -20,6 +20,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
20
  case GGML_TYPE_Q8_0:
21
  mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
22
  break;
 
 
 
23
  case GGML_TYPE_Q2_K:
24
  mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
25
  break;
@@ -282,6 +285,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
282
  case GGML_TYPE_Q5_0:
283
  case GGML_TYPE_Q5_1:
284
  case GGML_TYPE_Q8_0:
 
285
  case GGML_TYPE_Q2_K:
286
  case GGML_TYPE_Q3_K:
287
  case GGML_TYPE_Q4_K:
 
20
  case GGML_TYPE_Q8_0:
21
  mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
22
  break;
23
+ case GGML_TYPE_MXFP4:
24
+ mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
25
+ break;
26
  case GGML_TYPE_Q2_K:
27
  mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
28
  break;
 
285
  case GGML_TYPE_Q5_0:
286
  case GGML_TYPE_Q5_1:
287
  case GGML_TYPE_Q8_0:
288
+ case GGML_TYPE_MXFP4:
289
  case GGML_TYPE_Q2_K:
290
  case GGML_TYPE_Q3_K:
291
  case GGML_TYPE_Q4_K:
ggml/src/ggml-cuda/mmq.cuh CHANGED
@@ -58,6 +58,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
58
  return MMQ_Q8_1_DS_LAYOUT_DS4;
59
  case GGML_TYPE_Q8_0:
60
  return MMQ_Q8_1_DS_LAYOUT_D4;
 
 
61
  case GGML_TYPE_Q2_K:
62
  return MMQ_Q8_1_DS_LAYOUT_D2S6;
63
  case GGML_TYPE_Q3_K:
@@ -170,6 +172,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
170
  case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
171
  case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
172
  case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
 
173
  case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
174
  case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
175
  case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
@@ -206,6 +209,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
206
  case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
207
  case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
208
  case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
 
209
  case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
210
  case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
211
  case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
@@ -692,6 +696,71 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
692
  }
693
  }
694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  template <int mmq_x, int mmq_y>
696
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
697
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -2268,7 +2337,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2268
  const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
2269
 
2270
  const int aux_q4 = get_int_b2(bxi->qs, kqsx);
2271
- const int2 v = get_int_from_table_16(aux_q4);
2272
  const int k0 = kbx * (2 * QI4_NL) + kqsx;
2273
 
2274
  #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
@@ -2707,7 +2776,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2707
  const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
2708
 
2709
  const int aux_q4 = get_int_b4(bxi->qs, kqsx);
2710
- const int2 v = get_int_from_table_16(aux_q4);
2711
  const int k0 = 8 * (kqsx / 4) + kqsx % 4;
2712
 
2713
  #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
@@ -2863,6 +2932,14 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
2863
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2864
  };
2865
 
 
 
 
 
 
 
 
 
2866
  template <int mmq_x, int mmq_y, bool need_check>
2867
  struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
2868
  static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
@@ -3642,6 +3719,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
3642
  extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
3643
  extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
3644
  extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
 
3645
  extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
3646
  extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
3647
  extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
 
58
  return MMQ_Q8_1_DS_LAYOUT_DS4;
59
  case GGML_TYPE_Q8_0:
60
  return MMQ_Q8_1_DS_LAYOUT_D4;
61
+ case GGML_TYPE_MXFP4:
62
+ return MMQ_Q8_1_DS_LAYOUT_D4;
63
  case GGML_TYPE_Q2_K:
64
  return MMQ_Q8_1_DS_LAYOUT_D2S6;
65
  case GGML_TYPE_Q3_K:
 
172
  case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
173
  case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
174
  case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
175
+ case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
176
  case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
177
  case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
178
  case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
 
209
  case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
210
  case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
211
  case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
212
+ case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
213
  case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
214
  case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
215
  case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
 
696
  }
697
  }
698
 
699
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
700
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
701
+ constexpr int nwarps = mmq_get_nwarps_device();
702
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
703
+
704
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
705
+ int * x_qs = (int *) x_tile;
706
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
707
+ #else
708
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
709
+ int * x_qs = (int *) x_tile;
710
+ float * x_df = (float *) (x_qs + txs.qs);
711
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
712
+
713
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
714
+ constexpr int nrows = warp_size / threads_per_row;
715
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
716
+ const int kbx = txi / QI_MXFP4;
717
+ const int kqsx = txi % QI_MXFP4;
718
+
719
+ #pragma unroll
720
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
721
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
722
+
723
+ if (need_check) {
724
+ i = min(i, i_max);
725
+ }
726
+
727
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
728
+
729
+ const int aux_q4 = get_int_b1(bxi->qs, kqsx);
730
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
731
+ const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
732
+
733
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
734
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
735
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
736
+ #else
737
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
738
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
739
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
740
+ }
741
+
742
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
743
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
744
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
745
+
746
+ #pragma unroll
747
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
748
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
749
+
750
+ if (need_check) {
751
+ i = min(i, i_max);
752
+ }
753
+
754
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
755
+
756
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
757
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
758
+ #else
759
+ x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
760
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
761
+ }
762
+ }
763
+
764
  template <int mmq_x, int mmq_y>
765
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
766
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 
2337
  const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
2338
 
2339
  const int aux_q4 = get_int_b2(bxi->qs, kqsx);
2340
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2341
  const int k0 = kbx * (2 * QI4_NL) + kqsx;
2342
 
2343
  #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
 
2776
  const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
2777
 
2778
  const int aux_q4 = get_int_b4(bxi->qs, kqsx);
2779
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2780
  const int k0 = 8 * (kqsx / 4) + kqsx % 4;
2781
 
2782
  #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
 
2932
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2933
  };
2934
 
2935
+ template <int mmq_x, int mmq_y, bool need_check>
2936
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
2937
+ static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
2938
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
2939
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2940
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2941
+ };
2942
+
2943
  template <int mmq_x, int mmq_y, bool need_check>
2944
  struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
2945
  static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
 
3719
  extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
3720
  extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
3721
  extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
3722
+ extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
3723
  extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
3724
  extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
3725
  extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
ggml/src/ggml-cuda/mmvq.cu CHANGED
@@ -13,6 +13,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
13
  case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
14
  case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
15
  case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
 
16
  case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
17
  case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
18
  case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
@@ -38,6 +39,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
38
  case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
39
  case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
40
  case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
 
41
  case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
42
  case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
43
  case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
@@ -384,6 +386,13 @@ static void mul_mat_vec_q_switch_type(
384
  nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
385
  stream);
386
  break;
 
 
 
 
 
 
 
387
  case GGML_TYPE_Q2_K:
388
  mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
389
  (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
 
13
  case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
14
  case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
15
  case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
16
+ case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
17
  case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
18
  case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
19
  case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
 
39
  case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
40
  case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
41
  case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
42
+ case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
43
  case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
44
  case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
45
  case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
 
386
  nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
387
  stream);
388
  break;
389
+ case GGML_TYPE_MXFP4:
390
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
391
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
392
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
393
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
394
+ stream);
395
+ break;
396
  case GGML_TYPE_Q2_K:
397
  mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
398
  (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
ggml/src/ggml-cuda/softmax.cu CHANGED
@@ -45,7 +45,7 @@ struct soft_max_params {
45
  #endif // __clang__
46
  template <bool use_shared, int ncols_template, int block_size_template, typename T>
47
  static __global__ void soft_max_f32(
48
- const float * x, const T * mask, float * dst, const soft_max_params p) {
49
  const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
50
 
51
  const int tid = threadIdx.x;
@@ -77,7 +77,7 @@ static __global__ void soft_max_f32(
77
  // shared memory buffer to cache values between iterations:
78
  float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
79
 
80
- float max_val = -INFINITY;
81
 
82
  #pragma unroll
83
  for (int col0 = 0; col0 < ncols; col0 += block_size) {
@@ -143,6 +143,10 @@ static __global__ void soft_max_f32(
143
  tmp = warp_reduce_sum(tmp);
144
  }
145
 
 
 
 
 
146
  const float inv_sum = 1.0f / tmp;
147
 
148
  #pragma unroll
@@ -183,7 +187,7 @@ static __global__ void soft_max_back_f32(
183
  }
184
 
185
  template<int... Ns, typename T>
186
- static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
187
  const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
188
  {
189
  const int id = ggml_cuda_get_device();
@@ -196,7 +200,7 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
196
  if (p.ncols == ncols) {
197
  CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
198
  soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
199
- (x, mask, dst, p);
200
  return true;
201
  }
202
  return false;
@@ -209,12 +213,12 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
209
 
210
  //default case
211
  CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
212
- soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
213
  }
214
 
215
 
216
  template<typename T>
217
- static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
218
  int nth = WARP_SIZE;
219
  const int64_t ncols_x = params.ncols;
220
 
@@ -230,10 +234,10 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
230
 
231
 
232
  if (nbytes_shared <= smpbo) {
233
- launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
234
  } else {
235
  const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
236
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
237
  }
238
  }
239
 
@@ -249,9 +253,11 @@ static void soft_max_back_f32_cuda(
249
  void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
250
  const ggml_tensor * src0 = dst->src[0];
251
  const ggml_tensor * src1 = dst->src[1];
 
252
 
253
  const float * src0_d = (const float *) src0->data;
254
  const void * src1_d = src1 ? (const void *) src1->data : nullptr;
 
255
  float * dst_d = (float *) dst->data;
256
 
257
  cudaStream_t stream = ctx.stream();
@@ -309,9 +315,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
309
  params.m1 = m1;
310
 
311
  if (use_f16) {
312
- soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream);
313
  } else {
314
- soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
315
  }
316
  }
317
 
 
45
  #endif // __clang__
46
  template <bool use_shared, int ncols_template, int block_size_template, typename T>
47
  static __global__ void soft_max_f32(
48
+ const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
49
  const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
50
 
51
  const int tid = threadIdx.x;
 
77
  // shared memory buffer to cache values between iterations:
78
  float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
79
 
80
+ float max_val = sinks ? sinks[i02] : -INFINITY;
81
 
82
  #pragma unroll
83
  for (int col0 = 0; col0 < ncols; col0 += block_size) {
 
143
  tmp = warp_reduce_sum(tmp);
144
  }
145
 
146
+ if (sinks) {
147
+ tmp += expf(sinks[i02] - max_val);
148
+ }
149
+
150
  const float inv_sum = 1.0f / tmp;
151
 
152
  #pragma unroll
 
187
  }
188
 
189
  template<int... Ns, typename T>
190
+ static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
191
  const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
192
  {
193
  const int id = ggml_cuda_get_device();
 
200
  if (p.ncols == ncols) {
201
  CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
202
  soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
203
+ (x, mask, sinks, dst, p);
204
  return true;
205
  }
206
  return false;
 
213
 
214
  //default case
215
  CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
216
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
217
  }
218
 
219
 
220
  template<typename T>
221
+ static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
222
  int nth = WARP_SIZE;
223
  const int64_t ncols_x = params.ncols;
224
 
 
234
 
235
 
236
  if (nbytes_shared <= smpbo) {
237
+ launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
238
  } else {
239
  const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
240
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
241
  }
242
  }
243
 
 
253
  void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
254
  const ggml_tensor * src0 = dst->src[0];
255
  const ggml_tensor * src1 = dst->src[1];
256
+ const ggml_tensor * src2 = dst->src[2];
257
 
258
  const float * src0_d = (const float *) src0->data;
259
  const void * src1_d = src1 ? (const void *) src1->data : nullptr;
260
+ const void * src2_d = src2 ? (const void *) src2->data : nullptr;
261
  float * dst_d = (float *) dst->data;
262
 
263
  cudaStream_t stream = ctx.stream();
 
315
  params.m1 = m1;
316
 
317
  if (use_f16) {
318
+ soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
319
  } else {
320
+ soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
321
  }
322
  }
323
 
ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../mmq.cuh"
4
+
5
+ DECL_MMQ_CASE(GGML_TYPE_MXFP4);
ggml/src/ggml-cuda/unary.cu CHANGED
@@ -300,6 +300,81 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst
300
  ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
301
  }
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  /* silu_back */
304
 
305
  static __device__ __forceinline__ float op_silu_back(float grad, float x) {
 
300
  ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
301
  }
302
 
303
+ // swiglu_oai
304
+
305
+ template <typename T>
306
+ static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
307
+ const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
308
+
309
+ if (i >= k) {
310
+ return;
311
+ }
312
+
313
+ // perform base op and multiply with gate (either offset in same tensor or a separate one)
314
+ const int64_t j0 = (i / n) * o0 + (i % n);
315
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
316
+
317
+ float xi = x[j0];
318
+ float gi = g[j1];
319
+ xi = fminf(xi, limit);
320
+ gi = fmaxf(fminf(gi, limit), -limit);
321
+
322
+ float out_glu = xi / (1.0f + expf(-xi * alpha));
323
+ out_glu = out_glu * (1.0f + gi);
324
+
325
+ dst[i] = out_glu;
326
+ }
327
+
328
+ template <typename T>
329
+ static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
330
+ const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
331
+ swiglu_oai_kernel<<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
332
+ }
333
+
334
+ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
335
+ const ggml_tensor * src0 = dst->src[0];
336
+ const ggml_tensor * src1 = dst->src[1];
337
+ void * src0_d = src0->data;
338
+ void * src1_d = src1 ? src1->data : src0->data;
339
+ const int64_t src0_o = src0->nb[1];
340
+ const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
341
+ void * dst_d = dst->data;
342
+ const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
343
+ cudaStream_t stream = ctx.stream();
344
+
345
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
346
+ GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
347
+ GGML_ASSERT(ggml_is_contiguous(dst));
348
+
349
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
350
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
351
+ GGML_ASSERT(src0->type == dst->type);
352
+ GGML_ASSERT(dst->ne[0] == nc);
353
+ GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
354
+
355
+ if (src1) {
356
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
357
+ GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
358
+ GGML_ASSERT(src1->ne[0] == nc);
359
+ GGML_ASSERT(src0->type == src1->type);
360
+ }
361
+
362
+ //const int32_t swapped = ((const int32_t *) dst->op_params)[1];
363
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
364
+ const float alpha = ggml_get_op_params_f32(dst, 2);
365
+ const float limit = ggml_get_op_params_f32(dst, 3);
366
+
367
+ float * src0_p = (float *) src0_d;
368
+ float * src1_p = (float *) src1_d;
369
+
370
+ if (!src1) {
371
+ src0_p += swapped ? nc : 0;
372
+ src1_p += swapped ? 0 : nc;
373
+ }
374
+
375
+ swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
376
+ }
377
+
378
  /* silu_back */
379
 
380
  static __device__ __forceinline__ float op_silu_back(float grad, float x) {
ggml/src/ggml-cuda/unary.cuh CHANGED
@@ -67,6 +67,8 @@ void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
67
 
68
  void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
69
 
 
 
70
  void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
71
 
72
  void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
67
 
68
  void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
69
 
70
+ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
71
+
72
  void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
73
 
74
  void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/vecdotq.cuh CHANGED
@@ -1,8 +1,20 @@
1
  #pragma once
2
 
3
  #include "common.cuh"
 
4
  #include <cstdint>
5
 
 
 
 
 
 
 
 
 
 
 
 
6
  static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
7
  const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
8
 
@@ -16,6 +28,20 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
16
  return ((const int *) x)[i32]; // assume at least 4 byte alignment
17
  }
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
20
  // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
21
 
@@ -211,6 +237,30 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_
211
  return d8_1*sumf;
212
  }
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  #define VDR_Q2_K_Q8_1_MMVQ 1
215
  #define VDR_Q2_K_Q8_1_MMQ 4
216
 
@@ -1068,20 +1118,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
1068
  return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
1069
  }
1070
 
1071
- static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
1072
- const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
1073
- const int8_t * q0_8 = (const int8_t *) &q0_32;
1074
- const char4 val0_8 = make_char4(
1075
- kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
1076
-
1077
- const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
1078
- const int8_t * q1_8 = (const int8_t *) &q1_32;
1079
- const char4 val1_8 = make_char4(
1080
- kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
1081
-
1082
- return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
1083
- }
1084
-
1085
  #define VDR_IQ4_NL_Q8_1_MMVQ 2
1086
  #define VDR_IQ4_NL_Q8_1_MMQ 4
1087
 
@@ -1096,7 +1132,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
1096
  #pragma unroll
1097
  for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
1098
  const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
1099
- const int2 v = get_int_from_table_16(aux_q4);
1100
 
1101
  sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
1102
  sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
@@ -1118,7 +1154,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
1118
  #pragma unroll
1119
  for (int j = 0; j < 4; ++j) {
1120
  const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
1121
- const int2 v = get_int_from_table_16(aux_q4);
1122
 
1123
  const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
1124
  const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
 
1
  #pragma once
2
 
3
  #include "common.cuh"
4
+
5
  #include <cstdint>
6
 
7
+ static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
8
+ const uint8_t * x8 = (const uint8_t *) x;
9
+
10
+ int x32 = x8[4*i32 + 0] << 0;
11
+ x32 |= x8[4*i32 + 1] << 8;
12
+ x32 |= x8[4*i32 + 2] << 16;
13
+ x32 |= x8[4*i32 + 3] << 24;
14
+
15
+ return x32;
16
+ }
17
+
18
  static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
19
  const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
20
 
 
28
  return ((const int *) x)[i32]; // assume at least 4 byte alignment
29
  }
30
 
31
+ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
32
+ const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
33
+ const int8_t * q0_8 = (const int8_t *) &q0_32;
34
+ const char4 val0_8 = make_char4(
35
+ table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
36
+
37
+ const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
38
+ const int8_t * q1_8 = (const int8_t *) &q1_32;
39
+ const char4 val1_8 = make_char4(
40
+ table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
41
+
42
+ return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
43
+ }
44
+
45
  // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
46
  // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
47
 
 
237
  return d8_1*sumf;
238
  }
239
 
240
+ #define VDR_MXFP4_Q8_1_MMVQ 2
241
+ #define VDR_MXFP4_Q8_1_MMQ 4
242
+
243
+ static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
244
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
245
+
246
+ const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
247
+
248
+ const int * q8 = (const int *) bq8_1->qs + iqs;
249
+
250
+ int sumi = 0;
251
+ #pragma unroll
252
+ for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
253
+ const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
254
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
255
+
256
+ sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
257
+ sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
258
+ }
259
+
260
+ const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
261
+ return d * sumi;
262
+ }
263
+
264
  #define VDR_Q2_K_Q8_1_MMVQ 1
265
  #define VDR_Q2_K_Q8_1_MMQ 4
266
 
 
1118
  return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
1119
  }
1120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1121
  #define VDR_IQ4_NL_Q8_1_MMVQ 2
1122
  #define VDR_IQ4_NL_Q8_1_MMQ 4
1123
 
 
1132
  #pragma unroll
1133
  for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
1134
  const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
1135
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
1136
 
1137
  sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
1138
  sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
 
1154
  #pragma unroll
1155
  for (int j = 0; j < 4; ++j) {
1156
  const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
1157
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
1158
 
1159
  const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
1160
  const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
ggml/src/ggml-cuda/vendors/cuda.h CHANGED
@@ -6,6 +6,10 @@
6
  #include <cuda_bf16.h>
7
  #include <cuda_fp16.h>
8
 
 
 
 
 
9
  #if CUDART_VERSION < 11020
10
  #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
11
  #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
 
6
  #include <cuda_bf16.h>
7
  #include <cuda_fp16.h>
8
 
9
+ #if CUDART_VERSION >= 12050
10
+ #include <cuda_fp8.h>
11
+ #endif // CUDART_VERSION >= 12050
12
+
13
  #if CUDART_VERSION < 11020
14
  #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
15
  #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
ggml/src/ggml-impl.h CHANGED
@@ -410,6 +410,67 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
410
  #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
411
  #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  /**
414
  * Converts brain16 to float32.
415
  *
 
410
  #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
411
  #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
412
 
413
+ static inline float ggml_e8m0_to_fp32(uint8_t x) {
414
+ uint32_t bits; // Stores the raw bit representation of the float
415
+
416
+ // Handle special case for minimum exponent (denormalized float)
417
+ if (x == 0) {
418
+ // Bit pattern for 2^(-127):
419
+ // - Sign bit: 0 (positive)
420
+ // - Exponent: 0 (denormalized number)
421
+ // - Mantissa: 0x400000 (0.5 in fractional form)
422
+ // Value = 0.5 * 2^(-126) = 2^(-127)
423
+ bits = 0x00400000;
424
+ }
425
+ // note: disabled as we don't need to handle NaNs
426
+ //// Handle special case for NaN (all bits set)
427
+ //else if (x == 0xFF) {
428
+ // // Standard quiet NaN pattern:
429
+ // // - Sign bit: 0
430
+ // // - Exponent: all 1s (0xFF)
431
+ // // - Mantissa: 0x400000 (quiet NaN flag)
432
+ // bits = 0x7FC00000;
433
+ //}
434
+ // Normalized values (most common case)
435
+ else {
436
+ // Construct normalized float by shifting exponent into position:
437
+ // - Exponent field: 8 bits (positions 30-23)
438
+ // - Mantissa: 0 (implicit leading 1)
439
+ // Value = 2^(x - 127)
440
+ bits = (uint32_t) x << 23;
441
+ }
442
+
443
+ float result; // Final float value
444
+ // Safely reinterpret bit pattern as float without type-punning issues
445
+ memcpy(&result, &bits, sizeof(float));
446
+ return result;
447
+ }
448
+
449
+ // Equal to ggml_e8m0_to_fp32/2
450
+ // Useful with MXFP4 quantization since the E0M2 values are doubled
451
+ static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
452
+ uint32_t bits;
453
+
454
+ // For x < 2: use precomputed denormal patterns
455
+ if (x < 2) {
456
+ // 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
457
+ bits = 0x00200000 << x;
458
+ }
459
+ // For x >= 2: normalized exponent adjustment
460
+ else {
461
+ // 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
462
+ bits = (uint32_t)(x - 1) << 23;
463
+ }
464
+ // Note: NaNs are not handled here
465
+
466
+ float result;
467
+ memcpy(&result, &bits, sizeof(float));
468
+ return result;
469
+ }
470
+
471
+ #define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
472
+ #define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
473
+
474
  /**
475
  * Converts brain16 to float32.
476
  *
ggml/src/ggml-metal/ggml-metal-impl.h CHANGED
@@ -23,6 +23,9 @@
23
  #define N_R0_Q8_0 4
24
  #define N_SG_Q8_0 2
25
 
 
 
 
26
  #define N_R0_Q2_K 4
27
  #define N_SG_Q2_K 2
28
 
@@ -129,6 +132,15 @@ typedef struct {
129
  uint64_t o1[8];
130
  } ggml_metal_kargs_bin;
131
 
 
 
 
 
 
 
 
 
 
132
  typedef struct {
133
  int32_t ne00;
134
  int32_t ne01;
@@ -444,6 +456,8 @@ typedef struct{
444
  uint64_t nb1;
445
  int32_t i00;
446
  int32_t i10;
 
 
447
  } ggml_metal_kargs_glu;
448
 
449
  typedef struct {
 
23
  #define N_R0_Q8_0 4
24
  #define N_SG_Q8_0 2
25
 
26
+ #define N_R0_MXFP4 2
27
+ #define N_SG_MXFP4 2
28
+
29
  #define N_R0_Q2_K 4
30
  #define N_SG_Q2_K 2
31
 
 
132
  uint64_t o1[8];
133
  } ggml_metal_kargs_bin;
134
 
135
+ typedef struct {
136
+ int64_t ne0;
137
+ int64_t ne1;
138
+ size_t nb01;
139
+ size_t nb02;
140
+ size_t nb11;
141
+ size_t nb21;
142
+ } ggml_metal_kargs_add_id;
143
+
144
  typedef struct {
145
  int32_t ne00;
146
  int32_t ne01;
 
456
  uint64_t nb1;
457
  int32_t i00;
458
  int32_t i10;
459
+ float alpha;
460
+ float limit;
461
  } ggml_metal_kargs_glu;
462
 
463
  typedef struct {
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -195,6 +195,7 @@ enum ggml_metal_kernel_type {
195
  GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
196
  GGML_METAL_KERNEL_TYPE_DIV,
197
  GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
 
198
  GGML_METAL_KERNEL_TYPE_REPEAT_F32,
199
  GGML_METAL_KERNEL_TYPE_REPEAT_F16,
200
  GGML_METAL_KERNEL_TYPE_REPEAT_I32,
@@ -234,6 +235,7 @@ enum ggml_metal_kernel_type {
234
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
235
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
236
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
 
237
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
238
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
239
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
@@ -286,6 +288,7 @@ enum ggml_metal_kernel_type {
286
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
287
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
288
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
 
289
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
290
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
291
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
@@ -310,6 +313,10 @@ enum ggml_metal_kernel_type {
310
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
311
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
312
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
 
 
 
 
313
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
314
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
315
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
@@ -351,6 +358,7 @@ enum ggml_metal_kernel_type {
351
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
352
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
353
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
 
354
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
355
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
356
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
@@ -373,6 +381,7 @@ enum ggml_metal_kernel_type {
373
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
374
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
375
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
 
376
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
377
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
378
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
@@ -397,6 +406,7 @@ enum ggml_metal_kernel_type {
397
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
398
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
399
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
 
400
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
401
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
402
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
@@ -579,6 +589,7 @@ enum ggml_metal_kernel_type {
579
  GGML_METAL_KERNEL_TYPE_REGLU,
580
  GGML_METAL_KERNEL_TYPE_GEGLU,
581
  GGML_METAL_KERNEL_TYPE_SWIGLU,
 
582
  GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
583
  GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
584
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -1199,6 +1210,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1199
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
1200
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
1201
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
 
1202
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
1203
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
1204
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
@@ -1238,6 +1250,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1238
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
1239
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
1240
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
 
1241
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
1242
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
1243
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
@@ -1290,6 +1303,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1290
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
1291
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
1292
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
 
1293
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
1294
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
1295
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
@@ -1314,6 +1328,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1314
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
1315
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
1316
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
 
 
 
 
1317
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
1318
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
1319
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
@@ -1355,6 +1373,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1355
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
1356
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
1357
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
 
1358
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
1359
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
1360
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
@@ -1377,6 +1396,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1377
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
1378
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
1379
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
 
 
1380
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
1381
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
1382
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
@@ -1401,6 +1422,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1401
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
1402
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
1403
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
 
1404
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
1405
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
1406
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
@@ -1583,6 +1605,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1583
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
1584
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
1585
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
 
1586
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
1587
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
1588
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
@@ -1774,6 +1797,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1774
  case GGML_GLU_OP_REGLU:
1775
  case GGML_GLU_OP_GEGLU:
1776
  case GGML_GLU_OP_SWIGLU:
 
1777
  case GGML_GLU_OP_GEGLU_ERF:
1778
  case GGML_GLU_OP_GEGLU_QUICK:
1779
  return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
@@ -1791,6 +1815,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1791
  case GGML_OP_SUB:
1792
  case GGML_OP_MUL:
1793
  case GGML_OP_DIV:
 
1794
  return op->src[0]->type == GGML_TYPE_F32;
1795
  case GGML_OP_ACC:
1796
  case GGML_OP_REPEAT:
@@ -2042,6 +2067,7 @@ static int ggml_metal_encode_node(
2042
 
2043
  const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
2044
  const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
 
2045
  const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
2046
 
2047
  size_t offs_src0 = 0;
@@ -2291,6 +2317,38 @@ static int ggml_metal_encode_node(
2291
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2292
  }
2293
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2294
  case GGML_OP_REPEAT:
2295
  {
2296
  id<MTLComputePipelineState> pipeline;
@@ -2710,6 +2768,9 @@ static int ggml_metal_encode_node(
2710
  case GGML_GLU_OP_SWIGLU:
2711
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
2712
  break;
 
 
 
2713
  case GGML_GLU_OP_GEGLU_ERF:
2714
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
2715
  break;
@@ -2720,7 +2781,9 @@ static int ggml_metal_encode_node(
2720
  GGML_ABORT("fatal error");
2721
  }
2722
 
2723
- const int32_t swp = ((const int32_t *) dst->op_params)[1];
 
 
2724
 
2725
  const int32_t i00 = swp ? ne0 : 0;
2726
  const int32_t i10 = swp ? 0 : ne0;
@@ -2734,6 +2797,8 @@ static int ggml_metal_encode_node(
2734
  /*.nb1 =*/ nb1,
2735
  /*.i00 =*/ src1 ? 0 : i00,
2736
  /*.i10 =*/ src1 ? 0 : i10,
 
 
2737
  };
2738
 
2739
  [encoder setComputePipelineState:pipeline];
@@ -2992,8 +3057,13 @@ static int ggml_metal_encode_node(
2992
  } else {
2993
  [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
2994
  }
2995
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2996
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
 
 
 
 
 
2997
 
2998
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2999
 
@@ -3291,6 +3361,7 @@ static int ggml_metal_encode_node(
3291
  src0t == GGML_TYPE_Q5_0 ||
3292
  src0t == GGML_TYPE_Q5_1 ||
3293
  src0t == GGML_TYPE_Q8_0 ||
 
3294
  src0t == GGML_TYPE_IQ4_NL ||
3295
  false) && (ne11 >= 2 && ne11 <= 8)
3296
  ) ||
@@ -3383,6 +3454,14 @@ static int ggml_metal_encode_node(
3383
  case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
3384
  default: GGML_ABORT("not implemented");
3385
  } break;
 
 
 
 
 
 
 
 
3386
  case GGML_TYPE_Q4_K:
3387
  switch (r1ptg) {
3388
  case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
@@ -3481,6 +3560,7 @@ static int ggml_metal_encode_node(
3481
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
3482
  case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
3483
  case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
 
3484
  case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
3485
  case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
3486
  case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
@@ -3623,6 +3703,13 @@ static int ggml_metal_encode_node(
3623
  nr0 = N_R0_Q8_0;
3624
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
3625
  } break;
 
 
 
 
 
 
 
3626
  case GGML_TYPE_Q2_K:
3627
  {
3628
  nsg = N_SG_Q2_K;
@@ -3756,8 +3843,6 @@ static int ggml_metal_encode_node(
3756
  case GGML_OP_MUL_MAT_ID:
3757
  {
3758
  // src2 = ids
3759
- const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
3760
-
3761
  GGML_ASSERT(src2t == GGML_TYPE_I32);
3762
 
3763
  GGML_ASSERT(!ggml_is_transposed(src0));
@@ -3883,6 +3968,7 @@ static int ggml_metal_encode_node(
3883
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
3884
  case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
3885
  case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
 
3886
  case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
3887
  case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
3888
  case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
@@ -4018,6 +4104,13 @@ static int ggml_metal_encode_node(
4018
  nr0 = N_R0_Q8_0;
4019
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
4020
  } break;
 
 
 
 
 
 
 
4021
  case GGML_TYPE_Q2_K:
4022
  {
4023
  nsg = N_SG_Q2_K;
@@ -4170,6 +4263,7 @@ static int ggml_metal_encode_node(
4170
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
4171
  case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
4172
  case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
 
4173
  case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
4174
  case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
4175
  case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
@@ -4980,11 +5074,14 @@ static int ggml_metal_encode_node(
4980
  GGML_ASSERT(ne11 == ne21);
4981
  GGML_ASSERT(ne12 == ne22);
4982
 
4983
- struct ggml_tensor * src3 = node->src[3];
 
4984
 
4985
  size_t offs_src3 = 0;
 
4986
 
4987
  id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
 
4988
 
4989
  GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
4990
  GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
@@ -5000,8 +5097,6 @@ static int ggml_metal_encode_node(
5000
  const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
5001
  const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
5002
 
5003
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
5004
-
5005
  float scale;
5006
  float max_bias;
5007
  float logit_softcap;
@@ -5389,7 +5484,12 @@ static int ggml_metal_encode_node(
5389
  } else {
5390
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
5391
  }
5392
- [encoder setBuffer:id_dst offset:offs_dst atIndex:5];
 
 
 
 
 
5393
 
5394
  if (!use_vec_kernel) {
5395
  // half8x8 kernel
 
195
  GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
196
  GGML_METAL_KERNEL_TYPE_DIV,
197
  GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
198
+ GGML_METAL_KERNEL_TYPE_ADD_ID,
199
  GGML_METAL_KERNEL_TYPE_REPEAT_F32,
200
  GGML_METAL_KERNEL_TYPE_REPEAT_F16,
201
  GGML_METAL_KERNEL_TYPE_REPEAT_I32,
 
235
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
236
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
237
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
238
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,
239
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
240
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
241
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
 
288
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
289
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
290
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
291
+ GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
292
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
293
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
294
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
 
313
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
314
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
315
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
316
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2,
317
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3,
318
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4,
319
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5,
320
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
321
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
322
  GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
 
358
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
359
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
360
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
361
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
362
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
363
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
364
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
 
381
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
382
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
383
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
384
+ GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
385
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
386
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
387
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
 
406
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
407
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
408
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
409
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,
410
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
411
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
412
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
 
589
  GGML_METAL_KERNEL_TYPE_REGLU,
590
  GGML_METAL_KERNEL_TYPE_GEGLU,
591
  GGML_METAL_KERNEL_TYPE_SWIGLU,
592
+ GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,
593
  GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
594
  GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
595
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
 
1210
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
1211
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
1212
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
1213
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true);
1214
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
1215
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
1216
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
 
1250
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
1251
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
1252
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
1253
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true);
1254
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
1255
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
1256
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
 
1303
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
1304
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
1305
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
1306
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
1307
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
1308
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
1309
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
 
1328
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
1329
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
1330
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
1331
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2, mul_mv_ext_mxfp4_f32_r1_2, has_simdgroup_reduction);
1332
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3, mul_mv_ext_mxfp4_f32_r1_3, has_simdgroup_reduction);
1333
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4, mul_mv_ext_mxfp4_f32_r1_4, has_simdgroup_reduction);
1334
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5, mul_mv_ext_mxfp4_f32_r1_5, has_simdgroup_reduction);
1335
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
1336
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
1337
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
 
1373
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
1374
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
1375
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
1376
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction);
1377
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
1378
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
1379
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
 
1396
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
1397
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
1398
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
1399
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
1400
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
1401
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
1402
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
1403
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
 
1422
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
1423
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
1424
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
1425
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm);
1426
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
1427
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
1428
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
 
1605
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
1606
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
1607
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
1608
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, swiglu_oai, true);
1609
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
1610
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
1611
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
 
1797
  case GGML_GLU_OP_REGLU:
1798
  case GGML_GLU_OP_GEGLU:
1799
  case GGML_GLU_OP_SWIGLU:
1800
+ case GGML_GLU_OP_SWIGLU_OAI:
1801
  case GGML_GLU_OP_GEGLU_ERF:
1802
  case GGML_GLU_OP_GEGLU_QUICK:
1803
  return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
 
1815
  case GGML_OP_SUB:
1816
  case GGML_OP_MUL:
1817
  case GGML_OP_DIV:
1818
+ case GGML_OP_ADD_ID:
1819
  return op->src[0]->type == GGML_TYPE_F32;
1820
  case GGML_OP_ACC:
1821
  case GGML_OP_REPEAT:
 
2067
 
2068
  const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
2069
  const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
2070
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
2071
  const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
2072
 
2073
  size_t offs_src0 = 0;
 
2317
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2318
  }
2319
  } break;
2320
+ case GGML_OP_ADD_ID:
2321
+ {
2322
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
2323
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
2324
+ GGML_ASSERT(src2t == GGML_TYPE_I32);
2325
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
2326
+
2327
+ GGML_ASSERT(ggml_is_contiguous_rows(src0));
2328
+
2329
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ID].pipeline;
2330
+
2331
+ ggml_metal_kargs_add_id args = {
2332
+ /*.ne0 =*/ ne0,
2333
+ /*.ne1 =*/ ne1,
2334
+ /*.nb01 =*/ nb01,
2335
+ /*.nb02 =*/ nb02,
2336
+ /*.nb11 =*/ nb11,
2337
+ /*.nb21 =*/ nb21,
2338
+
2339
+ };
2340
+
2341
+ [encoder setComputePipelineState:pipeline];
2342
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2343
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2344
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2345
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
2346
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2347
+
2348
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
2349
+
2350
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2351
+ } break;
2352
  case GGML_OP_REPEAT:
2353
  {
2354
  id<MTLComputePipelineState> pipeline;
 
2768
  case GGML_GLU_OP_SWIGLU:
2769
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
2770
  break;
2771
+ case GGML_GLU_OP_SWIGLU_OAI:
2772
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_OAI].pipeline;
2773
+ break;
2774
  case GGML_GLU_OP_GEGLU_ERF:
2775
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
2776
  break;
 
2781
  GGML_ABORT("fatal error");
2782
  }
2783
 
2784
+ const int32_t swp = ggml_get_op_params_i32(dst, 1);
2785
+ const float alpha = ggml_get_op_params_f32(dst, 2);
2786
+ const float limit = ggml_get_op_params_f32(dst, 3);
2787
 
2788
  const int32_t i00 = swp ? ne0 : 0;
2789
  const int32_t i10 = swp ? 0 : ne0;
 
2797
  /*.nb1 =*/ nb1,
2798
  /*.i00 =*/ src1 ? 0 : i00,
2799
  /*.i10 =*/ src1 ? 0 : i10,
2800
+ /*.alpha=*/ alpha,
2801
+ /*.limit=*/ limit
2802
  };
2803
 
2804
  [encoder setComputePipelineState:pipeline];
 
3057
  } else {
3058
  [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
3059
  }
3060
+ if (id_src2) {
3061
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
3062
+ } else {
3063
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:2];
3064
+ }
3065
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3066
+ [encoder setBytes:&args length:sizeof(args) atIndex:4];
3067
 
3068
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
3069
 
 
3361
  src0t == GGML_TYPE_Q5_0 ||
3362
  src0t == GGML_TYPE_Q5_1 ||
3363
  src0t == GGML_TYPE_Q8_0 ||
3364
+ src0t == GGML_TYPE_MXFP4 ||
3365
  src0t == GGML_TYPE_IQ4_NL ||
3366
  false) && (ne11 >= 2 && ne11 <= 8)
3367
  ) ||
 
3454
  case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
3455
  default: GGML_ABORT("not implemented");
3456
  } break;
3457
+ case GGML_TYPE_MXFP4:
3458
+ switch (r1ptg) {
3459
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2].pipeline; break;
3460
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3].pipeline; break;
3461
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4].pipeline; break;
3462
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5].pipeline; break;
3463
+ default: GGML_ABORT("not implemented");
3464
+ } break;
3465
  case GGML_TYPE_Q4_K:
3466
  switch (r1ptg) {
3467
  case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
 
3560
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
3561
  case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
3562
  case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
3563
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
3564
  case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
3565
  case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
3566
  case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
 
3703
  nr0 = N_R0_Q8_0;
3704
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
3705
  } break;
3706
+ case GGML_TYPE_MXFP4:
3707
+ {
3708
+ nsg = N_SG_MXFP4;
3709
+ nr0 = N_R0_MXFP4;
3710
+ smem = 32*sizeof(float);
3711
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
3712
+ } break;
3713
  case GGML_TYPE_Q2_K:
3714
  {
3715
  nsg = N_SG_Q2_K;
 
3843
  case GGML_OP_MUL_MAT_ID:
3844
  {
3845
  // src2 = ids
 
 
3846
  GGML_ASSERT(src2t == GGML_TYPE_I32);
3847
 
3848
  GGML_ASSERT(!ggml_is_transposed(src0));
 
3968
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
3969
  case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
3970
  case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
3971
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break;
3972
  case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
3973
  case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
3974
  case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
 
4104
  nr0 = N_R0_Q8_0;
4105
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
4106
  } break;
4107
+ case GGML_TYPE_MXFP4:
4108
+ {
4109
+ nsg = N_SG_MXFP4;
4110
+ nr0 = N_R0_MXFP4;
4111
+ smem = 32*sizeof(float);
4112
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
4113
+ } break;
4114
  case GGML_TYPE_Q2_K:
4115
  {
4116
  nsg = N_SG_Q2_K;
 
4263
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
4264
  case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
4265
  case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
4266
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4 ].pipeline; break;
4267
  case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
4268
  case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
4269
  case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
 
5074
  GGML_ASSERT(ne11 == ne21);
5075
  GGML_ASSERT(ne12 == ne22);
5076
 
5077
+ struct ggml_tensor * src3 = node->src[3]; // mask
5078
+ struct ggml_tensor * src4 = node->src[4]; // sinks
5079
 
5080
  size_t offs_src3 = 0;
5081
+ size_t offs_src4 = 0;
5082
 
5083
  id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
5084
+ id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
5085
 
5086
  GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
5087
  GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
 
5097
  const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
5098
  const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
5099
 
 
 
5100
  float scale;
5101
  float max_bias;
5102
  float logit_softcap;
 
5484
  } else {
5485
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
5486
  }
5487
+ if (id_src4) {
5488
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
5489
+ } else {
5490
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
5491
+ }
5492
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
5493
 
5494
  if (!use_vec_kernel) {
5495
  // half8x8 kernel
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -35,6 +35,10 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
35
  -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
36
  };
37
 
 
 
 
 
38
  static inline int best_index_int8(int n, constant float * val, float x) {
39
  if (x <= val[0]) return 0;
40
  if (x >= val[n-1]) return n-1;
@@ -46,6 +50,18 @@ static inline int best_index_int8(int n, constant float * val, float x) {
46
  return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
47
  }
48
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  // NOTE: this is not dequantizing - we are simply fitting the template
50
  template <typename type4x4>
51
  void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -242,6 +258,27 @@ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
242
  }
243
  }
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
246
  #pragma METAL fp math_mode(safe)
247
  float amax = 0.0f; // absolute max
@@ -462,25 +499,34 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
462
  }
463
  }
464
 
465
- void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
466
- #pragma METAL fp math_mode(safe)
467
- float amax = 0.0f; // absolute max
468
 
469
- for (int j = 0; j < QK8_0; j++) {
470
- const float v = src[j];
471
- amax = MAX(amax, fabs(v));
 
 
 
 
 
472
  }
 
473
 
474
- const float d = amax / ((1 << 7) - 1);
475
- const float id = d ? 1.0f/d : 0.0f;
 
476
 
477
- dst.d = d;
 
478
 
479
- for (int j = 0; j < QK8_0; ++j) {
480
- const float x0 = src[j]*id;
481
 
482
- dst.qs[j] = round(x0);
483
- }
 
 
484
  }
485
 
486
  template <typename type4x4>
@@ -960,6 +1006,32 @@ kernel void kernel_div(
960
  }
961
  }
962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963
  template<typename T>
964
  kernel void kernel_repeat(
965
  constant ggml_metal_kargs_repeat & args,
@@ -1431,6 +1503,32 @@ kernel void kernel_swiglu(
1431
  }
1432
  }
1433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1434
  kernel void kernel_geglu_erf(
1435
  device const char * src0,
1436
  device const char * src1,
@@ -1534,6 +1632,7 @@ template<typename T>
1534
  kernel void kernel_soft_max(
1535
  device const char * src0,
1536
  device const char * src1,
 
1537
  device char * dst,
1538
  constant ggml_metal_kargs_soft_max & args,
1539
  threadgroup float * buf [[threadgroup(0)]],
@@ -1552,6 +1651,7 @@ kernel void kernel_soft_max(
1552
 
1553
  device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1554
  device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
 
1555
  device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
1556
 
1557
  float slope = 1.0f;
@@ -1567,7 +1667,7 @@ kernel void kernel_soft_max(
1567
  }
1568
 
1569
  // parallel max
1570
- float lmax = -INFINITY;
1571
 
1572
  for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1573
  lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
@@ -1623,6 +1723,10 @@ kernel void kernel_soft_max(
1623
  sum = simd_sum(sum);
1624
  }
1625
 
 
 
 
 
1626
  const float inv_sum = 1.0f/sum;
1627
 
1628
  for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
@@ -1634,6 +1738,7 @@ template<typename T>
1634
  kernel void kernel_soft_max_4(
1635
  device const char * src0,
1636
  device const char * src1,
 
1637
  device char * dst,
1638
  constant ggml_metal_kargs_soft_max & args,
1639
  threadgroup float * buf [[threadgroup(0)]],
@@ -1652,6 +1757,7 @@ kernel void kernel_soft_max_4(
1652
 
1653
  device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1654
  device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
 
1655
  device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
1656
 
1657
  float slope = 1.0f;
@@ -1666,7 +1772,7 @@ kernel void kernel_soft_max_4(
1666
  }
1667
 
1668
  // parallel max
1669
- float4 lmax4 = -INFINITY;
1670
 
1671
  for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1672
  lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
@@ -1725,6 +1831,10 @@ kernel void kernel_soft_max_4(
1725
  sum = simd_sum(sum);
1726
  }
1727
 
 
 
 
 
1728
  const float inv_sum = 1.0f/sum;
1729
 
1730
  for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
@@ -3106,6 +3216,11 @@ template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4
3106
  template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
3107
  template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
3108
 
 
 
 
 
 
3109
  template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
3110
  template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
3111
  template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
@@ -4092,6 +4207,7 @@ kernel void kernel_flash_attn_ext(
4092
  device const char * k,
4093
  device const char * v,
4094
  device const char * mask,
 
4095
  device char * dst,
4096
  threadgroup half * shmem_f16 [[threadgroup(0)]],
4097
  uint3 tgpig[[threadgroup_position_in_grid]],
@@ -4407,6 +4523,35 @@ kernel void kernel_flash_attn_ext(
4407
  }
4408
  }
4409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4410
  // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
4411
  for (short j = tiisg; j < Q; j += NW) {
4412
  ss[j*TS + 0] = S[j];
@@ -4618,6 +4763,7 @@ kernel void kernel_flash_attn_ext_vec(
4618
  device const char * k,
4619
  device const char * v,
4620
  device const char * mask,
 
4621
  device char * dst,
4622
  threadgroup half * shmem_f16 [[threadgroup(0)]],
4623
  uint3 tgpig[[threadgroup_position_in_grid]],
@@ -4835,6 +4981,23 @@ kernel void kernel_flash_attn_ext_vec(
4835
  }
4836
  }
4837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4838
  // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
4839
  if (tiisg == 0) {
4840
  ss[0] = (s_t) S;
@@ -6940,6 +7103,95 @@ kernel void kernel_mul_mv_iq4_xs_f32(
6940
  kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
6941
  }
6942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6943
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
6944
  kernel void kernel_get_rows_q(
6945
  constant ggml_metal_kargs_get_rows & args,
@@ -7475,6 +7727,7 @@ template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get
7475
  template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
7476
  template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
7477
  template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
 
7478
  template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
7479
  template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
7480
  template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
@@ -7527,6 +7780,7 @@ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_m
7527
  template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
7528
  template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
7529
  template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
 
7530
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
7531
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
7532
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
@@ -7558,6 +7812,7 @@ template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_m
7558
  template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
7559
  template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
7560
  template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
 
7561
  template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
7562
  template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
7563
  template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
@@ -7703,6 +7958,8 @@ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t
7703
  template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
7704
  template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
7705
 
 
 
7706
  template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
7707
  template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
7708
  template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
 
35
  -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
36
  };
37
 
38
+ constexpr constant static float kvalues_mxfp4_f[16] = {
39
+ 0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
40
+ };
41
+
42
  static inline int best_index_int8(int n, constant float * val, float x) {
43
  if (x <= val[0]) return 0;
44
  if (x >= val[n-1]) return n-1;
 
50
  return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
51
  }
52
 
53
+ static inline float e8m0_to_fp32(uint8_t x) {
54
+ uint32_t bits;
55
+
56
+ if (x == 0) {
57
+ bits = 0x00400000;
58
+ } else {
59
+ bits = (uint32_t) x << 23;
60
+ }
61
+
62
+ return as_type<float>(bits);
63
+ }
64
+
65
  // NOTE: this is not dequantizing - we are simply fitting the template
66
  template <typename type4x4>
67
  void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
 
258
  }
259
  }
260
 
261
+ void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
262
+ #pragma METAL fp math_mode(safe)
263
+ float amax = 0.0f; // absolute max
264
+
265
+ for (int j = 0; j < QK8_0; j++) {
266
+ const float v = src[j];
267
+ amax = MAX(amax, fabs(v));
268
+ }
269
+
270
+ const float d = amax / ((1 << 7) - 1);
271
+ const float id = d ? 1.0f/d : 0.0f;
272
+
273
+ dst.d = d;
274
+
275
+ for (int j = 0; j < QK8_0; ++j) {
276
+ const float x0 = src[j]*id;
277
+
278
+ dst.qs[j] = round(x0);
279
+ }
280
+ }
281
+
282
  void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
283
  #pragma METAL fp math_mode(safe)
284
  float amax = 0.0f; // absolute max
 
499
  }
500
  }
501
 
502
+ template <typename type4x4>
503
+ void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
504
+ device const uint8_t * q2 = (device const uint8_t *)xb->qs;
505
 
506
+ const float d = e8m0_to_fp32(xb->e);
507
+ const uint8_t shr = il >= 1 ? 4 : 0;
508
+
509
+ for (int i = 0; i < 4; ++i) {
510
+ reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
511
+ reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
512
+ reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
513
+ reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
514
  }
515
+ }
516
 
517
+ template <typename type4>
518
+ void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
519
+ device const uint8_t * q2 = (device const uint8_t *)xb->qs;
520
 
521
+ const float d = e8m0_to_fp32(xb->e);
522
+ const short il4 = il%4;
523
 
524
+ const uint8_t shr = il >= 4 ? 4 : 0;
 
525
 
526
+ reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
527
+ reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
528
+ reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
529
+ reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
530
  }
531
 
532
  template <typename type4x4>
 
1006
  }
1007
  }
1008
 
1009
+ kernel void kernel_add_id(
1010
+ constant ggml_metal_kargs_add_id & args,
1011
+ device const char * src0,
1012
+ device const char * src1,
1013
+ device const char * src2,
1014
+ device char * dst,
1015
+ uint3 tgpig[[threadgroup_position_in_grid]],
1016
+ ushort3 tpitg[[thread_position_in_threadgroup]],
1017
+ ushort3 ntg[[threads_per_threadgroup]]) {
1018
+ const int i1 = tgpig.x;
1019
+ const int i2 = tgpig.y;
1020
+
1021
+ const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
1022
+
1023
+ const size_t nb1 = args.ne0 * sizeof(float);
1024
+ const size_t nb2 = args.ne1 * nb1;
1025
+
1026
+ device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
1027
+ device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
1028
+ device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
1029
+
1030
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1031
+ dst_row[i0] = src0_row[i0] + src1_row[i0];
1032
+ }
1033
+ }
1034
+
1035
  template<typename T>
1036
  kernel void kernel_repeat(
1037
  constant ggml_metal_kargs_repeat & args,
 
1503
  }
1504
  }
1505
 
1506
+ kernel void kernel_swiglu_oai(
1507
+ device const char * src0,
1508
+ device const char * src1,
1509
+ device char * dst,
1510
+ constant ggml_metal_kargs_glu & args,
1511
+ uint tgpig[[threadgroup_position_in_grid]],
1512
+ uint tpitg[[thread_position_in_threadgroup]],
1513
+ uint ntg[[threads_per_threadgroup]]) {
1514
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1515
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1516
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1517
+
1518
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1519
+ float x0 = src0_row[i0];
1520
+ float x1 = src1_row[i0];
1521
+
1522
+ x0 = min(x0, args.limit);
1523
+ x1 = max(min(x1, args.limit), -args.limit);
1524
+
1525
+ float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
1526
+ out_glu = out_glu * (1.0f + x1);
1527
+
1528
+ dst_row[i0] = out_glu;
1529
+ }
1530
+ }
1531
+
1532
  kernel void kernel_geglu_erf(
1533
  device const char * src0,
1534
  device const char * src1,
 
1632
  kernel void kernel_soft_max(
1633
  device const char * src0,
1634
  device const char * src1,
1635
+ device const char * src2,
1636
  device char * dst,
1637
  constant ggml_metal_kargs_soft_max & args,
1638
  threadgroup float * buf [[threadgroup(0)]],
 
1651
 
1652
  device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1653
  device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
1654
+ device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
1655
  device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
1656
 
1657
  float slope = 1.0f;
 
1667
  }
1668
 
1669
  // parallel max
1670
+ float lmax = psrc2 ? psrc2[i02] : -INFINITY;
1671
 
1672
  for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1673
  lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
 
1723
  sum = simd_sum(sum);
1724
  }
1725
 
1726
+ if (psrc2) {
1727
+ sum += exp(psrc2[i02] - max_val);
1728
+ }
1729
+
1730
  const float inv_sum = 1.0f/sum;
1731
 
1732
  for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
 
1738
  kernel void kernel_soft_max_4(
1739
  device const char * src0,
1740
  device const char * src1,
1741
+ device const char * src2,
1742
  device char * dst,
1743
  constant ggml_metal_kargs_soft_max & args,
1744
  threadgroup float * buf [[threadgroup(0)]],
 
1757
 
1758
  device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1759
  device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
1760
+ device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
1761
  device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
1762
 
1763
  float slope = 1.0f;
 
1772
  }
1773
 
1774
  // parallel max
1775
+ float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
1776
 
1777
  for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1778
  lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
 
1831
  sum = simd_sum(sum);
1832
  }
1833
 
1834
+ if (psrc2) {
1835
+ sum += exp(psrc2[i02] - max_val);
1836
+ }
1837
+
1838
  const float inv_sum = 1.0f/sum;
1839
 
1840
  for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
 
3216
  template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
3217
  template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
3218
 
3219
+ template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4, 32, dequantize_mxfp4_t4>;
3220
+ template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4, 32, dequantize_mxfp4_t4>;
3221
+ template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>;
3222
+ template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>;
3223
+
3224
  template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
3225
  template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
3226
  template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
 
4207
  device const char * k,
4208
  device const char * v,
4209
  device const char * mask,
4210
+ device const char * sinks,
4211
  device char * dst,
4212
  threadgroup half * shmem_f16 [[threadgroup(0)]],
4213
  uint3 tgpig[[threadgroup_position_in_grid]],
 
4523
  }
4524
  }
4525
 
4526
+ if (sinks != q && sgitg == 0) {
4527
+ for (ushort j = 0; j < Q; ++j) {
4528
+ const float m = M[j];
4529
+ const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
4530
+
4531
+ M[j] = simd_max(max(M[j], s));
4532
+
4533
+ const float ms = exp(m - M[j]);
4534
+ const float vs = exp(s - M[j]);
4535
+
4536
+ S[j] = S[j]*ms + simd_sum(vs);
4537
+
4538
+ if (tiisg == j) {
4539
+ ss[j*TS + 2*C + j] = ms;
4540
+ }
4541
+ }
4542
+
4543
+ // O = diag(ms)*O
4544
+ {
4545
+ s8x8_t ms;
4546
+ simdgroup_load(ms, ss + 2*C, TS, 0, false);
4547
+
4548
+ #pragma unroll(DV8)
4549
+ for (short i = 0; i < DV8; ++i) {
4550
+ simdgroup_multiply(lo[i], ms, lo[i]);
4551
+ }
4552
+ }
4553
+ }
4554
+
4555
  // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
4556
  for (short j = tiisg; j < Q; j += NW) {
4557
  ss[j*TS + 0] = S[j];
 
4763
  device const char * k,
4764
  device const char * v,
4765
  device const char * mask,
4766
+ device const char * sinks,
4767
  device char * dst,
4768
  threadgroup half * shmem_f16 [[threadgroup(0)]],
4769
  uint3 tgpig[[threadgroup_position_in_grid]],
 
4981
  }
4982
  }
4983
 
4984
+ if (sinks != q && sgitg == 0) {
4985
+ const float m = M;
4986
+ const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
4987
+
4988
+ M = simd_max(max(M, s));
4989
+
4990
+ const float ms = exp(m - M);
4991
+ const float vs = exp(s - M);
4992
+
4993
+ S = S*ms + simd_sum(vs);
4994
+
4995
+ #pragma unroll(DV4/NL)
4996
+ for (short ii = 0; ii < DV4; ii += NL) {
4997
+ lo[ii/NL] *= ms;
4998
+ }
4999
+ }
5000
+
5001
  // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
5002
  if (tiisg == 0) {
5003
  ss[0] = (s_t) S;
 
7103
  kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
7104
  }
7105
 
7106
+ template<int nr0, int nsg, int nw, typename args_t>
7107
+ void kernel_mul_mv_mxfp4_f32_impl(
7108
+ args_t args,
7109
+ device const char * src0,
7110
+ device const char * src1,
7111
+ device char * dst,
7112
+ threadgroup char * shmem,
7113
+ uint3 tgpig,
7114
+ ushort tiisg,
7115
+ ushort sgitg) {
7116
+
7117
+ threadgroup float * shmem_f32 = (threadgroup float *) shmem;
7118
+ const int nb = args.ne00/QK_MXFP4;
7119
+
7120
+ const int r0 = tgpig.x;
7121
+ const int r1 = tgpig.y;
7122
+ const int im = tgpig.z;
7123
+
7124
+ const int first_row = (r0 * nsg + sgitg) * nr0;
7125
+
7126
+ const uint i12 = im%args.ne12;
7127
+ const uint i13 = im/args.ne12;
7128
+
7129
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7130
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7131
+
7132
+ device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
7133
+ device const float * y = (device const float *) (src1 + offset1);
7134
+
7135
+ const short ix = tiisg/2; // 0...15
7136
+ const short it = tiisg%2; // 0 or 1
7137
+
7138
+ shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
7139
+ threadgroup_barrier(mem_flags::mem_threadgroup);
7140
+
7141
+ float4 yl[4];
7142
+ float sumf[nr0]={0.f};
7143
+
7144
+ device const float * yb = y + ix * QK_MXFP4 + it * 8;
7145
+
7146
+ for (int ib = ix; ib < nb; ib += 16) {
7147
+ device const float4 * y4 = (device const float4 *)yb;
7148
+ yl[0] = y4[0];
7149
+ yl[1] = y4[4];
7150
+ yl[2] = y4[1];
7151
+ yl[3] = y4[5];
7152
+
7153
+ #pragma unroll(nr0)
7154
+ for (short row = 0; row < nr0; row++) {
7155
+ device const block_mxfp4 & xb = x[row*nb + ib];
7156
+ device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
7157
+
7158
+ float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
7159
+ float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]);
7160
+ float4 acc3 = yl[2]*float4(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]);
7161
+ float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]);
7162
+
7163
+ acc1 = (acc1 + acc3) + (acc2 + acc4);
7164
+
7165
+ sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3]));
7166
+ }
7167
+
7168
+ yb += 16 * QK_MXFP4;
7169
+ }
7170
+
7171
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7172
+
7173
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7174
+ float sum_all = simd_sum(sumf[row]);
7175
+ if (tiisg == 0) {
7176
+ dst_f32[first_row + row] = sum_all;
7177
+ }
7178
+ }
7179
+ }
7180
+
7181
+ [[host_name("kernel_mul_mv_mxfp4_f32")]]
7182
+ kernel void kernel_mul_mv_mxfp4_f32(
7183
+ constant ggml_metal_kargs_mul_mv & args,
7184
+ device const char * src0,
7185
+ device const char * src1,
7186
+ device char * dst,
7187
+ threadgroup char * shmem [[threadgroup(0)]],
7188
+ uint3 tgpig[[threadgroup_position_in_grid]],
7189
+ ushort tiisg[[thread_index_in_simdgroup]],
7190
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7191
+
7192
+ kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
7193
+ }
7194
+
7195
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
7196
  kernel void kernel_get_rows_q(
7197
  constant ggml_metal_kargs_get_rows & args,
 
7727
  template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
7728
  template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
7729
  template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
7730
+ template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
7731
  template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
7732
  template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
7733
  template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
 
7780
  template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
7781
  template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
7782
  template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
7783
+ template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
7784
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
7785
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
7786
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
 
7812
  template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
7813
  template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
7814
  template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
7815
+ template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
7816
  template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
7817
  template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
7818
  template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
 
7958
  template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
7959
  template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
7960
 
7961
+ template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH>>>;
7962
+
7963
  template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
7964
  template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
7965
  template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
ggml/src/ggml-opencl/ggml-opencl.cpp CHANGED
@@ -2497,6 +2497,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2497
  case GGML_OP_CLAMP:
2498
  return op->src[0]->type == GGML_TYPE_F32;
2499
  case GGML_OP_SOFT_MAX:
 
 
2500
  case GGML_OP_NORM:
2501
  case GGML_OP_RMS_NORM:
2502
  return true;
 
2497
  case GGML_OP_CLAMP:
2498
  return op->src[0]->type == GGML_TYPE_F32;
2499
  case GGML_OP_SOFT_MAX:
2500
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
2501
+ return op->src[2] == nullptr;
2502
  case GGML_OP_NORM:
2503
  case GGML_OP_RMS_NORM:
2504
  return true;
ggml/src/ggml-quants.c CHANGED
@@ -21,6 +21,17 @@
21
 
22
  #define UNUSED GGML_UNUSED
23
 
 
 
 
 
 
 
 
 
 
 
 
24
  // reference implementation for deterministic creation of model files
25
  void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
26
  static const int qk = QK4_0;
@@ -246,6 +257,53 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST
246
  }
247
  }
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
250
  static const int qk = QK4_0;
251
 
@@ -356,6 +414,26 @@ void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRI
356
  }
357
  }
358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  //
360
  // 2-6 bit quantization in super-blocks
361
  //
@@ -2014,6 +2092,12 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
2014
  return nrow * row_size;
2015
  }
2016
 
 
 
 
 
 
 
2017
  // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
2018
 
2019
  void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
@@ -4551,17 +4635,6 @@ size_t quantize_iq1_m(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
4551
 
4552
  // ============================ 4-bit non-linear quants
4553
 
4554
- static inline int best_index_int8(int n, const int8_t * val, float x) {
4555
- if (x <= val[0]) return 0;
4556
- if (x >= val[n-1]) return n-1;
4557
- int ml = 0, mu = n-1;
4558
- while (mu-ml > 1) {
4559
- int mav = (ml+mu)/2;
4560
- if (x < val[mav]) mu = mav; else ml = mav;
4561
- }
4562
- return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
4563
- }
4564
-
4565
  static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
4566
  ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
4567
  float * scales, float * weight, uint8_t * L,
@@ -4961,6 +5034,15 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
4961
  return true;
4962
  }
4963
 
 
 
 
 
 
 
 
 
 
4964
  #define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
4965
  const type * q = (const type *) (data); \
4966
  for (size_t i = 0; i < (nb); ++i) { \
@@ -4977,6 +5059,14 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
4977
  } \
4978
  }
4979
 
 
 
 
 
 
 
 
 
4980
  #define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
4981
  const type * q = (const type *) (data); \
4982
  for (size_t i = 0; i < (nb); ++i) { \
@@ -5130,6 +5220,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
5130
  {
5131
  VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
5132
  } break;
 
 
 
 
5133
  case GGML_TYPE_Q2_K:
5134
  {
5135
  VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
 
21
 
22
  #define UNUSED GGML_UNUSED
23
 
24
+ static inline int best_index_int8(int n, const int8_t * val, float x) {
25
+ if (x <= val[0]) return 0;
26
+ if (x >= val[n-1]) return n-1;
27
+ int ml = 0, mu = n-1;
28
+ while (mu-ml > 1) {
29
+ int mav = (ml+mu)/2;
30
+ if (x < val[mav]) mu = mav; else ml = mav;
31
+ }
32
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
33
+ }
34
+
35
  // reference implementation for deterministic creation of model files
36
  void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
37
  static const int qk = QK4_0;
 
257
  }
258
  }
259
 
260
+ static inline int best_index_mxfp4(float x, float e) {
261
+ int best_index = 0;
262
+ float best_err = fabsf(kvalues_mxfp4[0]*e - x);
263
+ for (int i = 1; i < 16; i++) {
264
+ float err = fabsf(kvalues_mxfp4[i]*e - x);
265
+ if (err < best_err) {
266
+ best_index = i;
267
+ best_err = err;
268
+ }
269
+ }
270
+ return best_index;
271
+ }
272
+
273
+ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
274
+ static const int qk = QK_MXFP4;
275
+
276
+ assert(k % qk == 0);
277
+
278
+ const int nb = k / qk;
279
+
280
+ for (int i = 0; i < nb; i++) {
281
+ float amax = 0.0f; // absolute max
282
+
283
+ for (int j = 0; j < qk; j++) {
284
+ const float v = x[i*qk + j];
285
+
286
+ if (amax < fabsf(v)) {
287
+ amax = fabsf(v);
288
+ }
289
+ }
290
+
291
+ const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127);
292
+
293
+ const float d = GGML_E8M0_TO_FP32_HALF(e);
294
+
295
+ y[i].e = e;
296
+
297
+ for (int j = 0; j < qk/2; ++j) {
298
+ const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d);
299
+ const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
300
+
301
+ y[i].qs[j] = x0;
302
+ y[i].qs[j] |= x1 << 4;
303
+ }
304
+ }
305
+ }
306
+
307
  void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
308
  static const int qk = QK4_0;
309
 
 
414
  }
415
  }
416
 
417
+ void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
418
+ static const int qk = QK_MXFP4;
419
+
420
+ assert(k % qk == 0);
421
+
422
+ const int nb = k / qk;
423
+
424
+ for (int i = 0; i < nb; i++) {
425
+ const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
426
+
427
+ for (int j = 0; j < qk/2; ++j) {
428
+ const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];
429
+ const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >> 4];
430
+
431
+ y[i*qk + j + 0 ] = x0*d;
432
+ y[i*qk + j + qk/2] = x1*d;
433
+ }
434
+ }
435
+ }
436
+
437
  //
438
  // 2-6 bit quantization in super-blocks
439
  //
 
2092
  return nrow * row_size;
2093
  }
2094
 
2095
+ size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
2096
+ GGML_UNUSED(quant_weights);
2097
+ quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
2098
+ return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
2099
+ }
2100
+
2101
  // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
2102
 
2103
  void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
 
4635
 
4636
  // ============================ 4-bit non-linear quants
4637
 
 
 
 
 
 
 
 
 
 
 
 
4638
  static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
4639
  ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
4640
  float * scales, float * weight, uint8_t * L,
 
5034
  return true;
5035
  }
5036
 
5037
+ static bool validate_e_e8m0(uint8_t e, size_t i) {
5038
+ if (e == 0xff) {
5039
+ fprintf(stderr, "ggml_validate_row_data: found invalid e value %d at block %zu\n", e, i);
5040
+ return false;
5041
+ }
5042
+
5043
+ return true;
5044
+ }
5045
+
5046
  #define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
5047
  const type * q = (const type *) (data); \
5048
  for (size_t i = 0; i < (nb); ++i) { \
 
5059
  } \
5060
  }
5061
 
5062
+ #define VALIDATE_ROW_DATA_E_E8M0_IMPL(type, data, nb) \
5063
+ const type * q = (const type *) (data); \
5064
+ for (size_t i = 0; i < (nb); ++i) { \
5065
+ if (!validate_e_e8m0(q[i].e, i)) { \
5066
+ return false; \
5067
+ } \
5068
+ }
5069
+
5070
  #define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
5071
  const type * q = (const type *) (data); \
5072
  for (size_t i = 0; i < (nb); ++i) { \
 
5220
  {
5221
  VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
5222
  } break;
5223
+ case GGML_TYPE_MXFP4:
5224
+ {
5225
+ VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
5226
+ } break;
5227
  case GGML_TYPE_Q2_K:
5228
  {
5229
  VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
ggml/src/ggml-quants.h CHANGED
@@ -21,6 +21,8 @@ GGML_API void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 *
21
  GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
22
  GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
23
 
 
 
24
  GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
25
  GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
26
  GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
@@ -45,6 +47,8 @@ GGML_API void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GG
45
  GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
46
  //GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
47
 
 
 
48
  GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
49
  GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
50
  GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@@ -90,6 +94,8 @@ GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTR
90
  GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
91
  GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
92
 
 
 
93
  GGML_API void iq2xs_init_impl(enum ggml_type type);
94
  GGML_API void iq2xs_free_impl(enum ggml_type type);
95
  GGML_API void iq3xs_init_impl(int grid_size);
 
21
  GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
22
  GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
23
 
24
+ GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
25
+
26
  GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
27
  GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
28
  GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
 
47
  GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
48
  //GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
49
 
50
+ GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
51
+
52
  GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
53
  GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
54
  GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 
94
  GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
95
  GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
96
 
97
+ GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
98
+
99
  GGML_API void iq2xs_init_impl(enum ggml_type type);
100
  GGML_API void iq2xs_free_impl(enum ggml_type type);
101
  GGML_API void iq3xs_init_impl(int grid_size);
ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -4193,15 +4193,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4193
  case GGML_OP_MUL_MAT:
4194
  case GGML_OP_MUL_MAT_ID:
4195
  {
4196
- struct ggml_tensor * a;
4197
- struct ggml_tensor * b;
4198
- if (op->op == GGML_OP_MUL_MAT) {
4199
- a = op->src[0];
4200
- b = op->src[1];
4201
- } else {
4202
- a = op->src[2];
4203
- b = op->src[1];
4204
- }
4205
  if (a->ne[3] != b->ne[3]) {
4206
  return false;
4207
  }
@@ -4216,7 +4210,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4216
  }
4217
  }
4218
  ggml_type src0_type = op->src[0]->type;
4219
- if (src0_type == GGML_TYPE_BF16) {
 
 
4220
  return false;
4221
  }
4222
  return true;
@@ -4361,6 +4357,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4361
  if (op->src[0]->ne[3] != 1) {
4362
  return false;
4363
  }
 
 
 
 
4364
  // TODO: support broadcast
4365
  // ref: https://github.com/ggml-org/llama.cpp/pull/14435
4366
  return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
 
4193
  case GGML_OP_MUL_MAT:
4194
  case GGML_OP_MUL_MAT_ID:
4195
  {
4196
+ struct ggml_tensor * a = op->src[0];
4197
+ struct ggml_tensor * b = op->src[1];
4198
+
 
 
 
 
 
 
4199
  if (a->ne[3] != b->ne[3]) {
4200
  return false;
4201
  }
 
4210
  }
4211
  }
4212
  ggml_type src0_type = op->src[0]->type;
4213
+ if (src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_MXFP4) {
4214
+ // TODO: support MXFP4
4215
+ // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
4216
  return false;
4217
  }
4218
  return true;
 
4357
  if (op->src[0]->ne[3] != 1) {
4358
  return false;
4359
  }
4360
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
4361
+ if (op->src[2]) {
4362
+ return false;
4363
+ }
4364
  // TODO: support broadcast
4365
  // ref: https://github.com/ggml-org/llama.cpp/pull/14435
4366
  return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -449,6 +449,8 @@ struct vk_device_struct {
449
  vk_pipeline pipeline_div[2][2][2];
450
  vk_pipeline pipeline_div_norepeat[2][2][2];
451
 
 
 
452
  vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
453
  vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
454
  vk_pipeline pipeline_scale_f32;
@@ -483,6 +485,7 @@ struct vk_device_struct {
483
  vk_pipeline pipeline_geglu[2];
484
  vk_pipeline pipeline_reglu[2];
485
  vk_pipeline pipeline_swiglu[2];
 
486
  vk_pipeline pipeline_geglu_erf[2];
487
  vk_pipeline pipeline_geglu_quick[2];
488
 
@@ -705,6 +708,8 @@ struct vk_op_glu_push_constants {
705
  uint32_t ne00;
706
  uint32_t ne20;
707
  uint32_t mode; // 0: default, 1: swapped, 2: split
 
 
708
  };
709
 
710
  struct vk_op_unary_push_constants {
@@ -794,6 +799,15 @@ struct vk_op_binary_push_constants {
794
  float param1; float param2; int32_t param3;
795
  };
796
 
 
 
 
 
 
 
 
 
 
797
  struct vk_op_diag_mask_push_constants {
798
  uint32_t ncols;
799
  uint32_t rows_per_channel;
@@ -835,6 +849,7 @@ struct vk_op_soft_max_push_constants {
835
  float m1;
836
  uint32_t n_head_log2;
837
  uint32_t nrows_x;
 
838
  };
839
 
840
  struct vk_op_argsort_push_constants {
@@ -1977,6 +1992,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
1977
  break;
1978
  case GGML_TYPE_IQ4_NL:
1979
  case GGML_TYPE_IQ4_XS:
 
1980
  lut_size = 4*16;
1981
  break;
1982
  default:
@@ -2353,6 +2369,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2353
  CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2354
  CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2355
  CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 
2356
 
2357
  CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
2358
  #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
@@ -2379,6 +2396,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2379
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2380
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2381
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
 
2382
  #undef CREATE_MM
2383
  #undef CREATE_MM2
2384
  } else
@@ -2440,6 +2458,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2440
  CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2441
  CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2442
  CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
2443
  } else {
2444
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2445
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2461,6 +2480,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2461
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2462
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2463
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
2464
  }
2465
 
2466
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2493,6 +2513,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2493
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2494
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2495
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
 
2496
  } else {
2497
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2498
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2514,6 +2535,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2514
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2515
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2516
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
 
2517
  }
2518
  #undef CREATE_MM2
2519
  #undef CREATE_MM
@@ -2581,6 +2603,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2581
  CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2582
  CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2583
  CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
2584
 
2585
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2586
  if (device->integer_dot_product) {
@@ -2618,6 +2641,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2618
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2619
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2620
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
 
2621
  #undef CREATE_MM2
2622
  #undef CREATE_MMQ
2623
  #undef CREATE_MM
@@ -2672,6 +2696,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2672
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2673
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2674
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
2675
 
2676
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2677
  if (device->integer_dot_product) {
@@ -2709,6 +2734,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2709
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2710
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2711
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
 
2712
  }
2713
  // reusing CREATE_MM from the fp32 path
2714
  if ((device->coopmat2 || device->coopmat_support)
@@ -2767,6 +2793,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2767
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2768
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2769
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
 
2770
 
2771
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2772
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
@@ -2790,6 +2817,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2790
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2791
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2792
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
 
2793
  }
2794
 
2795
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -2814,6 +2842,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2814
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2815
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2816
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
 
2817
 
2818
  // dequant shaders
2819
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -2836,6 +2865,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2836
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2837
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2838
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
 
2839
 
2840
  // get_rows
2841
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -2855,6 +2885,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2855
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2856
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2857
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
2858
 
2859
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2860
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -2873,6 +2904,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2873
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2874
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2875
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
2876
 
2877
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2878
  ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
@@ -2976,6 +3008,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2976
  CREATE_BINARY(div, _norepeat, {1})
2977
  #undef CREATE_BINARY
2978
 
 
 
2979
  ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2980
 
2981
  ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -3026,6 +3060,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3026
  CREATE_GLU(geglu)
3027
  CREATE_GLU(reglu)
3028
  CREATE_GLU(swiglu)
 
3029
  CREATE_GLU(geglu_erf)
3030
  CREATE_GLU(geglu_quick)
3031
  #undef CREATE_GLU
@@ -3035,10 +3070,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
3035
 
3036
  ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
3037
 
3038
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3039
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
3040
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3041
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
3042
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3043
 
3044
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -4244,6 +4279,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
4244
  case GGML_TYPE_IQ3_S:
4245
  case GGML_TYPE_IQ4_XS:
4246
  case GGML_TYPE_IQ4_NL:
 
4247
  break;
4248
  default:
4249
  return nullptr;
@@ -4314,6 +4350,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
4314
  case GGML_TYPE_IQ3_S:
4315
  case GGML_TYPE_IQ4_XS:
4316
  case GGML_TYPE_IQ4_NL:
 
4317
  break;
4318
  default:
4319
  return nullptr;
@@ -4357,6 +4394,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
4357
  case GGML_TYPE_IQ3_S:
4358
  case GGML_TYPE_IQ4_XS:
4359
  case GGML_TYPE_IQ4_NL:
 
4360
  break;
4361
  default:
4362
  return nullptr;
@@ -4411,6 +4449,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
4411
  case GGML_TYPE_IQ3_S:
4412
  case GGML_TYPE_IQ4_XS:
4413
  case GGML_TYPE_IQ4_NL:
 
4414
  break;
4415
  default:
4416
  return nullptr;
@@ -4446,6 +4485,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
4446
  case GGML_TYPE_IQ3_S:
4447
  case GGML_TYPE_IQ4_XS:
4448
  case GGML_TYPE_IQ4_NL:
 
4449
  break;
4450
  default:
4451
  return nullptr;
@@ -4631,6 +4671,7 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
4631
  std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
4632
  GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
4633
  GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
 
4634
 
4635
  vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
4636
  vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
@@ -6847,6 +6888,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6847
  break;
6848
  }
6849
  return nullptr;
 
 
 
 
 
6850
  case GGML_OP_CONCAT:
6851
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6852
  return ctx->device->pipeline_concat_f32;
@@ -6992,6 +7038,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6992
  return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
6993
  case GGML_GLU_OP_SWIGLU:
6994
  return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
 
 
6995
  case GGML_GLU_OP_GEGLU_ERF:
6996
  return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
6997
  case GGML_GLU_OP_GEGLU_QUICK:
@@ -7007,6 +7055,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
7007
  return nullptr;
7008
  case GGML_OP_SOFT_MAX:
7009
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
 
7010
 
7011
  if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
7012
  return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
@@ -7177,6 +7226,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
7177
  case GGML_OP_SUB:
7178
  case GGML_OP_MUL:
7179
  case GGML_OP_DIV:
 
7180
  case GGML_OP_CONCAT:
7181
  case GGML_OP_UPSCALE:
7182
  case GGML_OP_SQR:
@@ -7523,6 +7573,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7523
  elements = { ne, 1, 1 };
7524
  }
7525
  } break;
 
 
 
 
7526
  case GGML_OP_SET_ROWS:
7527
  {
7528
  uint32_t ne = ggml_nelements(src0);
@@ -7562,8 +7616,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7562
  }
7563
  }
7564
 
7565
- if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
7566
- // Empty src1 is possible in soft_max, but the shader needs a buffer
7567
  vk_subbuffer subbuf_y;
7568
  if (use_src1) {
7569
  subbuf_y = { d_Y, y_buf_offset, y_sz };
@@ -7573,6 +7627,24 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7573
 
7574
  ggml_vk_sync_buffers(subctx);
7575
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7576
  } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
7577
  // Empty src2 is possible in rope, but the shader needs a buffer
7578
  vk_subbuffer subbuf_z;
@@ -7701,6 +7773,21 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
7701
  }, dryrun);
7702
  }
7703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7704
  static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
7705
  GGML_ASSERT(version == 6 || version == 7);
7706
  int num_srcs = version == 6 ? 6 : 7;
@@ -8119,8 +8206,12 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
8119
  }
8120
 
8121
  static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
 
 
8122
  const bool swapped = (bool)dst->op_params[1];
8123
  const bool split = src1 != nullptr;
 
 
8124
 
8125
  GGML_ASSERT(ggml_is_contiguous(src0));
8126
 
@@ -8134,7 +8225,15 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
8134
 
8135
  const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
8136
 
8137
- ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
 
 
 
 
 
 
 
 
8138
  }
8139
 
8140
  static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -8142,7 +8241,7 @@ static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& sub
8142
  ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
8143
  }
8144
 
8145
- static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
8146
  float * op_params = (float *)dst->op_params;
8147
 
8148
  float scale = op_params[0];
@@ -8164,7 +8263,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
8164
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
8165
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
8166
 
8167
- ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
8168
  ncols,
8169
  src1 != nullptr ? nrows_y : (uint32_t)0,
8170
  (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
@@ -8174,6 +8273,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
8174
  m0, m1,
8175
  n_head_log2,
8176
  nrows_x,
 
8177
  }, dryrun);
8178
  }
8179
 
@@ -9413,6 +9513,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9413
  case GGML_GLU_OP_GEGLU:
9414
  case GGML_GLU_OP_REGLU:
9415
  case GGML_GLU_OP_SWIGLU:
 
9416
  case GGML_GLU_OP_GEGLU_ERF:
9417
  case GGML_GLU_OP_GEGLU_QUICK:
9418
  break;
@@ -9424,6 +9525,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9424
  case GGML_OP_REPEAT_BACK:
9425
  case GGML_OP_GET_ROWS:
9426
  case GGML_OP_ADD:
 
9427
  case GGML_OP_ACC:
9428
  case GGML_OP_SUB:
9429
  case GGML_OP_MUL:
@@ -9578,6 +9680,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9578
  case GGML_OP_DIV:
9579
  ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
9580
 
 
 
 
 
9581
  break;
9582
  case GGML_OP_CONCAT:
9583
  ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9675,6 +9781,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9675
  case GGML_GLU_OP_GEGLU:
9676
  case GGML_GLU_OP_REGLU:
9677
  case GGML_GLU_OP_SWIGLU:
 
9678
  case GGML_GLU_OP_GEGLU_ERF:
9679
  case GGML_GLU_OP_GEGLU_QUICK:
9680
  ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9688,7 +9795,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9688
 
9689
  break;
9690
  case GGML_OP_SOFT_MAX:
9691
- ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
9692
 
9693
  break;
9694
  case GGML_OP_SOFT_MAX_BACK:
@@ -9834,6 +9941,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
9834
  case GGML_OP_SUB:
9835
  case GGML_OP_MUL:
9836
  case GGML_OP_DIV:
 
9837
  case GGML_OP_CONCAT:
9838
  case GGML_OP_UPSCALE:
9839
  case GGML_OP_SCALE:
@@ -9903,6 +10011,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
9903
  case GGML_GLU_OP_GEGLU:
9904
  case GGML_GLU_OP_REGLU:
9905
  case GGML_GLU_OP_SWIGLU:
 
9906
  case GGML_GLU_OP_GEGLU_ERF:
9907
  case GGML_GLU_OP_GEGLU_QUICK:
9908
  buf = tensor->buffer;
@@ -10752,6 +10861,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10752
  case GGML_GLU_OP_GEGLU:
10753
  case GGML_GLU_OP_REGLU:
10754
  case GGML_GLU_OP_SWIGLU:
 
10755
  case GGML_GLU_OP_GEGLU_ERF:
10756
  case GGML_GLU_OP_GEGLU_QUICK:
10757
  return ggml_is_contiguous(op->src[0]) &&
@@ -10797,6 +10907,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10797
  case GGML_TYPE_IQ3_S:
10798
  case GGML_TYPE_IQ4_XS:
10799
  case GGML_TYPE_IQ4_NL:
 
10800
  break;
10801
  default:
10802
  return false;
@@ -10834,6 +10945,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10834
  if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
10835
  return false;
10836
  }
 
 
 
 
10837
  if (op->src[0]->type != GGML_TYPE_F32) {
10838
  return false;
10839
  }
@@ -10906,6 +11021,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10906
  case GGML_TYPE_IQ3_S:
10907
  case GGML_TYPE_IQ4_XS:
10908
  case GGML_TYPE_IQ4_NL:
 
10909
  return true;
10910
  default:
10911
  return false;
@@ -11004,6 +11120,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
11004
  return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
11005
  (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
11006
  (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
 
 
 
11007
  case GGML_OP_SILU_BACK:
11008
  case GGML_OP_RMS_NORM_BACK:
11009
  case GGML_OP_SQR:
 
449
  vk_pipeline pipeline_div[2][2][2];
450
  vk_pipeline pipeline_div_norepeat[2][2][2];
451
 
452
+ vk_pipeline pipeline_add_id_f32;
453
+
454
  vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
455
  vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
456
  vk_pipeline pipeline_scale_f32;
 
485
  vk_pipeline pipeline_geglu[2];
486
  vk_pipeline pipeline_reglu[2];
487
  vk_pipeline pipeline_swiglu[2];
488
+ vk_pipeline pipeline_swiglu_oai[2];
489
  vk_pipeline pipeline_geglu_erf[2];
490
  vk_pipeline pipeline_geglu_quick[2];
491
 
 
708
  uint32_t ne00;
709
  uint32_t ne20;
710
  uint32_t mode; // 0: default, 1: swapped, 2: split
711
+ float alpha; // for swiglu_oai
712
+ float limit;
713
  };
714
 
715
  struct vk_op_unary_push_constants {
 
799
  float param1; float param2; int32_t param3;
800
  };
801
 
802
+ struct vk_op_add_id_push_constants {
803
+ uint32_t ne0;
804
+ uint32_t ne1;
805
+ uint32_t s01;
806
+ uint32_t s02;
807
+ uint32_t s11;
808
+ uint32_t s21;
809
+ };
810
+
811
  struct vk_op_diag_mask_push_constants {
812
  uint32_t ncols;
813
  uint32_t rows_per_channel;
 
849
  float m1;
850
  uint32_t n_head_log2;
851
  uint32_t nrows_x;
852
+ uint32_t has_sinks;
853
  };
854
 
855
  struct vk_op_argsort_push_constants {
 
1992
  break;
1993
  case GGML_TYPE_IQ4_NL:
1994
  case GGML_TYPE_IQ4_XS:
1995
+ case GGML_TYPE_MXFP4:
1996
  lut_size = 4*16;
1997
  break;
1998
  default:
 
2369
  CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2370
  CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2371
  CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2372
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2373
 
2374
  CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
2375
  #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
 
2396
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2397
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2398
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2399
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2400
  #undef CREATE_MM
2401
  #undef CREATE_MM2
2402
  } else
 
2458
  CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2459
  CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2460
  CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2461
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2462
  } else {
2463
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2464
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
2480
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2481
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2482
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2483
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2484
  }
2485
 
2486
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
 
2513
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2514
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2515
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2516
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2517
  } else {
2518
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2519
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
 
2535
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2536
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2537
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2538
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2539
  }
2540
  #undef CREATE_MM2
2541
  #undef CREATE_MM
 
2603
  CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2604
  CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2605
  CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2606
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2607
 
2608
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2609
  if (device->integer_dot_product) {
 
2641
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2642
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2643
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2644
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2645
  #undef CREATE_MM2
2646
  #undef CREATE_MMQ
2647
  #undef CREATE_MM
 
2696
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2697
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2698
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2699
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2700
 
2701
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2702
  if (device->integer_dot_product) {
 
2734
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2735
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2736
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2737
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2738
  }
2739
  // reusing CREATE_MM from the fp32 path
2740
  if ((device->coopmat2 || device->coopmat_support)
 
2793
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2794
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2795
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2796
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f32_f32_len, mul_mat_vec_mxfp4_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2797
 
2798
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2799
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
 
2817
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2818
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2819
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2820
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f16_f32_len, mul_mat_vec_mxfp4_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2821
  }
2822
 
2823
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
 
2842
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2843
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2844
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2845
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2846
 
2847
  // dequant shaders
2848
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
 
2865
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2866
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2867
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
2868
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
2869
 
2870
  // get_rows
2871
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
 
2885
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2886
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2887
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2888
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2889
 
2890
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2891
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
 
2904
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2905
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2906
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2907
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2908
 
2909
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2910
  ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
 
3008
  CREATE_BINARY(div, _norepeat, {1})
3009
  #undef CREATE_BINARY
3010
 
3011
+ ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
3012
+
3013
  ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
3014
 
3015
  ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
3060
  CREATE_GLU(geglu)
3061
  CREATE_GLU(reglu)
3062
  CREATE_GLU(swiglu)
3063
+ CREATE_GLU(swiglu_oai)
3064
  CREATE_GLU(geglu_erf)
3065
  CREATE_GLU(geglu_quick)
3066
  #undef CREATE_GLU
 
3070
 
3071
  ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
3072
 
3073
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3074
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
3075
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3076
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
3077
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3078
 
3079
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 
4279
  case GGML_TYPE_IQ3_S:
4280
  case GGML_TYPE_IQ4_XS:
4281
  case GGML_TYPE_IQ4_NL:
4282
+ case GGML_TYPE_MXFP4:
4283
  break;
4284
  default:
4285
  return nullptr;
 
4350
  case GGML_TYPE_IQ3_S:
4351
  case GGML_TYPE_IQ4_XS:
4352
  case GGML_TYPE_IQ4_NL:
4353
+ case GGML_TYPE_MXFP4:
4354
  break;
4355
  default:
4356
  return nullptr;
 
4394
  case GGML_TYPE_IQ3_S:
4395
  case GGML_TYPE_IQ4_XS:
4396
  case GGML_TYPE_IQ4_NL:
4397
+ case GGML_TYPE_MXFP4:
4398
  break;
4399
  default:
4400
  return nullptr;
 
4449
  case GGML_TYPE_IQ3_S:
4450
  case GGML_TYPE_IQ4_XS:
4451
  case GGML_TYPE_IQ4_NL:
4452
+ case GGML_TYPE_MXFP4:
4453
  break;
4454
  default:
4455
  return nullptr;
 
4485
  case GGML_TYPE_IQ3_S:
4486
  case GGML_TYPE_IQ4_XS:
4487
  case GGML_TYPE_IQ4_NL:
4488
+ case GGML_TYPE_MXFP4:
4489
  break;
4490
  default:
4491
  return nullptr;
 
4671
  std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
4672
  GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
4673
  GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
4674
+ GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
4675
 
4676
  vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
4677
  vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
 
6888
  break;
6889
  }
6890
  return nullptr;
6891
+ case GGML_OP_ADD_ID:
6892
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
6893
+ return ctx->device->pipeline_add_id_f32;
6894
+ }
6895
+ return nullptr;
6896
  case GGML_OP_CONCAT:
6897
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6898
  return ctx->device->pipeline_concat_f32;
 
7038
  return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
7039
  case GGML_GLU_OP_SWIGLU:
7040
  return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
7041
+ case GGML_GLU_OP_SWIGLU_OAI:
7042
+ return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16];
7043
  case GGML_GLU_OP_GEGLU_ERF:
7044
  return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
7045
  case GGML_GLU_OP_GEGLU_QUICK:
 
7055
  return nullptr;
7056
  case GGML_OP_SOFT_MAX:
7057
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
7058
+ GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
7059
 
7060
  if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
7061
  return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
 
7226
  case GGML_OP_SUB:
7227
  case GGML_OP_MUL:
7228
  case GGML_OP_DIV:
7229
+ case GGML_OP_ADD_ID:
7230
  case GGML_OP_CONCAT:
7231
  case GGML_OP_UPSCALE:
7232
  case GGML_OP_SQR:
 
7573
  elements = { ne, 1, 1 };
7574
  }
7575
  } break;
7576
+ case GGML_OP_ADD_ID:
7577
+ {
7578
+ elements = { (uint32_t)ne01, (uint32_t)ne02, 1 };
7579
+ } break;
7580
  case GGML_OP_SET_ROWS:
7581
  {
7582
  uint32_t ne = ggml_nelements(src0);
 
7616
  }
7617
  }
7618
 
7619
+ if (op == GGML_OP_GLU) {
7620
+ // Empty src1 is possible in glu, but the shader needs a buffer
7621
  vk_subbuffer subbuf_y;
7622
  if (use_src1) {
7623
  subbuf_y = { d_Y, y_buf_offset, y_sz };
 
7627
 
7628
  ggml_vk_sync_buffers(subctx);
7629
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
7630
+ } else if (op == GGML_OP_SOFT_MAX) {
7631
+ // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
7632
+ vk_subbuffer subbuf_y;
7633
+ if (use_src1) {
7634
+ subbuf_y = { d_Y, y_buf_offset, y_sz };
7635
+ } else {
7636
+ subbuf_y = { d_X, 0, x_sz };
7637
+ }
7638
+
7639
+ vk_subbuffer subbuf_z;
7640
+ if (use_src2) {
7641
+ subbuf_z = { d_Z, z_buf_offset, z_sz };
7642
+ } else {
7643
+ subbuf_z = { d_X, 0, x_sz };
7644
+ }
7645
+
7646
+ ggml_vk_sync_buffers(subctx);
7647
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
7648
  } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
7649
  // Empty src2 is possible in rope, but the shader needs a buffer
7650
  vk_subbuffer subbuf_z;
 
7773
  }, dryrun);
7774
  }
7775
 
7776
+ static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
7777
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7778
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
7779
+ const uint32_t src2_type_size = ggml_type_size(src2->type);
7780
+
7781
+ ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, {
7782
+ (uint32_t)dst->ne[0],
7783
+ (uint32_t)dst->ne[1],
7784
+ (uint32_t)src0->nb[1] / src0_type_size,
7785
+ (uint32_t)src0->nb[2] / src0_type_size,
7786
+ (uint32_t)src1->nb[1] / src1_type_size,
7787
+ (uint32_t)src2->nb[1] / src2_type_size,
7788
+ }, dryrun);
7789
+ }
7790
+
7791
  static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
7792
  GGML_ASSERT(version == 6 || version == 7);
7793
  int num_srcs = version == 6 ? 6 : 7;
 
8206
  }
8207
 
8208
  static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
8209
+ const float * op_params_f = (const float *)dst->op_params;
8210
+
8211
  const bool swapped = (bool)dst->op_params[1];
8212
  const bool split = src1 != nullptr;
8213
+ const float alpha = op_params_f[2];
8214
+ const float limit = op_params_f[3];
8215
 
8216
  GGML_ASSERT(ggml_is_contiguous(src0));
8217
 
 
8225
 
8226
  const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
8227
 
8228
+ ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU,
8229
+ {
8230
+ (uint32_t)ggml_nelements(dst),
8231
+ (uint32_t)src0->ne[0],
8232
+ (uint32_t)dst->ne[0],
8233
+ mode,
8234
+ alpha,
8235
+ limit
8236
+ }, dryrun);
8237
  }
8238
 
8239
  static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
 
8241
  ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
8242
  }
8243
 
8244
+ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
8245
  float * op_params = (float *)dst->op_params;
8246
 
8247
  float scale = op_params[0];
 
8263
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
8264
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
8265
 
8266
+ ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
8267
  ncols,
8268
  src1 != nullptr ? nrows_y : (uint32_t)0,
8269
  (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
 
8273
  m0, m1,
8274
  n_head_log2,
8275
  nrows_x,
8276
+ src2 != nullptr
8277
  }, dryrun);
8278
  }
8279
 
 
9513
  case GGML_GLU_OP_GEGLU:
9514
  case GGML_GLU_OP_REGLU:
9515
  case GGML_GLU_OP_SWIGLU:
9516
+ case GGML_GLU_OP_SWIGLU_OAI:
9517
  case GGML_GLU_OP_GEGLU_ERF:
9518
  case GGML_GLU_OP_GEGLU_QUICK:
9519
  break;
 
9525
  case GGML_OP_REPEAT_BACK:
9526
  case GGML_OP_GET_ROWS:
9527
  case GGML_OP_ADD:
9528
+ case GGML_OP_ADD_ID:
9529
  case GGML_OP_ACC:
9530
  case GGML_OP_SUB:
9531
  case GGML_OP_MUL:
 
9680
  case GGML_OP_DIV:
9681
  ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
9682
 
9683
+ break;
9684
+ case GGML_OP_ADD_ID:
9685
+ ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
9686
+
9687
  break;
9688
  case GGML_OP_CONCAT:
9689
  ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
 
9781
  case GGML_GLU_OP_GEGLU:
9782
  case GGML_GLU_OP_REGLU:
9783
  case GGML_GLU_OP_SWIGLU:
9784
+ case GGML_GLU_OP_SWIGLU_OAI:
9785
  case GGML_GLU_OP_GEGLU_ERF:
9786
  case GGML_GLU_OP_GEGLU_QUICK:
9787
  ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
 
9795
 
9796
  break;
9797
  case GGML_OP_SOFT_MAX:
9798
+ ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
9799
 
9800
  break;
9801
  case GGML_OP_SOFT_MAX_BACK:
 
9941
  case GGML_OP_SUB:
9942
  case GGML_OP_MUL:
9943
  case GGML_OP_DIV:
9944
+ case GGML_OP_ADD_ID:
9945
  case GGML_OP_CONCAT:
9946
  case GGML_OP_UPSCALE:
9947
  case GGML_OP_SCALE:
 
10011
  case GGML_GLU_OP_GEGLU:
10012
  case GGML_GLU_OP_REGLU:
10013
  case GGML_GLU_OP_SWIGLU:
10014
+ case GGML_GLU_OP_SWIGLU_OAI:
10015
  case GGML_GLU_OP_GEGLU_ERF:
10016
  case GGML_GLU_OP_GEGLU_QUICK:
10017
  buf = tensor->buffer;
 
10861
  case GGML_GLU_OP_GEGLU:
10862
  case GGML_GLU_OP_REGLU:
10863
  case GGML_GLU_OP_SWIGLU:
10864
+ case GGML_GLU_OP_SWIGLU_OAI:
10865
  case GGML_GLU_OP_GEGLU_ERF:
10866
  case GGML_GLU_OP_GEGLU_QUICK:
10867
  return ggml_is_contiguous(op->src[0]) &&
 
10907
  case GGML_TYPE_IQ3_S:
10908
  case GGML_TYPE_IQ4_XS:
10909
  case GGML_TYPE_IQ4_NL:
10910
+ case GGML_TYPE_MXFP4:
10911
  break;
10912
  default:
10913
  return false;
 
10945
  if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
10946
  return false;
10947
  }
10948
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
10949
+ if (op->src[4]) {
10950
+ return false;
10951
+ }
10952
  if (op->src[0]->type != GGML_TYPE_F32) {
10953
  return false;
10954
  }
 
11021
  case GGML_TYPE_IQ3_S:
11022
  case GGML_TYPE_IQ4_XS:
11023
  case GGML_TYPE_IQ4_NL:
11024
+ case GGML_TYPE_MXFP4:
11025
  return true;
11026
  default:
11027
  return false;
 
11120
  return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
11121
  (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
11122
  (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
11123
+ case GGML_OP_ADD_ID:
11124
+ return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
11125
+ op->type == GGML_TYPE_F32;
11126
  case GGML_OP_SILU_BACK:
11127
  case GGML_OP_RMS_NORM_BACK:
11128
  case GGML_OP_SQR:
ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #extension GL_EXT_control_flow_attributes : require
4
+
5
+ #include "types.comp"
6
+
7
+ layout (push_constant) uniform parameter
8
+ {
9
+ uint ne0;
10
+ uint ne1;
11
+ uint s01;
12
+ uint s02;
13
+ uint s11;
14
+ uint s21;
15
+ } p;
16
+
17
+ #define BLOCK_SIZE 512
18
+
19
+ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
20
+
21
+ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
22
+ layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
23
+ layout (binding = 2) readonly buffer Z {int32_t data_c[];};
24
+ layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
25
+
26
+ void main() {
27
+ const uint i1 = gl_WorkGroupID.x;
28
+ const uint i2 = gl_WorkGroupID.y;
29
+
30
+ const uint i11 = data_c[i1 + i2 * p.s21];
31
+
32
+ const uint s1 = p.ne0;
33
+ const uint s2 = p.ne0 * p.ne1;
34
+
35
+ const uint d0 = i1 * s1 + i2 * s2;
36
+ const uint a0 = i1 * p.s01 + i2 * p.s02;
37
+ const uint b0 = i11 * p.s11;
38
+
39
+ for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) {
40
+ data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0];
41
+ }
42
+ }
ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp CHANGED
@@ -4,8 +4,8 @@
4
  #include "generic_unary_head.comp"
5
  #include "dequant_funcs.comp"
6
 
7
- #if defined(DATA_A_IQ4_NL)
8
- // 16 invocations needed for init_iq4nl_shmem
9
  layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
10
  #else
11
  layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
 
4
  #include "generic_unary_head.comp"
5
  #include "dequant_funcs.comp"
6
 
7
+ #if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
8
+ // 16 invocations needed for init_iq_shmem
9
  layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
10
  #else
11
  layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp CHANGED
@@ -434,6 +434,18 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
434
  }
435
  #endif
436
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  #if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
438
  vec2 get_dm(uint ib, uint a_offset) {
439
  return vec2(0, 0);
@@ -455,6 +467,12 @@ vec2 get_dm(uint ib, uint a_offset) {
455
  }
456
  #endif
457
 
 
 
 
 
 
 
458
  #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
459
  vec2 get_dm(uint ib, uint a_offset) {
460
  return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
 
434
  }
435
  #endif
436
 
437
+ #if defined(DATA_A_MXFP4)
438
+ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
439
+ const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
440
+ return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
441
+ }
442
+ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
443
+ vec2 v0 = dequantize(ib, iqs, a_offset);
444
+ vec2 v1 = dequantize(ib, iqs + 1, a_offset);
445
+ return vec4(v0.x, v0.y, v1.x, v1.y);
446
+ }
447
+ #endif
448
+
449
  #if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
450
  vec2 get_dm(uint ib, uint a_offset) {
451
  return vec2(0, 0);
 
467
  }
468
  #endif
469
 
470
+ #if defined(DATA_A_MXFP4)
471
+ vec2 get_dm(uint ib, uint a_offset) {
472
+ return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);
473
+ }
474
+ #endif
475
+
476
  #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
477
  vec2 get_dm(uint ib, uint a_offset) {
478
  return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp CHANGED
@@ -654,6 +654,25 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
654
  }
655
  #endif
656
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
  #if defined(DATA_A_Q4_0)
658
  #define dequantFuncA dequantFuncQ4_0
659
  #elif defined(DATA_A_Q4_1)
@@ -696,4 +715,6 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
696
  #define dequantFuncA dequantFuncIQ4_XS
697
  #elif defined(DATA_A_IQ4_NL)
698
  #define dequantFuncA dequantFuncIQ4_NL
 
 
699
  #endif
 
654
  }
655
  #endif
656
 
657
+ #if defined(DATA_A_MXFP4)
658
+ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
659
+ block_mxfp4 block;
660
+ };
661
+
662
+ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
663
+ {
664
+ const float d = e8m0_to_fp32(bl.block.e);
665
+ const uint idx = coordInBlock[1];
666
+ const uint iqs = idx & 0xF;
667
+ const uint shift = (idx & 0x10) >> 2;
668
+ uint32_t qs = bl.block.qs[iqs];
669
+ qs >>= shift;
670
+ qs &= 0xF;
671
+ float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
672
+ return ret;
673
+ }
674
+ #endif
675
+
676
  #if defined(DATA_A_Q4_0)
677
  #define dequantFuncA dequantFuncQ4_0
678
  #elif defined(DATA_A_Q4_1)
 
715
  #define dequantFuncA dequantFuncIQ4_XS
716
  #elif defined(DATA_A_IQ4_NL)
717
  #define dequantFuncA dequantFuncIQ4_NL
718
+ #elif defined(DATA_A_MXFP4)
719
+ #define dequantFuncA dequantFuncMXFP4
720
  #endif
ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #include "dequant_head.comp"
4
+
5
+ layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
6
+
7
+ layout (binding = 0) readonly buffer A {block_mxfp4 data_a[];};
8
+ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
9
+
10
+ void main() {
11
+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
12
+
13
+ init_iq_shmem(gl_WorkGroupSize);
14
+
15
+ const uint tid = gl_LocalInvocationID.x % 64;
16
+ const uint il = tid/32;
17
+ const uint ir = tid%32;
18
+ const uint ib = 32*i + ir;
19
+ if (ib >= p.nel / 32) {
20
+ return;
21
+ }
22
+
23
+ const uint q_idx = 8*il;
24
+ const uint b_idx = 1024*i + 32*ir + q_idx;
25
+
26
+ const float d = e8m0_to_fp32(data_a[ib].e);
27
+
28
+ [[unroll]] for (uint l = 0; l < 8; ++l) {
29
+ data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
30
+ data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
31
+ }
32
+ }