Spaces:
Running
llama : add gpt-oss (llama/15091)
Browse files* oai moe
* compat with new checkpoint
* add attn sink impl
* add rope scaling yarn
* logits match with latest transformers code
* wip chat template
* rm trailing space
* use ggml_scale_bias
* rm redundant is_swa_all
* convert interleaved gate_up
* graph : fix activation function to match reference (llama/7)
* vocab : handle o200k_harmony special tokens
* ggml : add attention sinks support (llama/1)
* llama : add attn sinks
* ggml : add attn sinks
* cuda : add attn sinks
* vulkan : add support for sinks in softmax
remove unnecessary return
* ggml : add fused swiglu_oai op (llama/11)
* ggml : add fused swiglu_oai op
* Update ggml/src/ggml-cpu/ops.cpp
Co-authored-by: Georgi Gerganov <[email protected]>
* update CUDA impl
* cont : metal impl
* add vulkan impl
* test-backend-ops : more test cases, clean up
* llama : remove unfused impl
* remove extra lines
---------
Co-authored-by: Georgi Gerganov <[email protected]>
---------
Co-authored-by: slaren <[email protected]>
* repack mxfp4 upon conversion
* clean up a bit
* enable thinking
* add quick hack to render only some special tokens
* fix bf16 conversion
* remove vocab hack
* webui ok
* support chat parsing for gpt-oss
* fix webui
* direct mapping mxfp4, FINALLY
* force using mxfp4
* properly use lazy tensor
* ggml : add mxfp4
ggml : use e8m0 conversion instead of powf
Co-authored-by: Diego Devesa <[email protected]>
change kvalues_mxfp4 table to match e2m1 (llama/6)
metal : remove quantization for now (not used)
cuda : fix disabled CUDA graphs due to ffn moe bias
vulkan : add support for mxfp4
cont : add cm2 dequant
* ggml : add ggml_add_id (llama/13)
* ggml : add ggml_add_id
* add cuda impl
* llama : add weight support check for add_id
* perf opt
* add vulkan impl
* rename cuda files
* add metal impl
* allow in-place ggml_add_id
* llama : keep biases on CPU with --cpu-moe
* llama : fix compile error
ggml-ci
* cuda : add fallback for __nv_cvt_e8m0_to_bf16raw
ggml-ci
* cleanup
ggml-ci
* sycl : fix supports_op for MXFP4
ggml-ci
* fix Unknown reasoning format
* ggml-cpu : fix AVX build
ggml-ci
* fix hip build
ggml-ci
* cuda : add mxfp4 dequantization support for cuBLAS
ggml-ci
* ggml-cpu : fix mxfp4 fallback definitions for some architectures
ggml-ci
* cuda : fix version required for __nv_cvt_e8m0_to_bf16raw
---------
Co-authored-by: Xuan Son Nguyen <[email protected]>
Co-authored-by: slaren <[email protected]>
- ggml/include/ggml.h +37 -1
- ggml/src/ggml-alloc.c +1 -0
- ggml/src/ggml-cann/ggml-cann.cpp +8 -0
- ggml/src/ggml-common.h +17 -0
- ggml/src/ggml-cpu/arch-fallback.h +6 -0
- ggml/src/ggml-cpu/arch/arm/quants.c +61 -0
- ggml/src/ggml-cpu/arch/x86/quants.c +96 -8
- ggml/src/ggml-cpu/ggml-cpu.c +14 -1
- ggml/src/ggml-cpu/ops.cpp +207 -9
- ggml/src/ggml-cpu/ops.h +2 -7
- ggml/src/ggml-cpu/quants.c +35 -0
- ggml/src/ggml-cpu/quants.h +8 -0
- ggml/src/ggml-cpu/vec.h +19 -4
- ggml/src/ggml-cuda/add-id.cu +58 -0
- ggml/src/ggml-cuda/add-id.cuh +3 -0
- ggml/src/ggml-cuda/common.cuh +26 -0
- ggml/src/ggml-cuda/convert.cu +28 -0
- ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- ggml/src/ggml-cuda/fattn-mma-f16.cuh +3 -1
- ggml/src/ggml-cuda/fattn-tile-f16.cu +2 -1
- ggml/src/ggml-cuda/fattn-tile-f32.cu +2 -1
- ggml/src/ggml-cuda/fattn-vec-f16.cuh +39 -3
- ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -2
- ggml/src/ggml-cuda/fattn-wmma-f16.cu +2 -1
- ggml/src/ggml-cuda/fattn.cu +16 -5
- ggml/src/ggml-cuda/ggml-cuda.cu +24 -1
- ggml/src/ggml-cuda/im2col.cu +3 -2
- ggml/src/ggml-cuda/mmq.cu +4 -0
- ggml/src/ggml-cuda/mmq.cuh +80 -2
- ggml/src/ggml-cuda/mmvq.cu +9 -0
- ggml/src/ggml-cuda/softmax.cu +16 -10
- ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- ggml/src/ggml-cuda/unary.cu +75 -0
- ggml/src/ggml-cuda/unary.cuh +2 -0
- ggml/src/ggml-cuda/vecdotq.cuh +52 -16
- ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- ggml/src/ggml-impl.h +61 -0
- ggml/src/ggml-metal/ggml-metal-impl.h +14 -0
- ggml/src/ggml-metal/ggml-metal.m +109 -9
- ggml/src/ggml-metal/ggml-metal.metal +272 -15
- ggml/src/ggml-opencl/ggml-opencl.cpp +2 -0
- ggml/src/ggml-quants.c +105 -11
- ggml/src/ggml-quants.h +6 -0
- ggml/src/ggml-sycl/ggml-sycl.cpp +10 -10
- ggml/src/ggml-vulkan/ggml-vulkan.cpp +129 -10
- ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
- ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
|
@@ -304,6 +304,16 @@
|
|
| 304 |
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
| 305 |
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
| 306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
|
| 308 |
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
| 309 |
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
|
@@ -395,7 +405,8 @@ extern "C" {
|
|
| 395 |
// GGML_TYPE_IQ4_NL_4_4 = 36,
|
| 396 |
// GGML_TYPE_IQ4_NL_4_8 = 37,
|
| 397 |
// GGML_TYPE_IQ4_NL_8_8 = 38,
|
| 398 |
-
|
|
|
|
| 399 |
};
|
| 400 |
|
| 401 |
// precision
|
|
@@ -430,6 +441,7 @@ extern "C" {
|
|
| 430 |
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
|
| 431 |
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
|
| 432 |
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
|
|
|
|
| 433 |
};
|
| 434 |
|
| 435 |
// available tensor operations:
|
|
@@ -438,6 +450,7 @@ extern "C" {
|
|
| 438 |
|
| 439 |
GGML_OP_DUP,
|
| 440 |
GGML_OP_ADD,
|
|
|
|
| 441 |
GGML_OP_ADD1,
|
| 442 |
GGML_OP_ACC,
|
| 443 |
GGML_OP_SUB,
|
|
@@ -557,6 +570,7 @@ extern "C" {
|
|
| 557 |
GGML_GLU_OP_REGLU,
|
| 558 |
GGML_GLU_OP_GEGLU,
|
| 559 |
GGML_GLU_OP_SWIGLU,
|
|
|
|
| 560 |
GGML_GLU_OP_GEGLU_ERF,
|
| 561 |
GGML_GLU_OP_GEGLU_QUICK,
|
| 562 |
|
|
@@ -831,6 +845,13 @@ extern "C" {
|
|
| 831 |
struct ggml_tensor * b,
|
| 832 |
enum ggml_type type);
|
| 833 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 834 |
GGML_API struct ggml_tensor * ggml_add1(
|
| 835 |
struct ggml_context * ctx,
|
| 836 |
struct ggml_tensor * a,
|
|
@@ -1198,6 +1219,13 @@ extern "C" {
|
|
| 1198 |
struct ggml_tensor * a,
|
| 1199 |
struct ggml_tensor * b);
|
| 1200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1201 |
// normalize along rows
|
| 1202 |
GGML_API struct ggml_tensor * ggml_norm(
|
| 1203 |
struct ggml_context * ctx,
|
|
@@ -1570,6 +1598,10 @@ extern "C" {
|
|
| 1570 |
float scale,
|
| 1571 |
float max_bias);
|
| 1572 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1573 |
GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
|
| 1574 |
struct ggml_context * ctx,
|
| 1575 |
struct ggml_tensor * a,
|
|
@@ -2052,6 +2084,10 @@ extern "C" {
|
|
| 2052 |
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
|
| 2053 |
const struct ggml_tensor * a);
|
| 2054 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2055 |
// TODO: needs to be adapted to ggml_flash_attn_ext
|
| 2056 |
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
| 2057 |
struct ggml_context * ctx,
|
|
|
|
| 304 |
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
| 305 |
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
| 306 |
|
| 307 |
+
#define GGML_TENSOR_TERNARY_OP_LOCALS \
|
| 308 |
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
| 309 |
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
| 310 |
+
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
|
| 311 |
+
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
|
| 312 |
+
GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
|
| 313 |
+
GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
|
| 314 |
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
| 315 |
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
| 316 |
+
|
| 317 |
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
|
| 318 |
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
| 319 |
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
|
|
|
| 405 |
// GGML_TYPE_IQ4_NL_4_4 = 36,
|
| 406 |
// GGML_TYPE_IQ4_NL_4_8 = 37,
|
| 407 |
// GGML_TYPE_IQ4_NL_8_8 = 38,
|
| 408 |
+
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
|
| 409 |
+
GGML_TYPE_COUNT = 40,
|
| 410 |
};
|
| 411 |
|
| 412 |
// precision
|
|
|
|
| 441 |
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
|
| 442 |
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
|
| 443 |
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
|
| 444 |
+
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
|
| 445 |
};
|
| 446 |
|
| 447 |
// available tensor operations:
|
|
|
|
| 450 |
|
| 451 |
GGML_OP_DUP,
|
| 452 |
GGML_OP_ADD,
|
| 453 |
+
GGML_OP_ADD_ID,
|
| 454 |
GGML_OP_ADD1,
|
| 455 |
GGML_OP_ACC,
|
| 456 |
GGML_OP_SUB,
|
|
|
|
| 570 |
GGML_GLU_OP_REGLU,
|
| 571 |
GGML_GLU_OP_GEGLU,
|
| 572 |
GGML_GLU_OP_SWIGLU,
|
| 573 |
+
GGML_GLU_OP_SWIGLU_OAI,
|
| 574 |
GGML_GLU_OP_GEGLU_ERF,
|
| 575 |
GGML_GLU_OP_GEGLU_QUICK,
|
| 576 |
|
|
|
|
| 845 |
struct ggml_tensor * b,
|
| 846 |
enum ggml_type type);
|
| 847 |
|
| 848 |
+
// dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
|
| 849 |
+
GGML_API struct ggml_tensor * ggml_add_id(
|
| 850 |
+
struct ggml_context * ctx,
|
| 851 |
+
struct ggml_tensor * a,
|
| 852 |
+
struct ggml_tensor * b,
|
| 853 |
+
struct ggml_tensor * ids);
|
| 854 |
+
|
| 855 |
GGML_API struct ggml_tensor * ggml_add1(
|
| 856 |
struct ggml_context * ctx,
|
| 857 |
struct ggml_tensor * a,
|
|
|
|
| 1219 |
struct ggml_tensor * a,
|
| 1220 |
struct ggml_tensor * b);
|
| 1221 |
|
| 1222 |
+
GGML_API struct ggml_tensor * ggml_swiglu_oai(
|
| 1223 |
+
struct ggml_context * ctx,
|
| 1224 |
+
struct ggml_tensor * a,
|
| 1225 |
+
struct ggml_tensor * b,
|
| 1226 |
+
float alpha,
|
| 1227 |
+
float limit);
|
| 1228 |
+
|
| 1229 |
// normalize along rows
|
| 1230 |
GGML_API struct ggml_tensor * ggml_norm(
|
| 1231 |
struct ggml_context * ctx,
|
|
|
|
| 1598 |
float scale,
|
| 1599 |
float max_bias);
|
| 1600 |
|
| 1601 |
+
GGML_API void ggml_soft_max_add_sinks(
|
| 1602 |
+
struct ggml_tensor * a,
|
| 1603 |
+
struct ggml_tensor * sinks);
|
| 1604 |
+
|
| 1605 |
GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
|
| 1606 |
struct ggml_context * ctx,
|
| 1607 |
struct ggml_tensor * a,
|
|
|
|
| 2084 |
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
|
| 2085 |
const struct ggml_tensor * a);
|
| 2086 |
|
| 2087 |
+
GGML_API void ggml_flash_attn_ext_add_sinks(
|
| 2088 |
+
struct ggml_tensor * a,
|
| 2089 |
+
struct ggml_tensor * sinks);
|
| 2090 |
+
|
| 2091 |
// TODO: needs to be adapted to ggml_flash_attn_ext
|
| 2092 |
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
| 2093 |
struct ggml_context * ctx,
|
|
@@ -29,6 +29,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
|
|
| 29 |
case GGML_OP_DIAG_MASK_ZERO:
|
| 30 |
case GGML_OP_DIAG_MASK_INF:
|
| 31 |
case GGML_OP_ADD:
|
|
|
|
| 32 |
case GGML_OP_ADD1:
|
| 33 |
case GGML_OP_SUB:
|
| 34 |
case GGML_OP_MUL:
|
|
|
|
| 29 |
case GGML_OP_DIAG_MASK_ZERO:
|
| 30 |
case GGML_OP_DIAG_MASK_INF:
|
| 31 |
case GGML_OP_ADD:
|
| 32 |
+
case GGML_OP_ADD_ID:
|
| 33 |
case GGML_OP_ADD1:
|
| 34 |
case GGML_OP_SUB:
|
| 35 |
case GGML_OP_MUL:
|
|
@@ -2340,6 +2340,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
|
| 2340 |
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
|
| 2341 |
return bias == 0.0f; // TODO: support bias != 0.0f
|
| 2342 |
case GGML_OP_SOFT_MAX:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2343 |
// TODO: support broadcast
|
| 2344 |
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
| 2345 |
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
|
|
@@ -2354,6 +2358,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
|
| 2354 |
if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
|
| 2355 |
return false;
|
| 2356 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2357 |
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
| 2358 |
// different head sizes of K and V are not supported yet
|
| 2359 |
return false;
|
|
|
|
| 2340 |
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
|
| 2341 |
return bias == 0.0f; // TODO: support bias != 0.0f
|
| 2342 |
case GGML_OP_SOFT_MAX:
|
| 2343 |
+
// TODO: support attention sinks [TAG_ATTN_SINKS]
|
| 2344 |
+
if (op->src[2]) {
|
| 2345 |
+
return false;
|
| 2346 |
+
}
|
| 2347 |
// TODO: support broadcast
|
| 2348 |
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
| 2349 |
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
|
|
|
|
| 2358 |
if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
|
| 2359 |
return false;
|
| 2360 |
}
|
| 2361 |
+
// TODO: support attention sinks [TAG_ATTN_SINKS]
|
| 2362 |
+
if (op->src[4]) {
|
| 2363 |
+
return false;
|
| 2364 |
+
}
|
| 2365 |
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
| 2366 |
// different head sizes of K and V are not supported yet
|
| 2367 |
return false;
|
|
@@ -99,6 +99,9 @@ typedef sycl::half2 ggml_half2;
|
|
| 99 |
#define QI4_1 (QK4_1 / (4 * QR4_1))
|
| 100 |
#define QR4_1 2
|
| 101 |
|
|
|
|
|
|
|
|
|
|
| 102 |
#define QI5_0 (QK5_0 / (4 * QR5_0))
|
| 103 |
#define QR5_0 2
|
| 104 |
|
|
@@ -184,6 +187,13 @@ typedef struct {
|
|
| 184 |
} block_q4_1;
|
| 185 |
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
#define QK5_0 32
|
| 188 |
typedef struct {
|
| 189 |
ggml_half d; // delta
|
|
@@ -1074,10 +1084,17 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
|
|
| 1074 |
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
| 1075 |
GGML_TABLE_END()
|
| 1076 |
|
|
|
|
| 1077 |
GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
|
| 1078 |
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
|
| 1079 |
GGML_TABLE_END()
|
| 1080 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1081 |
#define NGRID_IQ1S 2048
|
| 1082 |
#define IQ1S_DELTA 0.125f
|
| 1083 |
#define IQ1M_DELTA 0.125f
|
|
|
|
| 99 |
#define QI4_1 (QK4_1 / (4 * QR4_1))
|
| 100 |
#define QR4_1 2
|
| 101 |
|
| 102 |
+
#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
|
| 103 |
+
#define QR_MXFP4 2
|
| 104 |
+
|
| 105 |
#define QI5_0 (QK5_0 / (4 * QR5_0))
|
| 106 |
#define QR5_0 2
|
| 107 |
|
|
|
|
| 187 |
} block_q4_1;
|
| 188 |
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
|
| 189 |
|
| 190 |
+
#define QK_MXFP4 32
|
| 191 |
+
typedef struct {
|
| 192 |
+
uint8_t e; // E8M0
|
| 193 |
+
uint8_t qs[QK_MXFP4/2];
|
| 194 |
+
} block_mxfp4;
|
| 195 |
+
static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
|
| 196 |
+
|
| 197 |
#define QK5_0 32
|
| 198 |
typedef struct {
|
| 199 |
ggml_half d; // delta
|
|
|
|
| 1084 |
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
| 1085 |
GGML_TABLE_END()
|
| 1086 |
|
| 1087 |
+
// TODO: fix name to kvalues_iq4_nl
|
| 1088 |
GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
|
| 1089 |
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
|
| 1090 |
GGML_TABLE_END()
|
| 1091 |
|
| 1092 |
+
// e2m1 values (doubled)
|
| 1093 |
+
// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
| 1094 |
+
GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
|
| 1095 |
+
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
|
| 1096 |
+
GGML_TABLE_END()
|
| 1097 |
+
|
| 1098 |
#define NGRID_IQ1S 2048
|
| 1099 |
#define IQ1S_DELTA 0.125f
|
| 1100 |
#define IQ1M_DELTA 0.125f
|
|
@@ -13,6 +13,7 @@
|
|
| 13 |
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
|
| 14 |
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
|
| 15 |
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
|
|
|
|
| 16 |
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
| 17 |
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
| 18 |
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
|
@@ -68,6 +69,7 @@
|
|
| 68 |
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
| 69 |
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
| 70 |
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
|
|
|
| 71 |
// repack.cpp
|
| 72 |
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
| 73 |
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
|
@@ -90,6 +92,7 @@
|
|
| 90 |
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
| 91 |
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
| 92 |
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
|
|
|
| 93 |
// repack.cpp
|
| 94 |
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
| 95 |
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
|
@@ -120,6 +123,7 @@
|
|
| 120 |
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
| 121 |
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
|
| 122 |
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
|
|
|
| 123 |
// repack.cpp
|
| 124 |
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
| 125 |
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
|
@@ -149,6 +153,7 @@
|
|
| 149 |
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
|
| 150 |
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
|
| 151 |
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
|
|
|
| 152 |
// repack.cpp
|
| 153 |
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
| 154 |
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
|
@@ -179,6 +184,7 @@
|
|
| 179 |
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
| 180 |
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
|
| 181 |
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
|
|
|
| 182 |
// repack.cpp
|
| 183 |
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
| 184 |
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
|
|
|
| 13 |
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
|
| 14 |
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
|
| 15 |
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
|
| 16 |
+
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
| 17 |
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
| 18 |
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
| 19 |
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
|
|
|
| 69 |
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
| 70 |
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
| 71 |
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
| 72 |
+
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
| 73 |
// repack.cpp
|
| 74 |
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
| 75 |
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
|
|
|
| 92 |
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
| 93 |
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
| 94 |
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
| 95 |
+
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
| 96 |
// repack.cpp
|
| 97 |
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
| 98 |
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
|
|
|
| 123 |
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
| 124 |
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
|
| 125 |
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
| 126 |
+
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
| 127 |
// repack.cpp
|
| 128 |
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
| 129 |
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
|
|
|
| 153 |
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
|
| 154 |
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
|
| 155 |
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
| 156 |
+
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
| 157 |
// repack.cpp
|
| 158 |
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
| 159 |
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
|
|
|
| 184 |
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
| 185 |
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
|
| 186 |
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
| 187 |
+
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
| 188 |
// repack.cpp
|
| 189 |
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
| 190 |
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
|
@@ -589,6 +589,67 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
| 589 |
*s = sumf;
|
| 590 |
}
|
| 591 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 592 |
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 593 |
const int qk = QK8_0;
|
| 594 |
const int nb = n / qk;
|
|
|
|
| 589 |
*s = sumf;
|
| 590 |
}
|
| 591 |
|
| 592 |
+
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 593 |
+
assert(nrc == 1);
|
| 594 |
+
UNUSED(nrc);
|
| 595 |
+
UNUSED(bx);
|
| 596 |
+
UNUSED(by);
|
| 597 |
+
UNUSED(bs);
|
| 598 |
+
assert(n % QK_MXFP4 == 0);
|
| 599 |
+
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
|
| 600 |
+
|
| 601 |
+
const block_mxfp4 * GGML_RESTRICT x = vx;
|
| 602 |
+
const block_q8_0 * GGML_RESTRICT y = vy;
|
| 603 |
+
|
| 604 |
+
const int nb = n / QK_MXFP4;
|
| 605 |
+
|
| 606 |
+
int ib = 0;
|
| 607 |
+
float sumf = 0;
|
| 608 |
+
|
| 609 |
+
#if defined __ARM_NEON
|
| 610 |
+
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
|
| 611 |
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
| 612 |
+
uint8x16x2_t q4bits;
|
| 613 |
+
int8x16x4_t q4b;
|
| 614 |
+
int8x16x4_t q8b;
|
| 615 |
+
int32x4_t prod_1;
|
| 616 |
+
int32x4_t prod_2;
|
| 617 |
+
|
| 618 |
+
for (; ib + 1 < nb; ib += 2) {
|
| 619 |
+
q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
|
| 620 |
+
q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
|
| 621 |
+
q8b.val[0] = vld1q_s8(y[ib + 0].qs);
|
| 622 |
+
q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
|
| 623 |
+
q8b.val[2] = vld1q_s8(y[ib + 1].qs);
|
| 624 |
+
q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);
|
| 625 |
+
|
| 626 |
+
q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
|
| 627 |
+
q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
|
| 628 |
+
q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
|
| 629 |
+
q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
|
| 630 |
+
|
| 631 |
+
prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
|
| 632 |
+
prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
|
| 633 |
+
|
| 634 |
+
sumf +=
|
| 635 |
+
GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
|
| 636 |
+
GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
#endif
|
| 640 |
+
for (; ib < nb; ++ib) {
|
| 641 |
+
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
|
| 642 |
+
int sumi1 = 0;
|
| 643 |
+
int sumi2 = 0;
|
| 644 |
+
for (int j = 0; j < QK_MXFP4/2; ++j) {
|
| 645 |
+
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
|
| 646 |
+
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
|
| 647 |
+
}
|
| 648 |
+
sumf += d * (sumi1 + sumi2);
|
| 649 |
+
}
|
| 650 |
+
*s = sumf;
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 654 |
const int qk = QK8_0;
|
| 655 |
const int nb = n / qk;
|
|
@@ -66,6 +66,12 @@ static inline int hsum_i32_4(const __m128i a) {
|
|
| 66 |
}
|
| 67 |
|
| 68 |
#if defined(__AVX2__) || defined(__AVX512F__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
// spread 32 bits to 32 bytes { 0x00, 0xFF }
|
| 70 |
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
| 71 |
uint32_t x32;
|
|
@@ -261,6 +267,11 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const
|
|
| 261 |
return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
|
| 262 |
_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
|
| 263 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
#endif
|
| 265 |
#elif defined(__SSSE3__)
|
| 266 |
// horizontally add 4x4 floats
|
|
@@ -746,6 +757,91 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
| 746 |
#endif
|
| 747 |
}
|
| 748 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 750 |
const int qk = QK8_0;
|
| 751 |
const int nb = n / qk;
|
|
@@ -3206,14 +3302,6 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
| 3206 |
#endif
|
| 3207 |
}
|
| 3208 |
|
| 3209 |
-
#if defined(__AVX2__)
|
| 3210 |
-
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
| 3211 |
-
const __m256i ax = _mm256_sign_epi8(x, x);
|
| 3212 |
-
const __m256i sy = _mm256_sign_epi8(y, x);
|
| 3213 |
-
return _mm256_maddubs_epi16(ax, sy);
|
| 3214 |
-
}
|
| 3215 |
-
#endif
|
| 3216 |
-
|
| 3217 |
void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 3218 |
assert(n % QK_K == 0);
|
| 3219 |
assert(nrc == 1);
|
|
|
|
| 66 |
}
|
| 67 |
|
| 68 |
#if defined(__AVX2__) || defined(__AVX512F__)
|
| 69 |
+
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
| 70 |
+
const __m256i ax = _mm256_sign_epi8(x, x);
|
| 71 |
+
const __m256i sy = _mm256_sign_epi8(y, x);
|
| 72 |
+
return _mm256_maddubs_epi16(ax, sy);
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
// spread 32 bits to 32 bytes { 0x00, 0xFF }
|
| 76 |
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
| 77 |
uint32_t x32;
|
|
|
|
| 267 |
return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
|
| 268 |
_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
|
| 269 |
}
|
| 270 |
+
|
| 271 |
+
static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) {
|
| 272 |
+
return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
|
| 273 |
+
_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
|
| 274 |
+
}
|
| 275 |
#endif
|
| 276 |
#elif defined(__SSSE3__)
|
| 277 |
// horizontally add 4x4 floats
|
|
|
|
| 757 |
#endif
|
| 758 |
}
|
| 759 |
|
| 760 |
+
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 761 |
+
assert(nrc == 1);
|
| 762 |
+
UNUSED(nrc);
|
| 763 |
+
UNUSED(bx);
|
| 764 |
+
UNUSED(by);
|
| 765 |
+
UNUSED(bs);
|
| 766 |
+
assert(n % QK_MXFP4 == 0);
|
| 767 |
+
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
|
| 768 |
+
|
| 769 |
+
const block_mxfp4 * GGML_RESTRICT x = vx;
|
| 770 |
+
const block_q8_0 * GGML_RESTRICT y = vy;
|
| 771 |
+
|
| 772 |
+
const int nb = n / QK_MXFP4;
|
| 773 |
+
|
| 774 |
+
int ib = 0;
|
| 775 |
+
float sumf = 0;
|
| 776 |
+
|
| 777 |
+
#if defined __AVX2__
|
| 778 |
+
|
| 779 |
+
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
|
| 780 |
+
const __m128i m4b = _mm_set1_epi8(0x0f);
|
| 781 |
+
const __m256i mone = _mm256_set1_epi16(1);
|
| 782 |
+
|
| 783 |
+
__m256 accum1 = _mm256_setzero_ps();
|
| 784 |
+
__m256 accum2 = _mm256_setzero_ps();
|
| 785 |
+
for (; ib + 1 < nb; ib += 2) {
|
| 786 |
+
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
|
| 787 |
+
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
|
| 788 |
+
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
|
| 789 |
+
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
|
| 790 |
+
const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
|
| 791 |
+
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
|
| 792 |
+
const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
|
| 793 |
+
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
|
| 794 |
+
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
|
| 795 |
+
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
|
| 796 |
+
const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
|
| 797 |
+
const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
|
| 798 |
+
accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
|
| 799 |
+
_mm256_cvtepi32_ps(p_1), accum1);
|
| 800 |
+
accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
|
| 801 |
+
_mm256_cvtepi32_ps(p_2), accum2);
|
| 802 |
+
}
|
| 803 |
+
|
| 804 |
+
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
|
| 805 |
+
|
| 806 |
+
#elif defined __AVX__
|
| 807 |
+
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
|
| 808 |
+
const __m128i m4b = _mm_set1_epi8(0x0f);
|
| 809 |
+
|
| 810 |
+
__m256 accum = _mm256_setzero_ps();
|
| 811 |
+
for (; ib + 1 < nb; ib += 2) {
|
| 812 |
+
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
|
| 813 |
+
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
|
| 814 |
+
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
|
| 815 |
+
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
|
| 816 |
+
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
|
| 817 |
+
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
|
| 818 |
+
|
| 819 |
+
const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
|
| 820 |
+
const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
|
| 821 |
+
const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
|
| 822 |
+
const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
|
| 823 |
+
|
| 824 |
+
const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
|
| 825 |
+
const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d);
|
| 826 |
+
accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
sumf = hsum_float_8(accum);
|
| 830 |
+
|
| 831 |
+
#endif
|
| 832 |
+
for (; ib < nb; ++ib) {
|
| 833 |
+
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
|
| 834 |
+
int sumi1 = 0;
|
| 835 |
+
int sumi2 = 0;
|
| 836 |
+
for (int j = 0; j < QK_MXFP4/2; ++j) {
|
| 837 |
+
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
|
| 838 |
+
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
|
| 839 |
+
}
|
| 840 |
+
sumf += d * (sumi1 + sumi2);
|
| 841 |
+
}
|
| 842 |
+
*s = sumf;
|
| 843 |
+
}
|
| 844 |
+
|
| 845 |
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 846 |
const int qk = QK8_0;
|
| 847 |
const int nb = n / qk;
|
|
|
|
| 3302 |
#endif
|
| 3303 |
}
|
| 3304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3305 |
void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 3306 |
assert(n % QK_K == 0);
|
| 3307 |
assert(nrc == 1);
|
|
@@ -253,6 +253,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
|
| 253 |
.vec_dot_type = GGML_TYPE_Q8_1,
|
| 254 |
.nrows = 1,
|
| 255 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
[GGML_TYPE_Q2_K] = {
|
| 257 |
.from_float = quantize_row_q2_K,
|
| 258 |
.vec_dot = ggml_vec_dot_q2_K_q8_K,
|
|
@@ -1670,6 +1676,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 1670 |
{
|
| 1671 |
ggml_compute_forward_add(params, tensor);
|
| 1672 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1673 |
case GGML_OP_ADD1:
|
| 1674 |
{
|
| 1675 |
ggml_compute_forward_add1(params, tensor);
|
|
@@ -1924,7 +1934,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 1924 |
} break;
|
| 1925 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 1926 |
{
|
| 1927 |
-
ggml_compute_forward_flash_attn_ext(params, tensor
|
| 1928 |
} break;
|
| 1929 |
case GGML_OP_FLASH_ATTN_BACK:
|
| 1930 |
{
|
|
@@ -2111,6 +2121,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 2111 |
case GGML_OP_DUP:
|
| 2112 |
case GGML_OP_CONT:
|
| 2113 |
case GGML_OP_ADD:
|
|
|
|
| 2114 |
case GGML_OP_ADD1:
|
| 2115 |
case GGML_OP_ACC:
|
| 2116 |
{
|
|
@@ -2172,6 +2183,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 2172 |
case GGML_GLU_OP_REGLU:
|
| 2173 |
case GGML_GLU_OP_GEGLU:
|
| 2174 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
| 2175 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 2176 |
case GGML_GLU_OP_GEGLU_QUICK:
|
| 2177 |
{
|
|
@@ -2673,6 +2685,7 @@ struct ggml_cplan ggml_graph_plan(
|
|
| 2673 |
}
|
| 2674 |
} break;
|
| 2675 |
case GGML_OP_ADD:
|
|
|
|
| 2676 |
case GGML_OP_ADD1:
|
| 2677 |
{
|
| 2678 |
if (ggml_is_quantized(node->src[0]->type)) {
|
|
|
|
| 253 |
.vec_dot_type = GGML_TYPE_Q8_1,
|
| 254 |
.nrows = 1,
|
| 255 |
},
|
| 256 |
+
[GGML_TYPE_MXFP4] = {
|
| 257 |
+
.from_float = quantize_row_mxfp4,
|
| 258 |
+
.vec_dot = ggml_vec_dot_mxfp4_q8_0,
|
| 259 |
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
| 260 |
+
.nrows = 1,
|
| 261 |
+
},
|
| 262 |
[GGML_TYPE_Q2_K] = {
|
| 263 |
.from_float = quantize_row_q2_K,
|
| 264 |
.vec_dot = ggml_vec_dot_q2_K_q8_K,
|
|
|
|
| 1676 |
{
|
| 1677 |
ggml_compute_forward_add(params, tensor);
|
| 1678 |
} break;
|
| 1679 |
+
case GGML_OP_ADD_ID:
|
| 1680 |
+
{
|
| 1681 |
+
ggml_compute_forward_add_id(params, tensor);
|
| 1682 |
+
} break;
|
| 1683 |
case GGML_OP_ADD1:
|
| 1684 |
{
|
| 1685 |
ggml_compute_forward_add1(params, tensor);
|
|
|
|
| 1934 |
} break;
|
| 1935 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 1936 |
{
|
| 1937 |
+
ggml_compute_forward_flash_attn_ext(params, tensor);
|
| 1938 |
} break;
|
| 1939 |
case GGML_OP_FLASH_ATTN_BACK:
|
| 1940 |
{
|
|
|
|
| 2121 |
case GGML_OP_DUP:
|
| 2122 |
case GGML_OP_CONT:
|
| 2123 |
case GGML_OP_ADD:
|
| 2124 |
+
case GGML_OP_ADD_ID:
|
| 2125 |
case GGML_OP_ADD1:
|
| 2126 |
case GGML_OP_ACC:
|
| 2127 |
{
|
|
|
|
| 2183 |
case GGML_GLU_OP_REGLU:
|
| 2184 |
case GGML_GLU_OP_GEGLU:
|
| 2185 |
case GGML_GLU_OP_SWIGLU:
|
| 2186 |
+
case GGML_GLU_OP_SWIGLU_OAI:
|
| 2187 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 2188 |
case GGML_GLU_OP_GEGLU_QUICK:
|
| 2189 |
{
|
|
|
|
| 2685 |
}
|
| 2686 |
} break;
|
| 2687 |
case GGML_OP_ADD:
|
| 2688 |
+
case GGML_OP_ADD_ID:
|
| 2689 |
case GGML_OP_ADD1:
|
| 2690 |
{
|
| 2691 |
if (ggml_is_quantized(node->src[0]->type)) {
|
|
@@ -8,6 +8,7 @@
|
|
| 8 |
#include "vec.h"
|
| 9 |
|
| 10 |
#include <float.h>
|
|
|
|
| 11 |
|
| 12 |
// ggml_compute_forward_dup
|
| 13 |
|
|
@@ -1283,6 +1284,7 @@ void ggml_compute_forward_add(
|
|
| 1283 |
case GGML_TYPE_Q5_0:
|
| 1284 |
case GGML_TYPE_Q5_1:
|
| 1285 |
case GGML_TYPE_Q8_0:
|
|
|
|
| 1286 |
case GGML_TYPE_Q2_K:
|
| 1287 |
case GGML_TYPE_Q3_K:
|
| 1288 |
case GGML_TYPE_Q4_K:
|
|
@@ -1309,6 +1311,77 @@ void ggml_compute_forward_add(
|
|
| 1309 |
}
|
| 1310 |
}
|
| 1311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1312 |
// ggml_compute_forward_add1
|
| 1313 |
|
| 1314 |
static void ggml_compute_forward_add1_f32(
|
|
@@ -1660,6 +1733,7 @@ void ggml_compute_forward_add1(
|
|
| 1660 |
case GGML_TYPE_Q5_1:
|
| 1661 |
case GGML_TYPE_Q8_0:
|
| 1662 |
case GGML_TYPE_Q8_1:
|
|
|
|
| 1663 |
case GGML_TYPE_Q2_K:
|
| 1664 |
case GGML_TYPE_Q3_K:
|
| 1665 |
case GGML_TYPE_Q4_K:
|
|
@@ -1787,6 +1861,7 @@ void ggml_compute_forward_acc(
|
|
| 1787 |
case GGML_TYPE_Q5_1:
|
| 1788 |
case GGML_TYPE_Q8_0:
|
| 1789 |
case GGML_TYPE_Q8_1:
|
|
|
|
| 1790 |
case GGML_TYPE_Q2_K:
|
| 1791 |
case GGML_TYPE_Q3_K:
|
| 1792 |
case GGML_TYPE_Q4_K:
|
|
@@ -3614,6 +3689,93 @@ static void ggml_compute_forward_swiglu(
|
|
| 3614 |
}
|
| 3615 |
}
|
| 3616 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3617 |
// ggml_compute_forward_geglu_erf
|
| 3618 |
|
| 3619 |
static void ggml_compute_forward_geglu_erf_f32(
|
|
@@ -4599,6 +4761,7 @@ void ggml_compute_forward_out_prod(
|
|
| 4599 |
case GGML_TYPE_Q5_0:
|
| 4600 |
case GGML_TYPE_Q5_1:
|
| 4601 |
case GGML_TYPE_Q8_0:
|
|
|
|
| 4602 |
case GGML_TYPE_Q2_K:
|
| 4603 |
case GGML_TYPE_Q3_K:
|
| 4604 |
case GGML_TYPE_Q4_K:
|
|
@@ -4873,6 +5036,7 @@ void ggml_compute_forward_set(
|
|
| 4873 |
case GGML_TYPE_Q5_1:
|
| 4874 |
case GGML_TYPE_Q8_0:
|
| 4875 |
case GGML_TYPE_Q8_1:
|
|
|
|
| 4876 |
case GGML_TYPE_Q2_K:
|
| 4877 |
case GGML_TYPE_Q3_K:
|
| 4878 |
case GGML_TYPE_Q4_K:
|
|
@@ -5134,6 +5298,7 @@ void ggml_compute_forward_get_rows(
|
|
| 5134 |
case GGML_TYPE_Q5_1:
|
| 5135 |
case GGML_TYPE_Q8_0:
|
| 5136 |
case GGML_TYPE_Q8_1:
|
|
|
|
| 5137 |
case GGML_TYPE_Q2_K:
|
| 5138 |
case GGML_TYPE_Q3_K:
|
| 5139 |
case GGML_TYPE_Q4_K:
|
|
@@ -5523,6 +5688,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
| 5523 |
|
| 5524 |
const ggml_tensor * src0 = dst->src[0];
|
| 5525 |
const ggml_tensor * src1 = dst->src[1];
|
|
|
|
| 5526 |
|
| 5527 |
assert(ggml_is_contiguous(dst));
|
| 5528 |
assert(ggml_are_same_shape(src0, dst));
|
|
@@ -5557,6 +5723,9 @@ static void ggml_compute_forward_soft_max_f32(
|
|
| 5557 |
|
| 5558 |
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
| 5559 |
|
|
|
|
|
|
|
|
|
|
| 5560 |
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
| 5561 |
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
| 5562 |
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
@@ -5599,9 +5768,18 @@ static void ggml_compute_forward_soft_max_f32(
|
|
| 5599 |
float max = -INFINITY;
|
| 5600 |
ggml_vec_max_f32(ne00, &max, wp);
|
| 5601 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5602 |
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
|
| 5603 |
assert(sum > 0.0);
|
| 5604 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5605 |
sum = 1.0/sum;
|
| 5606 |
ggml_vec_scale_f32(ne00, dp, sum);
|
| 5607 |
|
|
@@ -5836,6 +6014,7 @@ void ggml_compute_forward_clamp(
|
|
| 5836 |
case GGML_TYPE_Q5_1:
|
| 5837 |
case GGML_TYPE_Q8_0:
|
| 5838 |
case GGML_TYPE_Q8_1:
|
|
|
|
| 5839 |
case GGML_TYPE_Q2_K:
|
| 5840 |
case GGML_TYPE_Q3_K:
|
| 5841 |
case GGML_TYPE_Q4_K:
|
|
@@ -7989,12 +8168,14 @@ void ggml_compute_forward_argsort(
|
|
| 7989 |
|
| 7990 |
static void ggml_compute_forward_flash_attn_ext_f16(
|
| 7991 |
const ggml_compute_params * params,
|
| 7992 |
-
const ggml_tensor * q,
|
| 7993 |
-
const ggml_tensor * k,
|
| 7994 |
-
const ggml_tensor * v,
|
| 7995 |
-
const ggml_tensor * mask,
|
| 7996 |
ggml_tensor * dst) {
|
| 7997 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7998 |
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
| 7999 |
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
| 8000 |
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
@@ -8189,6 +8370,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 8189 |
}
|
| 8190 |
}
|
| 8191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8192 |
// V /= S
|
| 8193 |
const float S_inv = 1.0f/S;
|
| 8194 |
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
|
@@ -8208,17 +8406,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 8208 |
|
| 8209 |
void ggml_compute_forward_flash_attn_ext(
|
| 8210 |
const ggml_compute_params * params,
|
| 8211 |
-
const ggml_tensor * q,
|
| 8212 |
-
const ggml_tensor * k,
|
| 8213 |
-
const ggml_tensor * v,
|
| 8214 |
-
const ggml_tensor * mask,
|
| 8215 |
ggml_tensor * dst) {
|
| 8216 |
switch (dst->op_params[3]) {
|
| 8217 |
case GGML_PREC_DEFAULT:
|
| 8218 |
case GGML_PREC_F32:
|
| 8219 |
{
|
| 8220 |
// uses F32 accumulators
|
| 8221 |
-
ggml_compute_forward_flash_attn_ext_f16(params,
|
| 8222 |
} break;
|
| 8223 |
default:
|
| 8224 |
{
|
|
@@ -9080,6 +9274,10 @@ void ggml_compute_forward_glu(
|
|
| 9080 |
{
|
| 9081 |
ggml_compute_forward_swiglu(params, dst);
|
| 9082 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9083 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 9084 |
{
|
| 9085 |
ggml_compute_forward_geglu_erf(params, dst);
|
|
|
|
| 8 |
#include "vec.h"
|
| 9 |
|
| 10 |
#include <float.h>
|
| 11 |
+
#include <algorithm>
|
| 12 |
|
| 13 |
// ggml_compute_forward_dup
|
| 14 |
|
|
|
|
| 1284 |
case GGML_TYPE_Q5_0:
|
| 1285 |
case GGML_TYPE_Q5_1:
|
| 1286 |
case GGML_TYPE_Q8_0:
|
| 1287 |
+
case GGML_TYPE_MXFP4:
|
| 1288 |
case GGML_TYPE_Q2_K:
|
| 1289 |
case GGML_TYPE_Q3_K:
|
| 1290 |
case GGML_TYPE_Q4_K:
|
|
|
|
| 1311 |
}
|
| 1312 |
}
|
| 1313 |
|
| 1314 |
+
// ggml_compute_forward_add_id
|
| 1315 |
+
|
| 1316 |
+
static void ggml_compute_forward_add_id_f32(
|
| 1317 |
+
const ggml_compute_params * params,
|
| 1318 |
+
ggml_tensor * dst) {
|
| 1319 |
+
|
| 1320 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 1321 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 1322 |
+
const ggml_tensor * src2 = dst->src[2];
|
| 1323 |
+
|
| 1324 |
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 1325 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 1326 |
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 1327 |
+
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
| 1328 |
+
|
| 1329 |
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
| 1330 |
+
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
| 1331 |
+
|
| 1332 |
+
const int ith = params->ith;
|
| 1333 |
+
const int nth = params->nth;
|
| 1334 |
+
|
| 1335 |
+
const int nr = ggml_nrows(src0);
|
| 1336 |
+
|
| 1337 |
+
GGML_TENSOR_TERNARY_OP_LOCALS
|
| 1338 |
+
|
| 1339 |
+
GGML_ASSERT( nb0 == sizeof(float));
|
| 1340 |
+
GGML_ASSERT(nb10 == sizeof(float));
|
| 1341 |
+
|
| 1342 |
+
// rows per thread
|
| 1343 |
+
const int dr = (nr + nth - 1)/nth;
|
| 1344 |
+
|
| 1345 |
+
// row range for this thread
|
| 1346 |
+
const int ir0 = dr*ith;
|
| 1347 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 1348 |
+
|
| 1349 |
+
for (int ir = ir0; ir < ir1; ++ir) {
|
| 1350 |
+
// src0 indices
|
| 1351 |
+
const int i3 = ir/(ne2*ne1);
|
| 1352 |
+
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
| 1353 |
+
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
| 1354 |
+
|
| 1355 |
+
// src1 indices
|
| 1356 |
+
const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
|
| 1357 |
+
|
| 1358 |
+
GGML_ASSERT(i11 >= 0 && i11 < ne11);
|
| 1359 |
+
|
| 1360 |
+
ggml_vec_add_f32(ne0,
|
| 1361 |
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
|
| 1362 |
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
|
| 1363 |
+
(float *) ((char *) src1->data + i11*nb11));
|
| 1364 |
+
}
|
| 1365 |
+
}
|
| 1366 |
+
|
| 1367 |
+
void ggml_compute_forward_add_id(
|
| 1368 |
+
const ggml_compute_params * params,
|
| 1369 |
+
ggml_tensor * dst) {
|
| 1370 |
+
|
| 1371 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 1372 |
+
|
| 1373 |
+
switch (src0->type) {
|
| 1374 |
+
case GGML_TYPE_F32:
|
| 1375 |
+
{
|
| 1376 |
+
ggml_compute_forward_add_id_f32(params, dst);
|
| 1377 |
+
} break;
|
| 1378 |
+
default:
|
| 1379 |
+
{
|
| 1380 |
+
GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
|
| 1381 |
+
}
|
| 1382 |
+
}
|
| 1383 |
+
}
|
| 1384 |
+
|
| 1385 |
// ggml_compute_forward_add1
|
| 1386 |
|
| 1387 |
static void ggml_compute_forward_add1_f32(
|
|
|
|
| 1733 |
case GGML_TYPE_Q5_1:
|
| 1734 |
case GGML_TYPE_Q8_0:
|
| 1735 |
case GGML_TYPE_Q8_1:
|
| 1736 |
+
case GGML_TYPE_MXFP4:
|
| 1737 |
case GGML_TYPE_Q2_K:
|
| 1738 |
case GGML_TYPE_Q3_K:
|
| 1739 |
case GGML_TYPE_Q4_K:
|
|
|
|
| 1861 |
case GGML_TYPE_Q5_1:
|
| 1862 |
case GGML_TYPE_Q8_0:
|
| 1863 |
case GGML_TYPE_Q8_1:
|
| 1864 |
+
case GGML_TYPE_MXFP4:
|
| 1865 |
case GGML_TYPE_Q2_K:
|
| 1866 |
case GGML_TYPE_Q3_K:
|
| 1867 |
case GGML_TYPE_Q4_K:
|
|
|
|
| 3689 |
}
|
| 3690 |
}
|
| 3691 |
|
| 3692 |
+
// ggml_compute_forward_swiglu_oai
|
| 3693 |
+
|
| 3694 |
+
static void ggml_compute_forward_swiglu_oai_f32(
|
| 3695 |
+
const ggml_compute_params * params,
|
| 3696 |
+
ggml_tensor * dst) {
|
| 3697 |
+
|
| 3698 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3699 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 3700 |
+
char * src0_d = (char *) src0->data;
|
| 3701 |
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
| 3702 |
+
const size_t src0_o = src0->nb[1];
|
| 3703 |
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
| 3704 |
+
|
| 3705 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 3706 |
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
| 3707 |
+
|
| 3708 |
+
if (src1) {
|
| 3709 |
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
| 3710 |
+
GGML_ASSERT(src0->type == src1->type);
|
| 3711 |
+
}
|
| 3712 |
+
|
| 3713 |
+
const int ith = params->ith;
|
| 3714 |
+
const int nth = params->nth;
|
| 3715 |
+
|
| 3716 |
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
| 3717 |
+
const int nr = ggml_nrows(src0);
|
| 3718 |
+
|
| 3719 |
+
GGML_ASSERT(dst->ne[0] == nc);
|
| 3720 |
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
| 3721 |
+
|
| 3722 |
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
| 3723 |
+
const float alpha = ggml_get_op_params_f32(dst, 2);
|
| 3724 |
+
const float limit = ggml_get_op_params_f32(dst, 3);
|
| 3725 |
+
|
| 3726 |
+
// rows per thread
|
| 3727 |
+
const int dr = (nr + nth - 1)/nth;
|
| 3728 |
+
|
| 3729 |
+
// row range for this thread
|
| 3730 |
+
const int ir0 = dr*ith;
|
| 3731 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 3732 |
+
|
| 3733 |
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 3734 |
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
| 3735 |
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
| 3736 |
+
float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
|
| 3737 |
+
|
| 3738 |
+
if (!src1) {
|
| 3739 |
+
src0_p += swapped ? nc : 0;
|
| 3740 |
+
src1_p += swapped ? 0 : nc;
|
| 3741 |
+
}
|
| 3742 |
+
|
| 3743 |
+
for (int k = 0; k < nc; k++) {
|
| 3744 |
+
const float x = std::min(src0_p[k], limit);
|
| 3745 |
+
const float y = std::clamp(src1_p[k], -limit, limit);
|
| 3746 |
+
const float out_glu = x / (1.f + expf(alpha * (-x)));
|
| 3747 |
+
dst_p[k] = out_glu * (y + 1.f);
|
| 3748 |
+
}
|
| 3749 |
+
|
| 3750 |
+
#ifndef NDEBUG
|
| 3751 |
+
for (int k = 0; k < nc; k++) {
|
| 3752 |
+
const float x = dst_p[k];
|
| 3753 |
+
GGML_UNUSED(x);
|
| 3754 |
+
assert(!isnan(x));
|
| 3755 |
+
assert(!isinf(x));
|
| 3756 |
+
}
|
| 3757 |
+
#endif
|
| 3758 |
+
}
|
| 3759 |
+
}
|
| 3760 |
+
|
| 3761 |
+
static void ggml_compute_forward_swiglu_oai(
|
| 3762 |
+
const ggml_compute_params * params,
|
| 3763 |
+
ggml_tensor * dst) {
|
| 3764 |
+
|
| 3765 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3766 |
+
|
| 3767 |
+
switch (src0->type) {
|
| 3768 |
+
case GGML_TYPE_F32:
|
| 3769 |
+
{
|
| 3770 |
+
ggml_compute_forward_swiglu_oai_f32(params, dst);
|
| 3771 |
+
} break;
|
| 3772 |
+
default:
|
| 3773 |
+
{
|
| 3774 |
+
GGML_ABORT("fatal error");
|
| 3775 |
+
}
|
| 3776 |
+
}
|
| 3777 |
+
}
|
| 3778 |
+
|
| 3779 |
// ggml_compute_forward_geglu_erf
|
| 3780 |
|
| 3781 |
static void ggml_compute_forward_geglu_erf_f32(
|
|
|
|
| 4761 |
case GGML_TYPE_Q5_0:
|
| 4762 |
case GGML_TYPE_Q5_1:
|
| 4763 |
case GGML_TYPE_Q8_0:
|
| 4764 |
+
case GGML_TYPE_MXFP4:
|
| 4765 |
case GGML_TYPE_Q2_K:
|
| 4766 |
case GGML_TYPE_Q3_K:
|
| 4767 |
case GGML_TYPE_Q4_K:
|
|
|
|
| 5036 |
case GGML_TYPE_Q5_1:
|
| 5037 |
case GGML_TYPE_Q8_0:
|
| 5038 |
case GGML_TYPE_Q8_1:
|
| 5039 |
+
case GGML_TYPE_MXFP4:
|
| 5040 |
case GGML_TYPE_Q2_K:
|
| 5041 |
case GGML_TYPE_Q3_K:
|
| 5042 |
case GGML_TYPE_Q4_K:
|
|
|
|
| 5298 |
case GGML_TYPE_Q5_1:
|
| 5299 |
case GGML_TYPE_Q8_0:
|
| 5300 |
case GGML_TYPE_Q8_1:
|
| 5301 |
+
case GGML_TYPE_MXFP4:
|
| 5302 |
case GGML_TYPE_Q2_K:
|
| 5303 |
case GGML_TYPE_Q3_K:
|
| 5304 |
case GGML_TYPE_Q4_K:
|
|
|
|
| 5688 |
|
| 5689 |
const ggml_tensor * src0 = dst->src[0];
|
| 5690 |
const ggml_tensor * src1 = dst->src[1];
|
| 5691 |
+
const ggml_tensor * src2 = dst->src[2];
|
| 5692 |
|
| 5693 |
assert(ggml_is_contiguous(dst));
|
| 5694 |
assert(ggml_are_same_shape(src0, dst));
|
|
|
|
| 5723 |
|
| 5724 |
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
| 5725 |
|
| 5726 |
+
// sinks
|
| 5727 |
+
const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
|
| 5728 |
+
|
| 5729 |
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
| 5730 |
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
| 5731 |
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
|
|
| 5768 |
float max = -INFINITY;
|
| 5769 |
ggml_vec_max_f32(ne00, &max, wp);
|
| 5770 |
|
| 5771 |
+
// if we have sinks, make a correction as if they were included in the softmax
|
| 5772 |
+
if (sk) {
|
| 5773 |
+
max = MAX(max, sk[i02]);
|
| 5774 |
+
}
|
| 5775 |
+
|
| 5776 |
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
|
| 5777 |
assert(sum > 0.0);
|
| 5778 |
|
| 5779 |
+
if (sk) {
|
| 5780 |
+
sum += (ggml_float) expf(sk[i02] - max);
|
| 5781 |
+
}
|
| 5782 |
+
|
| 5783 |
sum = 1.0/sum;
|
| 5784 |
ggml_vec_scale_f32(ne00, dp, sum);
|
| 5785 |
|
|
|
|
| 6014 |
case GGML_TYPE_Q5_1:
|
| 6015 |
case GGML_TYPE_Q8_0:
|
| 6016 |
case GGML_TYPE_Q8_1:
|
| 6017 |
+
case GGML_TYPE_MXFP4:
|
| 6018 |
case GGML_TYPE_Q2_K:
|
| 6019 |
case GGML_TYPE_Q3_K:
|
| 6020 |
case GGML_TYPE_Q4_K:
|
|
|
|
| 8168 |
|
| 8169 |
static void ggml_compute_forward_flash_attn_ext_f16(
|
| 8170 |
const ggml_compute_params * params,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8171 |
ggml_tensor * dst) {
|
| 8172 |
|
| 8173 |
+
const ggml_tensor * q = dst->src[0];
|
| 8174 |
+
const ggml_tensor * k = dst->src[1];
|
| 8175 |
+
const ggml_tensor * v = dst->src[2];
|
| 8176 |
+
const ggml_tensor * mask = dst->src[3];
|
| 8177 |
+
const ggml_tensor * sinks = dst->src[4];
|
| 8178 |
+
|
| 8179 |
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
| 8180 |
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
| 8181 |
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
|
|
| 8370 |
}
|
| 8371 |
}
|
| 8372 |
|
| 8373 |
+
// sinks
|
| 8374 |
+
if (sinks) {
|
| 8375 |
+
const float s = ((float *)((char *) sinks->data))[h];
|
| 8376 |
+
|
| 8377 |
+
float ms = 1.0f;
|
| 8378 |
+
float vs = 1.0f;
|
| 8379 |
+
|
| 8380 |
+
if (s > M) {
|
| 8381 |
+
ms = expf(M - s);
|
| 8382 |
+
ggml_vec_scale_f32(DV, VKQ32, ms);
|
| 8383 |
+
} else {
|
| 8384 |
+
vs = expf(s - M);
|
| 8385 |
+
}
|
| 8386 |
+
|
| 8387 |
+
S = S*ms + vs;
|
| 8388 |
+
}
|
| 8389 |
+
|
| 8390 |
// V /= S
|
| 8391 |
const float S_inv = 1.0f/S;
|
| 8392 |
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
|
|
|
| 8406 |
|
| 8407 |
void ggml_compute_forward_flash_attn_ext(
|
| 8408 |
const ggml_compute_params * params,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8409 |
ggml_tensor * dst) {
|
| 8410 |
switch (dst->op_params[3]) {
|
| 8411 |
case GGML_PREC_DEFAULT:
|
| 8412 |
case GGML_PREC_F32:
|
| 8413 |
{
|
| 8414 |
// uses F32 accumulators
|
| 8415 |
+
ggml_compute_forward_flash_attn_ext_f16(params, dst);
|
| 8416 |
} break;
|
| 8417 |
default:
|
| 8418 |
{
|
|
|
|
| 9274 |
{
|
| 9275 |
ggml_compute_forward_swiglu(params, dst);
|
| 9276 |
} break;
|
| 9277 |
+
case GGML_GLU_OP_SWIGLU_OAI:
|
| 9278 |
+
{
|
| 9279 |
+
ggml_compute_forward_swiglu_oai(params, dst);
|
| 9280 |
+
} break;
|
| 9281 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 9282 |
{
|
| 9283 |
ggml_compute_forward_geglu_erf(params, dst);
|
|
@@ -29,6 +29,7 @@ extern "C" {
|
|
| 29 |
|
| 30 |
void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 31 |
void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
|
|
| 32 |
void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 33 |
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 34 |
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
@@ -82,13 +83,7 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru
|
|
| 82 |
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 83 |
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 84 |
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 85 |
-
void ggml_compute_forward_flash_attn_ext(
|
| 86 |
-
const struct ggml_compute_params * params,
|
| 87 |
-
const struct ggml_tensor * q,
|
| 88 |
-
const struct ggml_tensor * k,
|
| 89 |
-
const struct ggml_tensor * v,
|
| 90 |
-
const struct ggml_tensor * mask,
|
| 91 |
-
struct ggml_tensor * dst);
|
| 92 |
void ggml_compute_forward_flash_attn_back(
|
| 93 |
const struct ggml_compute_params * params,
|
| 94 |
const bool masked,
|
|
|
|
| 29 |
|
| 30 |
void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 31 |
void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 32 |
+
void ggml_compute_forward_add_id(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 33 |
void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 34 |
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 35 |
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
|
|
| 83 |
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 84 |
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 85 |
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 86 |
+
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
void ggml_compute_forward_flash_attn_back(
|
| 88 |
const struct ggml_compute_params * params,
|
| 89 |
const bool masked,
|
|
@@ -46,6 +46,10 @@ void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI
|
|
| 46 |
quantize_row_q8_1_ref(x, y, k);
|
| 47 |
}
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
//
|
| 50 |
// 2-6 bit quantization in super-blocks
|
| 51 |
//
|
|
@@ -181,6 +185,37 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
|
| 181 |
*s = sumf;
|
| 182 |
}
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 185 |
const int qk = QK8_0;
|
| 186 |
const int nb = n / qk;
|
|
|
|
| 46 |
quantize_row_q8_1_ref(x, y, k);
|
| 47 |
}
|
| 48 |
|
| 49 |
+
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
|
| 50 |
+
quantize_row_mxfp4_ref(x, y, k);
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
//
|
| 54 |
// 2-6 bit quantization in super-blocks
|
| 55 |
//
|
|
|
|
| 185 |
*s = sumf;
|
| 186 |
}
|
| 187 |
|
| 188 |
+
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 189 |
+
assert(nrc == 1);
|
| 190 |
+
UNUSED(nrc);
|
| 191 |
+
UNUSED(bx);
|
| 192 |
+
UNUSED(by);
|
| 193 |
+
UNUSED(bs);
|
| 194 |
+
assert(n % QK_MXFP4 == 0);
|
| 195 |
+
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
|
| 196 |
+
|
| 197 |
+
const block_mxfp4 * GGML_RESTRICT x = vx;
|
| 198 |
+
const block_q8_0 * GGML_RESTRICT y = vy;
|
| 199 |
+
|
| 200 |
+
const int nb = n / QK_MXFP4;
|
| 201 |
+
|
| 202 |
+
int ib = 0;
|
| 203 |
+
float sumf = 0;
|
| 204 |
+
|
| 205 |
+
for (; ib < nb; ++ib) {
|
| 206 |
+
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
|
| 207 |
+
|
| 208 |
+
int sumi1 = 0;
|
| 209 |
+
int sumi2 = 0;
|
| 210 |
+
for (int j = 0; j < QK_MXFP4/2; ++j) {
|
| 211 |
+
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
|
| 212 |
+
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
|
| 213 |
+
}
|
| 214 |
+
sumf += d * (sumi1 + sumi2);
|
| 215 |
+
}
|
| 216 |
+
*s = sumf;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 220 |
const int qk = QK8_0;
|
| 221 |
const int nb = n / qk;
|
|
@@ -19,6 +19,8 @@ void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
|
|
| 19 |
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 20 |
void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 21 |
|
|
|
|
|
|
|
| 22 |
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 23 |
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 24 |
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
|
@@ -39,6 +41,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
| 39 |
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 40 |
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 41 |
|
|
|
|
|
|
|
| 42 |
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 43 |
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 44 |
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
|
@@ -67,8 +71,12 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
|
| 67 |
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 68 |
void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 69 |
void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
|
|
|
|
|
|
|
|
|
| 70 |
void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 71 |
void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
|
|
|
| 72 |
void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 73 |
void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 74 |
void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
|
|
|
| 19 |
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 20 |
void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 21 |
|
| 22 |
+
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 23 |
+
|
| 24 |
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 25 |
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 26 |
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
|
|
|
| 41 |
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 42 |
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 43 |
|
| 44 |
+
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 45 |
+
|
| 46 |
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 47 |
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 48 |
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
|
|
|
| 71 |
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 72 |
void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 73 |
void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 74 |
+
|
| 75 |
+
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 76 |
+
|
| 77 |
void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 78 |
void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 79 |
+
|
| 80 |
void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 81 |
void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 82 |
void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
|
@@ -55,7 +55,22 @@ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x)
|
|
| 55 |
|
| 56 |
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
| 57 |
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
|
| 60 |
for (int i = 0; i < n; ++i) {
|
| 61 |
z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
|
|
@@ -992,9 +1007,9 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
|
|
| 992 |
|
| 993 |
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
| 994 |
for (int i = 0; i < n; ++i) {
|
| 995 |
-
float
|
| 996 |
-
float
|
| 997 |
-
y[i] = GGML_CPU_FP32_TO_FP16((
|
| 998 |
}
|
| 999 |
}
|
| 1000 |
|
|
|
|
| 55 |
|
| 56 |
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
| 57 |
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
| 58 |
+
|
| 59 |
+
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
|
| 60 |
+
int i = 0;
|
| 61 |
+
#if defined(__AVX2__)
|
| 62 |
+
for (; i + 7 < n; i += 8) {
|
| 63 |
+
__m256 vx = _mm256_loadu_ps(x + i);
|
| 64 |
+
__m256 vy = _mm256_loadu_ps(y + i);
|
| 65 |
+
__m256 vz = _mm256_add_ps(vx, vy);
|
| 66 |
+
_mm256_storeu_ps(z + i, vz);
|
| 67 |
+
}
|
| 68 |
+
#endif
|
| 69 |
+
for (; i < n; ++i) {
|
| 70 |
+
z[i] = x[i] + y[i];
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
|
| 75 |
for (int i = 0; i < n; ++i) {
|
| 76 |
z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
|
|
|
|
| 1007 |
|
| 1008 |
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
| 1009 |
for (int i = 0; i < n; ++i) {
|
| 1010 |
+
float xi = GGML_CPU_FP16_TO_FP32(x[i]);
|
| 1011 |
+
float gi = GGML_CPU_FP16_TO_FP32(g[i]);
|
| 1012 |
+
y[i] = GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi);
|
| 1013 |
}
|
| 1014 |
}
|
| 1015 |
|
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "add-id.cuh"
|
| 2 |
+
|
| 3 |
+
static __global__ void add_id_kernel(
|
| 4 |
+
const float * src0, const float * src1, const int32_t * src2, float * dst,
|
| 5 |
+
int64_t ne0, int64_t ne1,
|
| 6 |
+
size_t nb01, size_t nb02,
|
| 7 |
+
size_t nb11,
|
| 8 |
+
size_t nb21
|
| 9 |
+
) {
|
| 10 |
+
|
| 11 |
+
const int64_t i1 = blockIdx.x;
|
| 12 |
+
const int64_t i2 = blockIdx.y;
|
| 13 |
+
|
| 14 |
+
const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21);
|
| 15 |
+
|
| 16 |
+
const size_t nb1 = ne0 * sizeof(float);
|
| 17 |
+
const size_t nb2 = ne1 * nb1;
|
| 18 |
+
|
| 19 |
+
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
|
| 20 |
+
const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02);
|
| 21 |
+
const float * src1_row = (const float *)((char *)src1 + i11*nb11);
|
| 22 |
+
|
| 23 |
+
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
|
| 24 |
+
dst_row[i0] = src0_row[i0] + src1_row[i0];
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 29 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 30 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 31 |
+
const ggml_tensor * src2 = dst->src[2];
|
| 32 |
+
|
| 33 |
+
GGML_TENSOR_TERNARY_OP_LOCALS
|
| 34 |
+
|
| 35 |
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 36 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 37 |
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 38 |
+
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
| 39 |
+
|
| 40 |
+
GGML_ASSERT(nb00 == sizeof(float));
|
| 41 |
+
GGML_ASSERT(nb10 == sizeof(float));
|
| 42 |
+
GGML_ASSERT(nb20 == sizeof(int32_t));
|
| 43 |
+
|
| 44 |
+
const float * src0_d = (const float *)src0->data;
|
| 45 |
+
const float * src1_d = (const float *)src1->data;
|
| 46 |
+
const int32_t * src2_d = (const int32_t *)src2->data;
|
| 47 |
+
float * dst_d = (float *)dst->data;
|
| 48 |
+
|
| 49 |
+
int threads = std::min((int)ne00, 768); // cols
|
| 50 |
+
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
|
| 51 |
+
add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
|
| 52 |
+
src0_d, src1_d, src2_d, dst_d,
|
| 53 |
+
ne0, ne1,
|
| 54 |
+
nb01, nb02,
|
| 55 |
+
nb11,
|
| 56 |
+
nb21
|
| 57 |
+
);
|
| 58 |
+
}
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
|
| 3 |
+
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@@ -1,6 +1,7 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include "ggml.h"
|
|
|
|
| 4 |
#include "ggml-cuda.h"
|
| 5 |
|
| 6 |
#include <cstdint>
|
|
@@ -549,6 +550,24 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
|
| 549 |
#endif // defined(GGML_USE_HIP)
|
| 550 |
}
|
| 551 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
| 553 |
|
| 554 |
static __device__ __forceinline__ float get_alibi_slope(
|
|
@@ -607,6 +626,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
|
|
| 607 |
static constexpr int qi = QI8_0;
|
| 608 |
};
|
| 609 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
template<>
|
| 611 |
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
|
| 612 |
static constexpr int qk = QK_K;
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include "ggml.h"
|
| 4 |
+
#include "ggml-impl.h"
|
| 5 |
#include "ggml-cuda.h"
|
| 6 |
|
| 7 |
#include <cstdint>
|
|
|
|
| 550 |
#endif // defined(GGML_USE_HIP)
|
| 551 |
}
|
| 552 |
|
| 553 |
+
static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
| 554 |
+
#if CUDART_VERSION >= 12080
|
| 555 |
+
const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
|
| 556 |
+
return (float) e;
|
| 557 |
+
#else
|
| 558 |
+
uint32_t bits;
|
| 559 |
+
if (x == 0) {
|
| 560 |
+
bits = 0x00400000;
|
| 561 |
+
} else {
|
| 562 |
+
bits = (uint32_t) x << 23;
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
float result;
|
| 566 |
+
memcpy(&result, &bits, sizeof(float));
|
| 567 |
+
return result;
|
| 568 |
+
#endif // CUDART_VERSION >= 12050
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
| 572 |
|
| 573 |
static __device__ __forceinline__ float get_alibi_slope(
|
|
|
|
| 626 |
static constexpr int qi = QI8_0;
|
| 627 |
};
|
| 628 |
|
| 629 |
+
template<>
|
| 630 |
+
struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
|
| 631 |
+
static constexpr int qk = QK_MXFP4;
|
| 632 |
+
static constexpr int qr = QR_MXFP4;
|
| 633 |
+
static constexpr int qi = QI_MXFP4;
|
| 634 |
+
};
|
| 635 |
+
|
| 636 |
template<>
|
| 637 |
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
|
| 638 |
static constexpr int qk = QK_K;
|
|
@@ -465,6 +465,24 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
|
|
| 465 |
}
|
| 466 |
}
|
| 467 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
| 469 |
static void dequantize_block_cuda(const void * vx, dst_t * y,
|
| 470 |
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
|
@@ -588,6 +606,12 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
|
|
| 588 |
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
| 589 |
}
|
| 590 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 591 |
template <typename src_t, typename dst_t>
|
| 592 |
static __global__ void convert_unary(
|
| 593 |
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
|
@@ -677,6 +701,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|
| 677 |
return dequantize_row_iq4_xs_cuda;
|
| 678 |
case GGML_TYPE_IQ3_S:
|
| 679 |
return dequantize_row_iq3_s_cuda;
|
|
|
|
|
|
|
| 680 |
case GGML_TYPE_F32:
|
| 681 |
return convert_unary_cont_cuda<float>;
|
| 682 |
case GGML_TYPE_BF16:
|
|
@@ -726,6 +752,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|
| 726 |
return dequantize_row_iq4_xs_cuda;
|
| 727 |
case GGML_TYPE_IQ3_S:
|
| 728 |
return dequantize_row_iq3_s_cuda;
|
|
|
|
|
|
|
| 729 |
case GGML_TYPE_F16:
|
| 730 |
return convert_unary_cont_cuda<half>;
|
| 731 |
case GGML_TYPE_BF16:
|
|
|
|
| 465 |
}
|
| 466 |
}
|
| 467 |
|
| 468 |
+
template<typename dst_t>
|
| 469 |
+
static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
| 470 |
+
|
| 471 |
+
const int64_t i = blockIdx.x;
|
| 472 |
+
const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
|
| 473 |
+
|
| 474 |
+
const int64_t tid = threadIdx.x;
|
| 475 |
+
const int64_t il = tid/8; // 0...3
|
| 476 |
+
const int64_t ib = tid%8; // 0...7
|
| 477 |
+
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
| 478 |
+
const uint8_t * q4 = x[ib].qs + 4*il;
|
| 479 |
+
const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
|
| 480 |
+
for (int j = 0; j < 4; ++j) {
|
| 481 |
+
y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
|
| 482 |
+
y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f;
|
| 483 |
+
}
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
| 487 |
static void dequantize_block_cuda(const void * vx, dst_t * y,
|
| 488 |
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
|
|
|
| 606 |
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
| 607 |
}
|
| 608 |
|
| 609 |
+
template<typename dst_t>
|
| 610 |
+
static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
| 611 |
+
const int nb = (k + QK_K - 1) / QK_K;
|
| 612 |
+
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
template <typename src_t, typename dst_t>
|
| 616 |
static __global__ void convert_unary(
|
| 617 |
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
|
|
|
| 701 |
return dequantize_row_iq4_xs_cuda;
|
| 702 |
case GGML_TYPE_IQ3_S:
|
| 703 |
return dequantize_row_iq3_s_cuda;
|
| 704 |
+
case GGML_TYPE_MXFP4:
|
| 705 |
+
return dequantize_row_mxfp4_cuda;
|
| 706 |
case GGML_TYPE_F32:
|
| 707 |
return convert_unary_cont_cuda<float>;
|
| 708 |
case GGML_TYPE_BF16:
|
|
|
|
| 752 |
return dequantize_row_iq4_xs_cuda;
|
| 753 |
case GGML_TYPE_IQ3_S:
|
| 754 |
return dequantize_row_iq3_s_cuda;
|
| 755 |
+
case GGML_TYPE_MXFP4:
|
| 756 |
+
return dequantize_row_mxfp4_cuda;
|
| 757 |
case GGML_TYPE_F16:
|
| 758 |
return convert_unary_cont_cuda<half>;
|
| 759 |
case GGML_TYPE_BF16:
|
|
@@ -15,6 +15,7 @@ typedef void (* fattn_kernel_t)(
|
|
| 15 |
const char * __restrict__ K,
|
| 16 |
const char * __restrict__ V,
|
| 17 |
const char * __restrict__ mask,
|
|
|
|
| 18 |
const int * __restrict__ KV_max,
|
| 19 |
float * __restrict__ dst,
|
| 20 |
float2 * __restrict__ dst_meta,
|
|
@@ -736,7 +737,8 @@ void launch_fattn(
|
|
| 736 |
|
| 737 |
GGML_ASSERT(V || is_mla);
|
| 738 |
|
| 739 |
-
const ggml_tensor * mask
|
|
|
|
| 740 |
|
| 741 |
ggml_tensor * KQV = dst;
|
| 742 |
|
|
@@ -940,6 +942,7 @@ void launch_fattn(
|
|
| 940 |
K_data,
|
| 941 |
V_data,
|
| 942 |
mask ? ((const char *) mask->data) : nullptr,
|
|
|
|
| 943 |
KV_max.ptr,
|
| 944 |
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
| 945 |
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
|
|
| 15 |
const char * __restrict__ K,
|
| 16 |
const char * __restrict__ V,
|
| 17 |
const char * __restrict__ mask,
|
| 18 |
+
const char * __restrict__ sinks,
|
| 19 |
const int * __restrict__ KV_max,
|
| 20 |
float * __restrict__ dst,
|
| 21 |
float2 * __restrict__ dst_meta,
|
|
|
|
| 737 |
|
| 738 |
GGML_ASSERT(V || is_mla);
|
| 739 |
|
| 740 |
+
const ggml_tensor * mask = dst->src[3];
|
| 741 |
+
const ggml_tensor * sinks = dst->src[4];
|
| 742 |
|
| 743 |
ggml_tensor * KQV = dst;
|
| 744 |
|
|
|
|
| 942 |
K_data,
|
| 943 |
V_data,
|
| 944 |
mask ? ((const char *) mask->data) : nullptr,
|
| 945 |
+
sinks ? ((const char *) sinks->data) : nullptr,
|
| 946 |
KV_max.ptr,
|
| 947 |
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
| 948 |
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
@@ -1206,6 +1206,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1206 |
const char * __restrict__ K,
|
| 1207 |
const char * __restrict__ V,
|
| 1208 |
const char * __restrict__ mask,
|
|
|
|
| 1209 |
const int * __restrict__ KV_max,
|
| 1210 |
float * __restrict__ dst,
|
| 1211 |
float2 * __restrict__ dst_meta,
|
|
@@ -1267,6 +1268,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1267 |
// kb0 == k start index when in the output tile.
|
| 1268 |
int kb0_start = kbc % iter_k;
|
| 1269 |
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
|
|
|
| 1270 |
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
| 1271 |
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
| 1272 |
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
|
@@ -1340,7 +1342,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1340 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1341 |
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1342 |
#else
|
| 1343 |
-
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
| 1344 |
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
| 1345 |
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
| 1346 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
|
|
| 1206 |
const char * __restrict__ K,
|
| 1207 |
const char * __restrict__ V,
|
| 1208 |
const char * __restrict__ mask,
|
| 1209 |
+
const char * __restrict__ sinks,
|
| 1210 |
const int * __restrict__ KV_max,
|
| 1211 |
float * __restrict__ dst,
|
| 1212 |
float2 * __restrict__ dst_meta,
|
|
|
|
| 1268 |
// kb0 == k start index when in the output tile.
|
| 1269 |
int kb0_start = kbc % iter_k;
|
| 1270 |
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
| 1271 |
+
|
| 1272 |
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
| 1273 |
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
| 1274 |
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
|
|
|
| 1342 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1343 |
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1344 |
#else
|
| 1345 |
+
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
| 1346 |
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
| 1347 |
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
| 1348 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 13 |
const char * __restrict__ K,
|
| 14 |
const char * __restrict__ V,
|
| 15 |
const char * __restrict__ mask,
|
|
|
|
| 16 |
const int * __restrict__ KV_max,
|
| 17 |
float * __restrict__ dst,
|
| 18 |
float2 * __restrict__ dst_meta,
|
|
@@ -272,7 +273,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 272 |
}
|
| 273 |
}
|
| 274 |
#else
|
| 275 |
-
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
| 276 |
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
| 277 |
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
| 278 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
|
|
| 13 |
const char * __restrict__ K,
|
| 14 |
const char * __restrict__ V,
|
| 15 |
const char * __restrict__ mask,
|
| 16 |
+
const char * __restrict__ sinks,
|
| 17 |
const int * __restrict__ KV_max,
|
| 18 |
float * __restrict__ dst,
|
| 19 |
float2 * __restrict__ dst_meta,
|
|
|
|
| 273 |
}
|
| 274 |
}
|
| 275 |
#else
|
| 276 |
+
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
| 277 |
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
| 278 |
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
| 279 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 13 |
const char * __restrict__ K,
|
| 14 |
const char * __restrict__ V,
|
| 15 |
const char * __restrict__ mask,
|
|
|
|
| 16 |
const int * __restrict__ KV_max,
|
| 17 |
float * __restrict__ dst,
|
| 18 |
float2 * __restrict__ dst_meta,
|
|
@@ -37,7 +38,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 37 |
return;
|
| 38 |
#endif // FP16_MMA_AVAILABLE
|
| 39 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 40 |
-
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
| 41 |
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
| 42 |
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
| 43 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
|
|
| 13 |
const char * __restrict__ K,
|
| 14 |
const char * __restrict__ V,
|
| 15 |
const char * __restrict__ mask,
|
| 16 |
+
const char * __restrict__ sinks,
|
| 17 |
const int * __restrict__ KV_max,
|
| 18 |
float * __restrict__ dst,
|
| 19 |
float2 * __restrict__ dst_meta,
|
|
|
|
| 38 |
return;
|
| 39 |
#endif // FP16_MMA_AVAILABLE
|
| 40 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 41 |
+
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
| 42 |
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
| 43 |
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
| 44 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 16 |
const char * __restrict__ K,
|
| 17 |
const char * __restrict__ V,
|
| 18 |
const char * __restrict__ mask,
|
|
|
|
| 19 |
const int * __restrict__ KV_max,
|
| 20 |
float * __restrict__ dst,
|
| 21 |
float2 * __restrict__ dst_meta,
|
|
@@ -61,7 +62,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 61 |
K += nb13*sequence + nb12*(head / gqa_ratio);
|
| 62 |
V += nb23*sequence + nb22*(head / gqa_ratio);
|
| 63 |
|
| 64 |
-
const half
|
|
|
|
| 65 |
|
| 66 |
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
| 67 |
const half slopeh = __float2half(slopef);
|
|
@@ -75,11 +77,12 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 75 |
half2 * KQ2 = (half2 *) KQ;
|
| 76 |
|
| 77 |
half kqmax[ncols];
|
|
|
|
| 78 |
#pragma unroll
|
| 79 |
for (int j = 0; j < ncols; ++j) {
|
| 80 |
kqmax[j] = -HALF_MAX_HALF;
|
|
|
|
| 81 |
}
|
| 82 |
-
half kqsum[ncols] = {0.0f};
|
| 83 |
|
| 84 |
__shared__ half kqmax_shared[ncols][WARP_SIZE];
|
| 85 |
__shared__ half kqsum_shared[ncols][WARP_SIZE];
|
|
@@ -283,6 +286,39 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 283 |
__syncthreads();
|
| 284 |
}
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
#pragma unroll
|
| 287 |
for (int j = 0; j < ncols; ++j) {
|
| 288 |
kqsum[j] = warp_reduce_sum((float)kqsum[j]);
|
|
@@ -313,7 +349,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 313 |
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
| 314 |
}
|
| 315 |
#else
|
| 316 |
-
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
| 317 |
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
| 318 |
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
| 319 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
|
|
| 16 |
const char * __restrict__ K,
|
| 17 |
const char * __restrict__ V,
|
| 18 |
const char * __restrict__ mask,
|
| 19 |
+
const char * __restrict__ sinks,
|
| 20 |
const int * __restrict__ KV_max,
|
| 21 |
float * __restrict__ dst,
|
| 22 |
float2 * __restrict__ dst_meta,
|
|
|
|
| 62 |
K += nb13*sequence + nb12*(head / gqa_ratio);
|
| 63 |
V += nb23*sequence + nb22*(head / gqa_ratio);
|
| 64 |
|
| 65 |
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
| 66 |
+
const float * sinksf = (const float *) (sinks);
|
| 67 |
|
| 68 |
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
| 69 |
const half slopeh = __float2half(slopef);
|
|
|
|
| 77 |
half2 * KQ2 = (half2 *) KQ;
|
| 78 |
|
| 79 |
half kqmax[ncols];
|
| 80 |
+
half kqsum[ncols];
|
| 81 |
#pragma unroll
|
| 82 |
for (int j = 0; j < ncols; ++j) {
|
| 83 |
kqmax[j] = -HALF_MAX_HALF;
|
| 84 |
+
kqsum[j] = 0.0f;
|
| 85 |
}
|
|
|
|
| 86 |
|
| 87 |
__shared__ half kqmax_shared[ncols][WARP_SIZE];
|
| 88 |
__shared__ half kqsum_shared[ncols][WARP_SIZE];
|
|
|
|
| 286 |
__syncthreads();
|
| 287 |
}
|
| 288 |
|
| 289 |
+
if (sinksf && blockIdx.y == 0) {
|
| 290 |
+
const half sink = __float2half(sinksf[head]);
|
| 291 |
+
|
| 292 |
+
#pragma unroll
|
| 293 |
+
for (int j = 0; j < ncols; ++j) {
|
| 294 |
+
if (threadIdx.x == 0) {
|
| 295 |
+
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
|
| 296 |
+
}
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
__syncthreads();
|
| 300 |
+
|
| 301 |
+
#pragma unroll
|
| 302 |
+
for (int j = 0; j < ncols; ++j) {
|
| 303 |
+
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
| 304 |
+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
| 305 |
+
|
| 306 |
+
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
| 307 |
+
kqmax[j] = kqmax_new_j;
|
| 308 |
+
|
| 309 |
+
const half val = hexp(sink - kqmax[j]);
|
| 310 |
+
kqsum[j] = kqsum[j]*KQ_max_scale;
|
| 311 |
+
|
| 312 |
+
if (tid == 0) {
|
| 313 |
+
kqsum[j] += val;
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
VKQ[j] *= __half2half2(KQ_max_scale);
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
__syncthreads();
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
#pragma unroll
|
| 323 |
for (int j = 0; j < ncols; ++j) {
|
| 324 |
kqsum[j] = warp_reduce_sum((float)kqsum[j]);
|
|
|
|
| 349 |
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
| 350 |
}
|
| 351 |
#else
|
| 352 |
+
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
| 353 |
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
| 354 |
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
| 355 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 16 |
const char * __restrict__ K,
|
| 17 |
const char * __restrict__ V,
|
| 18 |
const char * __restrict__ mask,
|
|
|
|
| 19 |
const int * __restrict__ KV_max,
|
| 20 |
float * __restrict__ dst,
|
| 21 |
float2 * __restrict__ dst_meta,
|
|
@@ -72,7 +73,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 72 |
K += nb13*sequence + nb12*(head / gqa_ratio);
|
| 73 |
V += nb23*sequence + nb22*(head / gqa_ratio);
|
| 74 |
|
| 75 |
-
const half
|
|
|
|
| 76 |
|
| 77 |
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
| 78 |
|
|
@@ -88,11 +90,12 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 88 |
}
|
| 89 |
|
| 90 |
float kqmax[ncols];
|
|
|
|
| 91 |
#pragma unroll
|
| 92 |
for (int j = 0; j < ncols; ++j) {
|
| 93 |
kqmax[j] = -FLT_MAX/2.0f;
|
|
|
|
| 94 |
}
|
| 95 |
-
float kqsum[ncols] = {0.0f};
|
| 96 |
|
| 97 |
__shared__ float kqmax_shared[ncols][WARP_SIZE];
|
| 98 |
__shared__ float kqsum_shared[ncols][WARP_SIZE];
|
|
@@ -279,6 +282,39 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 279 |
__syncthreads();
|
| 280 |
}
|
| 281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
#pragma unroll
|
| 283 |
for (int j = 0; j < ncols; ++j) {
|
| 284 |
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
|
|
|
| 16 |
const char * __restrict__ K,
|
| 17 |
const char * __restrict__ V,
|
| 18 |
const char * __restrict__ mask,
|
| 19 |
+
const char * __restrict__ sinks,
|
| 20 |
const int * __restrict__ KV_max,
|
| 21 |
float * __restrict__ dst,
|
| 22 |
float2 * __restrict__ dst_meta,
|
|
|
|
| 73 |
K += nb13*sequence + nb12*(head / gqa_ratio);
|
| 74 |
V += nb23*sequence + nb22*(head / gqa_ratio);
|
| 75 |
|
| 76 |
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
| 77 |
+
const float * sinksf = (const float *) (sinks);
|
| 78 |
|
| 79 |
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
| 80 |
|
|
|
|
| 90 |
}
|
| 91 |
|
| 92 |
float kqmax[ncols];
|
| 93 |
+
float kqsum[ncols];
|
| 94 |
#pragma unroll
|
| 95 |
for (int j = 0; j < ncols; ++j) {
|
| 96 |
kqmax[j] = -FLT_MAX/2.0f;
|
| 97 |
+
kqsum[j] = 0.0f;
|
| 98 |
}
|
|
|
|
| 99 |
|
| 100 |
__shared__ float kqmax_shared[ncols][WARP_SIZE];
|
| 101 |
__shared__ float kqsum_shared[ncols][WARP_SIZE];
|
|
|
|
| 282 |
__syncthreads();
|
| 283 |
}
|
| 284 |
|
| 285 |
+
if (sinksf && blockIdx.y == 0) {
|
| 286 |
+
const float sink = sinksf[head];
|
| 287 |
+
|
| 288 |
+
#pragma unroll
|
| 289 |
+
for (int j = 0; j < ncols; ++j) {
|
| 290 |
+
if (threadIdx.x == 0) {
|
| 291 |
+
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
__syncthreads();
|
| 296 |
+
|
| 297 |
+
#pragma unroll
|
| 298 |
+
for (int j = 0; j < ncols; ++j) {
|
| 299 |
+
float kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
| 300 |
+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
| 301 |
+
|
| 302 |
+
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
|
| 303 |
+
kqmax[j] = kqmax_new_j;
|
| 304 |
+
|
| 305 |
+
const float val = expf(sink - kqmax[j]);
|
| 306 |
+
kqsum[j] = kqsum[j]*KQ_max_scale;
|
| 307 |
+
|
| 308 |
+
if (tid == 0) {
|
| 309 |
+
kqsum[j] += val;
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
VKQ[j] *= KQ_max_scale;
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
__syncthreads();
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
#pragma unroll
|
| 319 |
for (int j = 0; j < ncols; ++j) {
|
| 320 |
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
|
@@ -29,6 +29,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 29 |
const char * __restrict__ K,
|
| 30 |
const char * __restrict__ V,
|
| 31 |
const char * __restrict__ mask,
|
|
|
|
| 32 |
const int * __restrict__ KV_max,
|
| 33 |
float * __restrict__ dst,
|
| 34 |
float2 * __restrict__ dst_meta,
|
|
@@ -423,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 423 |
dst_meta[j_dst_unrolled] = dst_meta_val;
|
| 424 |
}
|
| 425 |
#else
|
| 426 |
-
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
| 427 |
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
| 428 |
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
| 429 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
|
|
| 29 |
const char * __restrict__ K,
|
| 30 |
const char * __restrict__ V,
|
| 31 |
const char * __restrict__ mask,
|
| 32 |
+
const char * __restrict__ sinks,
|
| 33 |
const int * __restrict__ KV_max,
|
| 34 |
float * __restrict__ dst,
|
| 35 |
float2 * __restrict__ dst_meta,
|
|
|
|
| 424 |
dst_meta[j_dst_unrolled] = dst_meta_val;
|
| 425 |
}
|
| 426 |
#else
|
| 427 |
+
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
| 428 |
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
| 429 |
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
| 430 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
@@ -269,17 +269,28 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
|
|
| 269 |
}
|
| 270 |
|
| 271 |
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 272 |
-
const ggml_tensor * KQV
|
| 273 |
-
const ggml_tensor * Q
|
| 274 |
-
const ggml_tensor * K
|
| 275 |
-
const ggml_tensor * V
|
| 276 |
-
const ggml_tensor * mask
|
|
|
|
| 277 |
|
| 278 |
ggml_cuda_set_device(ctx.device);
|
| 279 |
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 280 |
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
| 281 |
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
| 282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
| 284 |
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
|
| 285 |
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
|
|
|
| 269 |
}
|
| 270 |
|
| 271 |
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 272 |
+
const ggml_tensor * KQV = dst;
|
| 273 |
+
const ggml_tensor * Q = dst->src[0];
|
| 274 |
+
const ggml_tensor * K = dst->src[1];
|
| 275 |
+
const ggml_tensor * V = dst->src[2];
|
| 276 |
+
const ggml_tensor * mask = dst->src[3];
|
| 277 |
+
const ggml_tensor * sinks = dst->src[4];
|
| 278 |
|
| 279 |
ggml_cuda_set_device(ctx.device);
|
| 280 |
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 281 |
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
| 282 |
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
| 283 |
|
| 284 |
+
// TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
|
| 285 |
+
if (sinks) {
|
| 286 |
+
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
| 287 |
+
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 288 |
+
} else {
|
| 289 |
+
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
| 290 |
+
}
|
| 291 |
+
return;
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
| 295 |
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
|
| 296 |
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
|
@@ -4,6 +4,7 @@
|
|
| 4 |
|
| 5 |
#include "ggml-cuda/common.cuh"
|
| 6 |
#include "ggml-cuda/acc.cuh"
|
|
|
|
| 7 |
#include "ggml-cuda/arange.cuh"
|
| 8 |
#include "ggml-cuda/argmax.cuh"
|
| 9 |
#include "ggml-cuda/argsort.cuh"
|
|
@@ -2259,6 +2260,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2259 |
case GGML_OP_ADD1: // TODO: more efficient implementation
|
| 2260 |
ggml_cuda_op_add(ctx, dst);
|
| 2261 |
break;
|
|
|
|
|
|
|
|
|
|
| 2262 |
case GGML_OP_SUB:
|
| 2263 |
ggml_cuda_op_sub(ctx, dst);
|
| 2264 |
break;
|
|
@@ -2333,6 +2337,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2333 |
case GGML_GLU_OP_SWIGLU:
|
| 2334 |
ggml_cuda_op_swiglu(ctx, dst);
|
| 2335 |
break;
|
|
|
|
|
|
|
|
|
|
| 2336 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 2337 |
ggml_cuda_op_geglu_erf(ctx, dst);
|
| 2338 |
break;
|
|
@@ -2607,6 +2614,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
|
| 2607 |
|
| 2608 |
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
| 2609 |
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
|
|
|
|
|
|
|
|
|
| 2610 |
|
| 2611 |
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 2612 |
ggml_tensor * node = cgraph->nodes[i];
|
|
@@ -2629,7 +2639,13 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
|
| 2629 |
#endif
|
| 2630 |
}
|
| 2631 |
|
| 2632 |
-
if (node->op == GGML_OP_ADD &&
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2633 |
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
| 2634 |
// by means of matching node names. See
|
| 2635 |
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
|
@@ -3227,6 +3243,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3227 |
case GGML_GLU_OP_REGLU:
|
| 3228 |
case GGML_GLU_OP_GEGLU:
|
| 3229 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
| 3230 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 3231 |
case GGML_GLU_OP_GEGLU_QUICK:
|
| 3232 |
return ggml_is_contiguous_1(op->src[0]);
|
|
@@ -3277,6 +3294,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3277 |
case GGML_TYPE_Q5_0:
|
| 3278 |
case GGML_TYPE_Q5_1:
|
| 3279 |
case GGML_TYPE_Q8_0:
|
|
|
|
| 3280 |
case GGML_TYPE_Q2_K:
|
| 3281 |
case GGML_TYPE_Q3_K:
|
| 3282 |
case GGML_TYPE_Q4_K:
|
|
@@ -3423,6 +3441,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3423 |
case GGML_OP_PERMUTE:
|
| 3424 |
case GGML_OP_TRANSPOSE:
|
| 3425 |
case GGML_OP_ADD:
|
|
|
|
| 3426 |
case GGML_OP_ADD1:
|
| 3427 |
case GGML_OP_SUB:
|
| 3428 |
case GGML_OP_MUL:
|
|
@@ -3503,6 +3522,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3503 |
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
| 3504 |
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
|
| 3505 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3506 |
if (op->src[0]->ne[0] == 192) {
|
| 3507 |
return false;
|
| 3508 |
}
|
|
|
|
| 4 |
|
| 5 |
#include "ggml-cuda/common.cuh"
|
| 6 |
#include "ggml-cuda/acc.cuh"
|
| 7 |
+
#include "ggml-cuda/add-id.cuh"
|
| 8 |
#include "ggml-cuda/arange.cuh"
|
| 9 |
#include "ggml-cuda/argmax.cuh"
|
| 10 |
#include "ggml-cuda/argsort.cuh"
|
|
|
|
| 2260 |
case GGML_OP_ADD1: // TODO: more efficient implementation
|
| 2261 |
ggml_cuda_op_add(ctx, dst);
|
| 2262 |
break;
|
| 2263 |
+
case GGML_OP_ADD_ID:
|
| 2264 |
+
ggml_cuda_op_add_id(ctx, dst);
|
| 2265 |
+
break;
|
| 2266 |
case GGML_OP_SUB:
|
| 2267 |
ggml_cuda_op_sub(ctx, dst);
|
| 2268 |
break;
|
|
|
|
| 2337 |
case GGML_GLU_OP_SWIGLU:
|
| 2338 |
ggml_cuda_op_swiglu(ctx, dst);
|
| 2339 |
break;
|
| 2340 |
+
case GGML_GLU_OP_SWIGLU_OAI:
|
| 2341 |
+
ggml_cuda_op_swiglu_oai(ctx, dst);
|
| 2342 |
+
break;
|
| 2343 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 2344 |
ggml_cuda_op_geglu_erf(ctx, dst);
|
| 2345 |
break;
|
|
|
|
| 2614 |
|
| 2615 |
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
| 2616 |
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
| 2617 |
+
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
|
| 2618 |
+
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
|
| 2619 |
+
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
|
| 2620 |
|
| 2621 |
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 2622 |
ggml_tensor * node = cgraph->nodes[i];
|
|
|
|
| 2639 |
#endif
|
| 2640 |
}
|
| 2641 |
|
| 2642 |
+
if (node->op == GGML_OP_ADD &&
|
| 2643 |
+
node->src[1] && node->src[1]->ne[1] > 1 &&
|
| 2644 |
+
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
|
| 2645 |
+
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
|
| 2646 |
+
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
|
| 2647 |
+
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
|
| 2648 |
+
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
|
| 2649 |
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
| 2650 |
// by means of matching node names. See
|
| 2651 |
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
|
|
|
| 3243 |
case GGML_GLU_OP_REGLU:
|
| 3244 |
case GGML_GLU_OP_GEGLU:
|
| 3245 |
case GGML_GLU_OP_SWIGLU:
|
| 3246 |
+
case GGML_GLU_OP_SWIGLU_OAI:
|
| 3247 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 3248 |
case GGML_GLU_OP_GEGLU_QUICK:
|
| 3249 |
return ggml_is_contiguous_1(op->src[0]);
|
|
|
|
| 3294 |
case GGML_TYPE_Q5_0:
|
| 3295 |
case GGML_TYPE_Q5_1:
|
| 3296 |
case GGML_TYPE_Q8_0:
|
| 3297 |
+
case GGML_TYPE_MXFP4:
|
| 3298 |
case GGML_TYPE_Q2_K:
|
| 3299 |
case GGML_TYPE_Q3_K:
|
| 3300 |
case GGML_TYPE_Q4_K:
|
|
|
|
| 3441 |
case GGML_OP_PERMUTE:
|
| 3442 |
case GGML_OP_TRANSPOSE:
|
| 3443 |
case GGML_OP_ADD:
|
| 3444 |
+
case GGML_OP_ADD_ID:
|
| 3445 |
case GGML_OP_ADD1:
|
| 3446 |
case GGML_OP_SUB:
|
| 3447 |
case GGML_OP_MUL:
|
|
|
|
| 3522 |
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
| 3523 |
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
|
| 3524 |
}
|
| 3525 |
+
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
|
| 3526 |
+
if (op->src[4] && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { // currently only sinks for head_size 64 and 128 are supported
|
| 3527 |
+
return false;
|
| 3528 |
+
}
|
| 3529 |
if (op->src[0]->ne[0] == 192) {
|
| 3530 |
return false;
|
| 3531 |
}
|
|
@@ -1,7 +1,5 @@
|
|
| 1 |
#include "im2col.cuh"
|
| 2 |
|
| 3 |
-
#define MIN(a, b) (a) < (b) ? (a) : (b)
|
| 4 |
-
|
| 5 |
#define MAX_GRIDDIM_Z 65535
|
| 6 |
|
| 7 |
template <typename T>
|
|
@@ -38,6 +36,9 @@ static __global__ void im2col_kernel(
|
|
| 38 |
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
| 39 |
}
|
| 40 |
}
|
|
|
|
|
|
|
|
|
|
| 41 |
}
|
| 42 |
|
| 43 |
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
|
|
|
| 1 |
#include "im2col.cuh"
|
| 2 |
|
|
|
|
|
|
|
| 3 |
#define MAX_GRIDDIM_Z 65535
|
| 4 |
|
| 5 |
template <typename T>
|
|
|
|
| 36 |
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
| 37 |
}
|
| 38 |
}
|
| 39 |
+
|
| 40 |
+
GGML_UNUSED(IC);
|
| 41 |
+
GGML_UNUSED(KH);
|
| 42 |
}
|
| 43 |
|
| 44 |
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
|
@@ -20,6 +20,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
|
|
| 20 |
case GGML_TYPE_Q8_0:
|
| 21 |
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
|
| 22 |
break;
|
|
|
|
|
|
|
|
|
|
| 23 |
case GGML_TYPE_Q2_K:
|
| 24 |
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
|
| 25 |
break;
|
|
@@ -282,6 +285,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
|
| 282 |
case GGML_TYPE_Q5_0:
|
| 283 |
case GGML_TYPE_Q5_1:
|
| 284 |
case GGML_TYPE_Q8_0:
|
|
|
|
| 285 |
case GGML_TYPE_Q2_K:
|
| 286 |
case GGML_TYPE_Q3_K:
|
| 287 |
case GGML_TYPE_Q4_K:
|
|
|
|
| 20 |
case GGML_TYPE_Q8_0:
|
| 21 |
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
|
| 22 |
break;
|
| 23 |
+
case GGML_TYPE_MXFP4:
|
| 24 |
+
mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
|
| 25 |
+
break;
|
| 26 |
case GGML_TYPE_Q2_K:
|
| 27 |
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
|
| 28 |
break;
|
|
|
|
| 285 |
case GGML_TYPE_Q5_0:
|
| 286 |
case GGML_TYPE_Q5_1:
|
| 287 |
case GGML_TYPE_Q8_0:
|
| 288 |
+
case GGML_TYPE_MXFP4:
|
| 289 |
case GGML_TYPE_Q2_K:
|
| 290 |
case GGML_TYPE_Q3_K:
|
| 291 |
case GGML_TYPE_Q4_K:
|
|
@@ -58,6 +58,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
|
| 58 |
return MMQ_Q8_1_DS_LAYOUT_DS4;
|
| 59 |
case GGML_TYPE_Q8_0:
|
| 60 |
return MMQ_Q8_1_DS_LAYOUT_D4;
|
|
|
|
|
|
|
| 61 |
case GGML_TYPE_Q2_K:
|
| 62 |
return MMQ_Q8_1_DS_LAYOUT_D2S6;
|
| 63 |
case GGML_TYPE_Q3_K:
|
|
@@ -170,6 +172,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|
| 170 |
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
|
| 171 |
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
|
| 172 |
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
|
|
|
|
| 173 |
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
|
| 174 |
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
|
| 175 |
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
|
|
@@ -206,6 +209,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
| 206 |
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
| 207 |
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
| 208 |
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
|
|
| 209 |
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
| 210 |
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
| 211 |
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
@@ -692,6 +696,71 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 692 |
}
|
| 693 |
}
|
| 694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
template <int mmq_x, int mmq_y>
|
| 696 |
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
| 697 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
@@ -2268,7 +2337,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2268 |
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
|
| 2269 |
|
| 2270 |
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
|
| 2271 |
-
const int2 v = get_int_from_table_16(aux_q4);
|
| 2272 |
const int k0 = kbx * (2 * QI4_NL) + kqsx;
|
| 2273 |
|
| 2274 |
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
@@ -2707,7 +2776,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2707 |
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
|
| 2708 |
|
| 2709 |
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
|
| 2710 |
-
const int2 v = get_int_from_table_16(aux_q4);
|
| 2711 |
const int k0 = 8 * (kqsx / 4) + kqsx % 4;
|
| 2712 |
|
| 2713 |
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
@@ -2863,6 +2932,14 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
|
|
| 2863 |
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
| 2864 |
};
|
| 2865 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2866 |
template <int mmq_x, int mmq_y, bool need_check>
|
| 2867 |
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
|
| 2868 |
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
|
@@ -3642,6 +3719,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
|
|
| 3642 |
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
|
| 3643 |
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
|
| 3644 |
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
|
|
|
|
| 3645 |
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
|
| 3646 |
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
|
| 3647 |
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
|
|
|
|
| 58 |
return MMQ_Q8_1_DS_LAYOUT_DS4;
|
| 59 |
case GGML_TYPE_Q8_0:
|
| 60 |
return MMQ_Q8_1_DS_LAYOUT_D4;
|
| 61 |
+
case GGML_TYPE_MXFP4:
|
| 62 |
+
return MMQ_Q8_1_DS_LAYOUT_D4;
|
| 63 |
case GGML_TYPE_Q2_K:
|
| 64 |
return MMQ_Q8_1_DS_LAYOUT_D2S6;
|
| 65 |
case GGML_TYPE_Q3_K:
|
|
|
|
| 172 |
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
|
| 173 |
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
|
| 174 |
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
|
| 175 |
+
case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
|
| 176 |
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
|
| 177 |
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
|
| 178 |
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
|
|
|
|
| 209 |
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
| 210 |
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
| 211 |
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
| 212 |
+
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
| 213 |
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
| 214 |
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
| 215 |
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
|
|
| 696 |
}
|
| 697 |
}
|
| 698 |
|
| 699 |
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
|
| 700 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 701 |
+
constexpr int nwarps = mmq_get_nwarps_device();
|
| 702 |
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 703 |
+
|
| 704 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
| 705 |
+
int * x_qs = (int *) x_tile;
|
| 706 |
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 707 |
+
#else
|
| 708 |
+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
|
| 709 |
+
int * x_qs = (int *) x_tile;
|
| 710 |
+
float * x_df = (float *) (x_qs + txs.qs);
|
| 711 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
| 712 |
+
|
| 713 |
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
|
| 714 |
+
constexpr int nrows = warp_size / threads_per_row;
|
| 715 |
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
| 716 |
+
const int kbx = txi / QI_MXFP4;
|
| 717 |
+
const int kqsx = txi % QI_MXFP4;
|
| 718 |
+
|
| 719 |
+
#pragma unroll
|
| 720 |
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
| 721 |
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
| 722 |
+
|
| 723 |
+
if (need_check) {
|
| 724 |
+
i = min(i, i_max);
|
| 725 |
+
}
|
| 726 |
+
|
| 727 |
+
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
|
| 728 |
+
|
| 729 |
+
const int aux_q4 = get_int_b1(bxi->qs, kqsx);
|
| 730 |
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
| 731 |
+
const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
|
| 732 |
+
|
| 733 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
| 734 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
|
| 735 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
|
| 736 |
+
#else
|
| 737 |
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
| 738 |
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
|
| 739 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
| 740 |
+
}
|
| 741 |
+
|
| 742 |
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
|
| 743 |
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
| 744 |
+
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
| 745 |
+
|
| 746 |
+
#pragma unroll
|
| 747 |
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
| 748 |
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
| 749 |
+
|
| 750 |
+
if (need_check) {
|
| 751 |
+
i = min(i, i_max);
|
| 752 |
+
}
|
| 753 |
+
|
| 754 |
+
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
|
| 755 |
+
|
| 756 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
| 757 |
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
| 758 |
+
#else
|
| 759 |
+
x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
| 760 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
| 761 |
+
}
|
| 762 |
+
}
|
| 763 |
+
|
| 764 |
template <int mmq_x, int mmq_y>
|
| 765 |
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
| 766 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
|
|
| 2337 |
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
|
| 2338 |
|
| 2339 |
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
|
| 2340 |
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
| 2341 |
const int k0 = kbx * (2 * QI4_NL) + kqsx;
|
| 2342 |
|
| 2343 |
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
|
|
| 2776 |
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
|
| 2777 |
|
| 2778 |
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
|
| 2779 |
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
| 2780 |
const int k0 = 8 * (kqsx / 4) + kqsx % 4;
|
| 2781 |
|
| 2782 |
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
|
|
| 2932 |
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
| 2933 |
};
|
| 2934 |
|
| 2935 |
+
template <int mmq_x, int mmq_y, bool need_check>
|
| 2936 |
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
| 2937 |
+
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
|
| 2938 |
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
|
| 2939 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
| 2940 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
| 2941 |
+
};
|
| 2942 |
+
|
| 2943 |
template <int mmq_x, int mmq_y, bool need_check>
|
| 2944 |
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
|
| 2945 |
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
|
|
|
| 3719 |
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
|
| 3720 |
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
|
| 3721 |
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
|
| 3722 |
+
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
|
| 3723 |
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
|
| 3724 |
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
|
| 3725 |
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
|
|
@@ -13,6 +13,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
|
|
| 13 |
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
|
| 14 |
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
|
| 15 |
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
|
|
|
|
| 16 |
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
|
| 17 |
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
|
| 18 |
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
|
|
@@ -38,6 +39,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
|
| 38 |
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
|
| 39 |
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
|
| 40 |
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
|
|
|
|
| 41 |
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
|
| 42 |
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
|
| 43 |
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
|
|
@@ -384,6 +386,13 @@ static void mul_mat_vec_q_switch_type(
|
|
| 384 |
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
| 385 |
stream);
|
| 386 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
case GGML_TYPE_Q2_K:
|
| 388 |
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
|
| 389 |
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
|
|
| 13 |
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
|
| 14 |
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
|
| 15 |
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
|
| 16 |
+
case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
|
| 17 |
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
|
| 18 |
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
|
| 19 |
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
|
|
|
|
| 39 |
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
|
| 40 |
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
|
| 41 |
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
|
| 42 |
+
case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
|
| 43 |
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
|
| 44 |
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
|
| 45 |
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
|
|
|
|
| 386 |
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
| 387 |
stream);
|
| 388 |
break;
|
| 389 |
+
case GGML_TYPE_MXFP4:
|
| 390 |
+
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
|
| 391 |
+
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
| 392 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 393 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
| 394 |
+
stream);
|
| 395 |
+
break;
|
| 396 |
case GGML_TYPE_Q2_K:
|
| 397 |
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
|
| 398 |
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
@@ -45,7 +45,7 @@ struct soft_max_params {
|
|
| 45 |
#endif // __clang__
|
| 46 |
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
| 47 |
static __global__ void soft_max_f32(
|
| 48 |
-
const float * x, const T * mask, float * dst, const soft_max_params p) {
|
| 49 |
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
|
| 50 |
|
| 51 |
const int tid = threadIdx.x;
|
|
@@ -77,7 +77,7 @@ static __global__ void soft_max_f32(
|
|
| 77 |
// shared memory buffer to cache values between iterations:
|
| 78 |
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
|
| 79 |
|
| 80 |
-
float max_val = -INFINITY;
|
| 81 |
|
| 82 |
#pragma unroll
|
| 83 |
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
|
@@ -143,6 +143,10 @@ static __global__ void soft_max_f32(
|
|
| 143 |
tmp = warp_reduce_sum(tmp);
|
| 144 |
}
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
const float inv_sum = 1.0f / tmp;
|
| 147 |
|
| 148 |
#pragma unroll
|
|
@@ -183,7 +187,7 @@ static __global__ void soft_max_back_f32(
|
|
| 183 |
}
|
| 184 |
|
| 185 |
template<int... Ns, typename T>
|
| 186 |
-
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
|
| 187 |
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
|
| 188 |
{
|
| 189 |
const int id = ggml_cuda_get_device();
|
|
@@ -196,7 +200,7 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
|
|
| 196 |
if (p.ncols == ncols) {
|
| 197 |
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
|
| 198 |
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 199 |
-
(x, mask, dst, p);
|
| 200 |
return true;
|
| 201 |
}
|
| 202 |
return false;
|
|
@@ -209,12 +213,12 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
|
|
| 209 |
|
| 210 |
//default case
|
| 211 |
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
|
| 212 |
-
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
|
| 213 |
}
|
| 214 |
|
| 215 |
|
| 216 |
template<typename T>
|
| 217 |
-
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
| 218 |
int nth = WARP_SIZE;
|
| 219 |
const int64_t ncols_x = params.ncols;
|
| 220 |
|
|
@@ -230,10 +234,10 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
|
|
| 230 |
|
| 231 |
|
| 232 |
if (nbytes_shared <= smpbo) {
|
| 233 |
-
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
| 234 |
} else {
|
| 235 |
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
| 236 |
-
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
|
| 237 |
}
|
| 238 |
}
|
| 239 |
|
|
@@ -249,9 +253,11 @@ static void soft_max_back_f32_cuda(
|
|
| 249 |
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 250 |
const ggml_tensor * src0 = dst->src[0];
|
| 251 |
const ggml_tensor * src1 = dst->src[1];
|
|
|
|
| 252 |
|
| 253 |
const float * src0_d = (const float *) src0->data;
|
| 254 |
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
|
|
|
|
| 255 |
float * dst_d = (float *) dst->data;
|
| 256 |
|
| 257 |
cudaStream_t stream = ctx.stream();
|
|
@@ -309,9 +315,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 309 |
params.m1 = m1;
|
| 310 |
|
| 311 |
if (use_f16) {
|
| 312 |
-
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream);
|
| 313 |
} else {
|
| 314 |
-
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
|
| 315 |
}
|
| 316 |
}
|
| 317 |
|
|
|
|
| 45 |
#endif // __clang__
|
| 46 |
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
| 47 |
static __global__ void soft_max_f32(
|
| 48 |
+
const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
|
| 49 |
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
|
| 50 |
|
| 51 |
const int tid = threadIdx.x;
|
|
|
|
| 77 |
// shared memory buffer to cache values between iterations:
|
| 78 |
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
|
| 79 |
|
| 80 |
+
float max_val = sinks ? sinks[i02] : -INFINITY;
|
| 81 |
|
| 82 |
#pragma unroll
|
| 83 |
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
|
|
|
| 143 |
tmp = warp_reduce_sum(tmp);
|
| 144 |
}
|
| 145 |
|
| 146 |
+
if (sinks) {
|
| 147 |
+
tmp += expf(sinks[i02] - max_val);
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
const float inv_sum = 1.0f / tmp;
|
| 151 |
|
| 152 |
#pragma unroll
|
|
|
|
| 187 |
}
|
| 188 |
|
| 189 |
template<int... Ns, typename T>
|
| 190 |
+
static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
|
| 191 |
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
|
| 192 |
{
|
| 193 |
const int id = ggml_cuda_get_device();
|
|
|
|
| 200 |
if (p.ncols == ncols) {
|
| 201 |
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
|
| 202 |
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 203 |
+
(x, mask, sinks, dst, p);
|
| 204 |
return true;
|
| 205 |
}
|
| 206 |
return false;
|
|
|
|
| 213 |
|
| 214 |
//default case
|
| 215 |
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
|
| 216 |
+
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
|
| 217 |
}
|
| 218 |
|
| 219 |
|
| 220 |
template<typename T>
|
| 221 |
+
static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
| 222 |
int nth = WARP_SIZE;
|
| 223 |
const int64_t ncols_x = params.ncols;
|
| 224 |
|
|
|
|
| 234 |
|
| 235 |
|
| 236 |
if (nbytes_shared <= smpbo) {
|
| 237 |
+
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
| 238 |
} else {
|
| 239 |
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
| 240 |
+
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
|
| 241 |
}
|
| 242 |
}
|
| 243 |
|
|
|
|
| 253 |
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 254 |
const ggml_tensor * src0 = dst->src[0];
|
| 255 |
const ggml_tensor * src1 = dst->src[1];
|
| 256 |
+
const ggml_tensor * src2 = dst->src[2];
|
| 257 |
|
| 258 |
const float * src0_d = (const float *) src0->data;
|
| 259 |
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
|
| 260 |
+
const void * src2_d = src2 ? (const void *) src2->data : nullptr;
|
| 261 |
float * dst_d = (float *) dst->data;
|
| 262 |
|
| 263 |
cudaStream_t stream = ctx.stream();
|
|
|
|
| 315 |
params.m1 = m1;
|
| 316 |
|
| 317 |
if (use_f16) {
|
| 318 |
+
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
| 319 |
} else {
|
| 320 |
+
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
| 321 |
}
|
| 322 |
}
|
| 323 |
|
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../mmq.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_MMQ_CASE(GGML_TYPE_MXFP4);
|
|
@@ -300,6 +300,81 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 300 |
ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
|
| 301 |
}
|
| 302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
/* silu_back */
|
| 304 |
|
| 305 |
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
|
|
|
| 300 |
ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
|
| 301 |
}
|
| 302 |
|
| 303 |
+
// swiglu_oai
|
| 304 |
+
|
| 305 |
+
template <typename T>
|
| 306 |
+
static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
|
| 307 |
+
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
|
| 308 |
+
|
| 309 |
+
if (i >= k) {
|
| 310 |
+
return;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
// perform base op and multiply with gate (either offset in same tensor or a separate one)
|
| 314 |
+
const int64_t j0 = (i / n) * o0 + (i % n);
|
| 315 |
+
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
| 316 |
+
|
| 317 |
+
float xi = x[j0];
|
| 318 |
+
float gi = g[j1];
|
| 319 |
+
xi = fminf(xi, limit);
|
| 320 |
+
gi = fmaxf(fminf(gi, limit), -limit);
|
| 321 |
+
|
| 322 |
+
float out_glu = xi / (1.0f + expf(-xi * alpha));
|
| 323 |
+
out_glu = out_glu * (1.0f + gi);
|
| 324 |
+
|
| 325 |
+
dst[i] = out_glu;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
template <typename T>
|
| 329 |
+
static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
|
| 330 |
+
const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
|
| 331 |
+
swiglu_oai_kernel<<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 335 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 336 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 337 |
+
void * src0_d = src0->data;
|
| 338 |
+
void * src1_d = src1 ? src1->data : src0->data;
|
| 339 |
+
const int64_t src0_o = src0->nb[1];
|
| 340 |
+
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
| 341 |
+
void * dst_d = dst->data;
|
| 342 |
+
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
| 343 |
+
cudaStream_t stream = ctx.stream();
|
| 344 |
+
|
| 345 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 346 |
+
GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
|
| 347 |
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
| 348 |
+
|
| 349 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 350 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 351 |
+
GGML_ASSERT(src0->type == dst->type);
|
| 352 |
+
GGML_ASSERT(dst->ne[0] == nc);
|
| 353 |
+
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
|
| 354 |
+
|
| 355 |
+
if (src1) {
|
| 356 |
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
| 357 |
+
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
|
| 358 |
+
GGML_ASSERT(src1->ne[0] == nc);
|
| 359 |
+
GGML_ASSERT(src0->type == src1->type);
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
//const int32_t swapped = ((const int32_t *) dst->op_params)[1];
|
| 363 |
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
| 364 |
+
const float alpha = ggml_get_op_params_f32(dst, 2);
|
| 365 |
+
const float limit = ggml_get_op_params_f32(dst, 3);
|
| 366 |
+
|
| 367 |
+
float * src0_p = (float *) src0_d;
|
| 368 |
+
float * src1_p = (float *) src1_d;
|
| 369 |
+
|
| 370 |
+
if (!src1) {
|
| 371 |
+
src0_p += swapped ? nc : 0;
|
| 372 |
+
src1_p += swapped ? 0 : nc;
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
/* silu_back */
|
| 379 |
|
| 380 |
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
|
@@ -67,6 +67,8 @@ void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
| 67 |
|
| 68 |
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 69 |
|
|
|
|
|
|
|
| 70 |
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 71 |
|
| 72 |
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
|
|
| 67 |
|
| 68 |
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 69 |
|
| 70 |
+
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 71 |
+
|
| 72 |
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 73 |
|
| 74 |
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@@ -1,8 +1,20 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include "common.cuh"
|
|
|
|
| 4 |
#include <cstdint>
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
|
| 7 |
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
|
| 8 |
|
|
@@ -16,6 +28,20 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
|
|
| 16 |
return ((const int *) x)[i32]; // assume at least 4 byte alignment
|
| 17 |
}
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
| 20 |
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
| 21 |
|
|
@@ -211,6 +237,30 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_
|
|
| 211 |
return d8_1*sumf;
|
| 212 |
}
|
| 213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
#define VDR_Q2_K_Q8_1_MMVQ 1
|
| 215 |
#define VDR_Q2_K_Q8_1_MMQ 4
|
| 216 |
|
|
@@ -1068,20 +1118,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
|
|
| 1068 |
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
|
| 1069 |
}
|
| 1070 |
|
| 1071 |
-
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
|
| 1072 |
-
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
| 1073 |
-
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
| 1074 |
-
const char4 val0_8 = make_char4(
|
| 1075 |
-
kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
|
| 1076 |
-
|
| 1077 |
-
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
|
| 1078 |
-
const int8_t * q1_8 = (const int8_t *) &q1_32;
|
| 1079 |
-
const char4 val1_8 = make_char4(
|
| 1080 |
-
kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
|
| 1081 |
-
|
| 1082 |
-
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
| 1083 |
-
}
|
| 1084 |
-
|
| 1085 |
#define VDR_IQ4_NL_Q8_1_MMVQ 2
|
| 1086 |
#define VDR_IQ4_NL_Q8_1_MMQ 4
|
| 1087 |
|
|
@@ -1096,7 +1132,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
|
|
| 1096 |
#pragma unroll
|
| 1097 |
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
|
| 1098 |
const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
|
| 1099 |
-
const int2 v = get_int_from_table_16(aux_q4);
|
| 1100 |
|
| 1101 |
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
|
| 1102 |
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
|
|
@@ -1118,7 +1154,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
|
|
| 1118 |
#pragma unroll
|
| 1119 |
for (int j = 0; j < 4; ++j) {
|
| 1120 |
const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
|
| 1121 |
-
const int2 v = get_int_from_table_16(aux_q4);
|
| 1122 |
|
| 1123 |
const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
|
| 1124 |
const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include "common.cuh"
|
| 4 |
+
|
| 5 |
#include <cstdint>
|
| 6 |
|
| 7 |
+
static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
|
| 8 |
+
const uint8_t * x8 = (const uint8_t *) x;
|
| 9 |
+
|
| 10 |
+
int x32 = x8[4*i32 + 0] << 0;
|
| 11 |
+
x32 |= x8[4*i32 + 1] << 8;
|
| 12 |
+
x32 |= x8[4*i32 + 2] << 16;
|
| 13 |
+
x32 |= x8[4*i32 + 3] << 24;
|
| 14 |
+
|
| 15 |
+
return x32;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
|
| 19 |
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
|
| 20 |
|
|
|
|
| 28 |
return ((const int *) x)[i32]; // assume at least 4 byte alignment
|
| 29 |
}
|
| 30 |
|
| 31 |
+
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
|
| 32 |
+
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
| 33 |
+
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
| 34 |
+
const char4 val0_8 = make_char4(
|
| 35 |
+
table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
|
| 36 |
+
|
| 37 |
+
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
|
| 38 |
+
const int8_t * q1_8 = (const int8_t *) &q1_32;
|
| 39 |
+
const char4 val1_8 = make_char4(
|
| 40 |
+
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
|
| 41 |
+
|
| 42 |
+
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
| 46 |
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
| 47 |
|
|
|
|
| 237 |
return d8_1*sumf;
|
| 238 |
}
|
| 239 |
|
| 240 |
+
#define VDR_MXFP4_Q8_1_MMVQ 2
|
| 241 |
+
#define VDR_MXFP4_Q8_1_MMQ 4
|
| 242 |
+
|
| 243 |
+
static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
|
| 244 |
+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
| 245 |
+
|
| 246 |
+
const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
|
| 247 |
+
|
| 248 |
+
const int * q8 = (const int *) bq8_1->qs + iqs;
|
| 249 |
+
|
| 250 |
+
int sumi = 0;
|
| 251 |
+
#pragma unroll
|
| 252 |
+
for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
|
| 253 |
+
const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
|
| 254 |
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
| 255 |
+
|
| 256 |
+
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
|
| 257 |
+
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
|
| 261 |
+
return d * sumi;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
#define VDR_Q2_K_Q8_1_MMVQ 1
|
| 265 |
#define VDR_Q2_K_Q8_1_MMQ 4
|
| 266 |
|
|
|
|
| 1118 |
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
|
| 1119 |
}
|
| 1120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1121 |
#define VDR_IQ4_NL_Q8_1_MMVQ 2
|
| 1122 |
#define VDR_IQ4_NL_Q8_1_MMQ 4
|
| 1123 |
|
|
|
|
| 1132 |
#pragma unroll
|
| 1133 |
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
|
| 1134 |
const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
|
| 1135 |
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
| 1136 |
|
| 1137 |
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
|
| 1138 |
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
|
|
|
|
| 1154 |
#pragma unroll
|
| 1155 |
for (int j = 0; j < 4; ++j) {
|
| 1156 |
const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
|
| 1157 |
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
| 1158 |
|
| 1159 |
const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
|
| 1160 |
const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
|
|
@@ -6,6 +6,10 @@
|
|
| 6 |
#include <cuda_bf16.h>
|
| 7 |
#include <cuda_fp16.h>
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
#if CUDART_VERSION < 11020
|
| 10 |
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
| 11 |
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
|
|
|
| 6 |
#include <cuda_bf16.h>
|
| 7 |
#include <cuda_fp16.h>
|
| 8 |
|
| 9 |
+
#if CUDART_VERSION >= 12050
|
| 10 |
+
#include <cuda_fp8.h>
|
| 11 |
+
#endif // CUDART_VERSION >= 12050
|
| 12 |
+
|
| 13 |
#if CUDART_VERSION < 11020
|
| 14 |
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
| 15 |
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
|
@@ -410,6 +410,67 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
|
| 410 |
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
|
| 411 |
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
| 412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
/**
|
| 414 |
* Converts brain16 to float32.
|
| 415 |
*
|
|
|
|
| 410 |
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
|
| 411 |
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
| 412 |
|
| 413 |
+
static inline float ggml_e8m0_to_fp32(uint8_t x) {
|
| 414 |
+
uint32_t bits; // Stores the raw bit representation of the float
|
| 415 |
+
|
| 416 |
+
// Handle special case for minimum exponent (denormalized float)
|
| 417 |
+
if (x == 0) {
|
| 418 |
+
// Bit pattern for 2^(-127):
|
| 419 |
+
// - Sign bit: 0 (positive)
|
| 420 |
+
// - Exponent: 0 (denormalized number)
|
| 421 |
+
// - Mantissa: 0x400000 (0.5 in fractional form)
|
| 422 |
+
// Value = 0.5 * 2^(-126) = 2^(-127)
|
| 423 |
+
bits = 0x00400000;
|
| 424 |
+
}
|
| 425 |
+
// note: disabled as we don't need to handle NaNs
|
| 426 |
+
//// Handle special case for NaN (all bits set)
|
| 427 |
+
//else if (x == 0xFF) {
|
| 428 |
+
// // Standard quiet NaN pattern:
|
| 429 |
+
// // - Sign bit: 0
|
| 430 |
+
// // - Exponent: all 1s (0xFF)
|
| 431 |
+
// // - Mantissa: 0x400000 (quiet NaN flag)
|
| 432 |
+
// bits = 0x7FC00000;
|
| 433 |
+
//}
|
| 434 |
+
// Normalized values (most common case)
|
| 435 |
+
else {
|
| 436 |
+
// Construct normalized float by shifting exponent into position:
|
| 437 |
+
// - Exponent field: 8 bits (positions 30-23)
|
| 438 |
+
// - Mantissa: 0 (implicit leading 1)
|
| 439 |
+
// Value = 2^(x - 127)
|
| 440 |
+
bits = (uint32_t) x << 23;
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
float result; // Final float value
|
| 444 |
+
// Safely reinterpret bit pattern as float without type-punning issues
|
| 445 |
+
memcpy(&result, &bits, sizeof(float));
|
| 446 |
+
return result;
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
// Equal to ggml_e8m0_to_fp32/2
|
| 450 |
+
// Useful with MXFP4 quantization since the E0M2 values are doubled
|
| 451 |
+
static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
|
| 452 |
+
uint32_t bits;
|
| 453 |
+
|
| 454 |
+
// For x < 2: use precomputed denormal patterns
|
| 455 |
+
if (x < 2) {
|
| 456 |
+
// 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
|
| 457 |
+
bits = 0x00200000 << x;
|
| 458 |
+
}
|
| 459 |
+
// For x >= 2: normalized exponent adjustment
|
| 460 |
+
else {
|
| 461 |
+
// 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
|
| 462 |
+
bits = (uint32_t)(x - 1) << 23;
|
| 463 |
+
}
|
| 464 |
+
// Note: NaNs are not handled here
|
| 465 |
+
|
| 466 |
+
float result;
|
| 467 |
+
memcpy(&result, &bits, sizeof(float));
|
| 468 |
+
return result;
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
|
| 472 |
+
#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
|
| 473 |
+
|
| 474 |
/**
|
| 475 |
* Converts brain16 to float32.
|
| 476 |
*
|
|
@@ -23,6 +23,9 @@
|
|
| 23 |
#define N_R0_Q8_0 4
|
| 24 |
#define N_SG_Q8_0 2
|
| 25 |
|
|
|
|
|
|
|
|
|
|
| 26 |
#define N_R0_Q2_K 4
|
| 27 |
#define N_SG_Q2_K 2
|
| 28 |
|
|
@@ -129,6 +132,15 @@ typedef struct {
|
|
| 129 |
uint64_t o1[8];
|
| 130 |
} ggml_metal_kargs_bin;
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
typedef struct {
|
| 133 |
int32_t ne00;
|
| 134 |
int32_t ne01;
|
|
@@ -444,6 +456,8 @@ typedef struct{
|
|
| 444 |
uint64_t nb1;
|
| 445 |
int32_t i00;
|
| 446 |
int32_t i10;
|
|
|
|
|
|
|
| 447 |
} ggml_metal_kargs_glu;
|
| 448 |
|
| 449 |
typedef struct {
|
|
|
|
| 23 |
#define N_R0_Q8_0 4
|
| 24 |
#define N_SG_Q8_0 2
|
| 25 |
|
| 26 |
+
#define N_R0_MXFP4 2
|
| 27 |
+
#define N_SG_MXFP4 2
|
| 28 |
+
|
| 29 |
#define N_R0_Q2_K 4
|
| 30 |
#define N_SG_Q2_K 2
|
| 31 |
|
|
|
|
| 132 |
uint64_t o1[8];
|
| 133 |
} ggml_metal_kargs_bin;
|
| 134 |
|
| 135 |
+
typedef struct {
|
| 136 |
+
int64_t ne0;
|
| 137 |
+
int64_t ne1;
|
| 138 |
+
size_t nb01;
|
| 139 |
+
size_t nb02;
|
| 140 |
+
size_t nb11;
|
| 141 |
+
size_t nb21;
|
| 142 |
+
} ggml_metal_kargs_add_id;
|
| 143 |
+
|
| 144 |
typedef struct {
|
| 145 |
int32_t ne00;
|
| 146 |
int32_t ne01;
|
|
|
|
| 456 |
uint64_t nb1;
|
| 457 |
int32_t i00;
|
| 458 |
int32_t i10;
|
| 459 |
+
float alpha;
|
| 460 |
+
float limit;
|
| 461 |
} ggml_metal_kargs_glu;
|
| 462 |
|
| 463 |
typedef struct {
|
|
@@ -195,6 +195,7 @@ enum ggml_metal_kernel_type {
|
|
| 195 |
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
|
| 196 |
GGML_METAL_KERNEL_TYPE_DIV,
|
| 197 |
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
|
|
|
|
| 198 |
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
| 199 |
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
| 200 |
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
|
@@ -234,6 +235,7 @@ enum ggml_metal_kernel_type {
|
|
| 234 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
|
| 235 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
|
| 236 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
|
|
|
|
| 237 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
|
| 238 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
|
| 239 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
|
|
@@ -286,6 +288,7 @@ enum ggml_metal_kernel_type {
|
|
| 286 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
| 287 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
| 288 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
|
|
|
| 289 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
|
| 290 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
|
| 291 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
|
|
@@ -310,6 +313,10 @@ enum ggml_metal_kernel_type {
|
|
| 310 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
|
| 311 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
|
| 312 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
|
| 314 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
|
| 315 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
|
|
@@ -351,6 +358,7 @@ enum ggml_metal_kernel_type {
|
|
| 351 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
| 352 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
|
| 353 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
|
|
|
|
| 354 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
|
| 355 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
|
| 356 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
|
|
@@ -373,6 +381,7 @@ enum ggml_metal_kernel_type {
|
|
| 373 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
|
| 374 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
|
| 375 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
|
|
|
|
| 376 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
|
| 377 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
|
| 378 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
|
|
@@ -397,6 +406,7 @@ enum ggml_metal_kernel_type {
|
|
| 397 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
|
| 398 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
|
| 399 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
|
|
|
|
| 400 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
|
| 401 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
|
| 402 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
|
|
@@ -579,6 +589,7 @@ enum ggml_metal_kernel_type {
|
|
| 579 |
GGML_METAL_KERNEL_TYPE_REGLU,
|
| 580 |
GGML_METAL_KERNEL_TYPE_GEGLU,
|
| 581 |
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
|
|
|
| 582 |
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
|
| 583 |
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
|
| 584 |
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
@@ -1199,6 +1210,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1199 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
|
| 1200 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
| 1201 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
|
|
|
|
| 1202 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
| 1203 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
| 1204 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
|
@@ -1238,6 +1250,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1238 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
| 1239 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
| 1240 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
|
|
|
| 1241 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
| 1242 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
| 1243 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
|
@@ -1290,6 +1303,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1290 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
| 1291 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
| 1292 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
|
|
|
| 1293 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
| 1294 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
| 1295 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
|
@@ -1314,6 +1328,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1314 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
|
| 1315 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
|
| 1316 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1317 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
|
| 1318 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
|
| 1319 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
|
|
@@ -1355,6 +1373,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1355 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
|
| 1356 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
|
| 1357 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
|
|
|
|
| 1358 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
|
| 1359 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
|
| 1360 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
|
|
@@ -1377,6 +1396,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1377 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
|
| 1378 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
|
| 1379 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
|
|
|
|
|
|
|
| 1380 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
|
| 1381 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
|
| 1382 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
|
|
@@ -1401,6 +1422,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1401 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
|
| 1402 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
|
| 1403 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
|
|
|
|
| 1404 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
|
| 1405 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
|
| 1406 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
|
|
@@ -1583,6 +1605,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1583 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
| 1584 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
| 1585 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
|
|
|
| 1586 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
|
| 1587 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
|
| 1588 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
@@ -1774,6 +1797,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1774 |
case GGML_GLU_OP_REGLU:
|
| 1775 |
case GGML_GLU_OP_GEGLU:
|
| 1776 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
| 1777 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 1778 |
case GGML_GLU_OP_GEGLU_QUICK:
|
| 1779 |
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
|
@@ -1791,6 +1815,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1791 |
case GGML_OP_SUB:
|
| 1792 |
case GGML_OP_MUL:
|
| 1793 |
case GGML_OP_DIV:
|
|
|
|
| 1794 |
return op->src[0]->type == GGML_TYPE_F32;
|
| 1795 |
case GGML_OP_ACC:
|
| 1796 |
case GGML_OP_REPEAT:
|
|
@@ -2042,6 +2067,7 @@ static int ggml_metal_encode_node(
|
|
| 2042 |
|
| 2043 |
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
| 2044 |
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
|
|
|
| 2045 |
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
| 2046 |
|
| 2047 |
size_t offs_src0 = 0;
|
|
@@ -2291,6 +2317,38 @@ static int ggml_metal_encode_node(
|
|
| 2291 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2292 |
}
|
| 2293 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2294 |
case GGML_OP_REPEAT:
|
| 2295 |
{
|
| 2296 |
id<MTLComputePipelineState> pipeline;
|
|
@@ -2710,6 +2768,9 @@ static int ggml_metal_encode_node(
|
|
| 2710 |
case GGML_GLU_OP_SWIGLU:
|
| 2711 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
| 2712 |
break;
|
|
|
|
|
|
|
|
|
|
| 2713 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 2714 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
|
| 2715 |
break;
|
|
@@ -2720,7 +2781,9 @@ static int ggml_metal_encode_node(
|
|
| 2720 |
GGML_ABORT("fatal error");
|
| 2721 |
}
|
| 2722 |
|
| 2723 |
-
const int32_t swp = (
|
|
|
|
|
|
|
| 2724 |
|
| 2725 |
const int32_t i00 = swp ? ne0 : 0;
|
| 2726 |
const int32_t i10 = swp ? 0 : ne0;
|
|
@@ -2734,6 +2797,8 @@ static int ggml_metal_encode_node(
|
|
| 2734 |
/*.nb1 =*/ nb1,
|
| 2735 |
/*.i00 =*/ src1 ? 0 : i00,
|
| 2736 |
/*.i10 =*/ src1 ? 0 : i10,
|
|
|
|
|
|
|
| 2737 |
};
|
| 2738 |
|
| 2739 |
[encoder setComputePipelineState:pipeline];
|
|
@@ -2992,8 +3057,13 @@ static int ggml_metal_encode_node(
|
|
| 2992 |
} else {
|
| 2993 |
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
|
| 2994 |
}
|
| 2995 |
-
|
| 2996 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2997 |
|
| 2998 |
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
| 2999 |
|
|
@@ -3291,6 +3361,7 @@ static int ggml_metal_encode_node(
|
|
| 3291 |
src0t == GGML_TYPE_Q5_0 ||
|
| 3292 |
src0t == GGML_TYPE_Q5_1 ||
|
| 3293 |
src0t == GGML_TYPE_Q8_0 ||
|
|
|
|
| 3294 |
src0t == GGML_TYPE_IQ4_NL ||
|
| 3295 |
false) && (ne11 >= 2 && ne11 <= 8)
|
| 3296 |
) ||
|
|
@@ -3383,6 +3454,14 @@ static int ggml_metal_encode_node(
|
|
| 3383 |
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
|
| 3384 |
default: GGML_ABORT("not implemented");
|
| 3385 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3386 |
case GGML_TYPE_Q4_K:
|
| 3387 |
switch (r1ptg) {
|
| 3388 |
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
|
|
@@ -3481,6 +3560,7 @@ static int ggml_metal_encode_node(
|
|
| 3481 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
| 3482 |
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
| 3483 |
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
|
|
|
| 3484 |
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
| 3485 |
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
| 3486 |
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
|
@@ -3623,6 +3703,13 @@ static int ggml_metal_encode_node(
|
|
| 3623 |
nr0 = N_R0_Q8_0;
|
| 3624 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
| 3625 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3626 |
case GGML_TYPE_Q2_K:
|
| 3627 |
{
|
| 3628 |
nsg = N_SG_Q2_K;
|
|
@@ -3756,8 +3843,6 @@ static int ggml_metal_encode_node(
|
|
| 3756 |
case GGML_OP_MUL_MAT_ID:
|
| 3757 |
{
|
| 3758 |
// src2 = ids
|
| 3759 |
-
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
| 3760 |
-
|
| 3761 |
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
| 3762 |
|
| 3763 |
GGML_ASSERT(!ggml_is_transposed(src0));
|
|
@@ -3883,6 +3968,7 @@ static int ggml_metal_encode_node(
|
|
| 3883 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
|
| 3884 |
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
|
| 3885 |
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
|
|
|
|
| 3886 |
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
|
| 3887 |
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
|
| 3888 |
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
|
|
@@ -4018,6 +4104,13 @@ static int ggml_metal_encode_node(
|
|
| 4018 |
nr0 = N_R0_Q8_0;
|
| 4019 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
|
| 4020 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4021 |
case GGML_TYPE_Q2_K:
|
| 4022 |
{
|
| 4023 |
nsg = N_SG_Q2_K;
|
|
@@ -4170,6 +4263,7 @@ static int ggml_metal_encode_node(
|
|
| 4170 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
|
| 4171 |
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
|
| 4172 |
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
|
|
|
|
| 4173 |
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
|
| 4174 |
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
|
| 4175 |
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
|
|
@@ -4980,11 +5074,14 @@ static int ggml_metal_encode_node(
|
|
| 4980 |
GGML_ASSERT(ne11 == ne21);
|
| 4981 |
GGML_ASSERT(ne12 == ne22);
|
| 4982 |
|
| 4983 |
-
struct ggml_tensor * src3 = node->src[3];
|
|
|
|
| 4984 |
|
| 4985 |
size_t offs_src3 = 0;
|
|
|
|
| 4986 |
|
| 4987 |
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
|
|
|
| 4988 |
|
| 4989 |
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
|
| 4990 |
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
|
@@ -5000,8 +5097,6 @@ static int ggml_metal_encode_node(
|
|
| 5000 |
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
|
| 5001 |
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
|
| 5002 |
|
| 5003 |
-
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
| 5004 |
-
|
| 5005 |
float scale;
|
| 5006 |
float max_bias;
|
| 5007 |
float logit_softcap;
|
|
@@ -5389,7 +5484,12 @@ static int ggml_metal_encode_node(
|
|
| 5389 |
} else {
|
| 5390 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
|
| 5391 |
}
|
| 5392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5393 |
|
| 5394 |
if (!use_vec_kernel) {
|
| 5395 |
// half8x8 kernel
|
|
|
|
| 195 |
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
|
| 196 |
GGML_METAL_KERNEL_TYPE_DIV,
|
| 197 |
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
|
| 198 |
+
GGML_METAL_KERNEL_TYPE_ADD_ID,
|
| 199 |
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
| 200 |
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
| 201 |
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
|
|
|
| 235 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
|
| 236 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
|
| 237 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
|
| 238 |
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,
|
| 239 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
|
| 240 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
|
| 241 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
|
|
|
|
| 288 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
| 289 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
| 290 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
| 291 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
|
| 292 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
|
| 293 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
|
| 294 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
|
|
|
|
| 313 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
|
| 314 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
|
| 315 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
|
| 316 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2,
|
| 317 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3,
|
| 318 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4,
|
| 319 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5,
|
| 320 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
|
| 321 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
|
| 322 |
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
|
|
|
|
| 358 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
| 359 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
|
| 360 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
|
| 361 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
|
| 362 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
|
| 363 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
|
| 364 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
|
|
|
|
| 381 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
|
| 382 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
|
| 383 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
|
| 384 |
+
GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
|
| 385 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
|
| 386 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
|
| 387 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
|
|
|
|
| 406 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
|
| 407 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
|
| 408 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
|
| 409 |
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,
|
| 410 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
|
| 411 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
|
| 412 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
|
|
|
|
| 589 |
GGML_METAL_KERNEL_TYPE_REGLU,
|
| 590 |
GGML_METAL_KERNEL_TYPE_GEGLU,
|
| 591 |
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
| 592 |
+
GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,
|
| 593 |
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
|
| 594 |
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
|
| 595 |
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
|
|
| 1210 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
|
| 1211 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
| 1212 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
|
| 1213 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true);
|
| 1214 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
| 1215 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
| 1216 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
|
|
|
| 1250 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
| 1251 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
| 1252 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
| 1253 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true);
|
| 1254 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
| 1255 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
| 1256 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
|
|
|
| 1303 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
| 1304 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
| 1305 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
| 1306 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
|
| 1307 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
| 1308 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
| 1309 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
|
|
|
| 1328 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
|
| 1329 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
|
| 1330 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
|
| 1331 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2, mul_mv_ext_mxfp4_f32_r1_2, has_simdgroup_reduction);
|
| 1332 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3, mul_mv_ext_mxfp4_f32_r1_3, has_simdgroup_reduction);
|
| 1333 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4, mul_mv_ext_mxfp4_f32_r1_4, has_simdgroup_reduction);
|
| 1334 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5, mul_mv_ext_mxfp4_f32_r1_5, has_simdgroup_reduction);
|
| 1335 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
|
| 1336 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
|
| 1337 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
|
|
|
|
| 1373 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
|
| 1374 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
|
| 1375 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
|
| 1376 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction);
|
| 1377 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
|
| 1378 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
|
| 1379 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
|
|
|
|
| 1396 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
|
| 1397 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
|
| 1398 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
|
| 1399 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
|
| 1400 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
|
| 1401 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
|
| 1402 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
|
| 1403 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
|
|
|
|
| 1422 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
|
| 1423 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
|
| 1424 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
|
| 1425 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm);
|
| 1426 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
|
| 1427 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
|
| 1428 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
|
|
|
|
| 1605 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
| 1606 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
| 1607 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
| 1608 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, swiglu_oai, true);
|
| 1609 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
|
| 1610 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
|
| 1611 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
|
|
| 1797 |
case GGML_GLU_OP_REGLU:
|
| 1798 |
case GGML_GLU_OP_GEGLU:
|
| 1799 |
case GGML_GLU_OP_SWIGLU:
|
| 1800 |
+
case GGML_GLU_OP_SWIGLU_OAI:
|
| 1801 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 1802 |
case GGML_GLU_OP_GEGLU_QUICK:
|
| 1803 |
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
|
|
|
| 1815 |
case GGML_OP_SUB:
|
| 1816 |
case GGML_OP_MUL:
|
| 1817 |
case GGML_OP_DIV:
|
| 1818 |
+
case GGML_OP_ADD_ID:
|
| 1819 |
return op->src[0]->type == GGML_TYPE_F32;
|
| 1820 |
case GGML_OP_ACC:
|
| 1821 |
case GGML_OP_REPEAT:
|
|
|
|
| 2067 |
|
| 2068 |
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
| 2069 |
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
| 2070 |
+
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
|
| 2071 |
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
| 2072 |
|
| 2073 |
size_t offs_src0 = 0;
|
|
|
|
| 2317 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2318 |
}
|
| 2319 |
} break;
|
| 2320 |
+
case GGML_OP_ADD_ID:
|
| 2321 |
+
{
|
| 2322 |
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
| 2323 |
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 2324 |
+
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
| 2325 |
+
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
| 2326 |
+
|
| 2327 |
+
GGML_ASSERT(ggml_is_contiguous_rows(src0));
|
| 2328 |
+
|
| 2329 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ID].pipeline;
|
| 2330 |
+
|
| 2331 |
+
ggml_metal_kargs_add_id args = {
|
| 2332 |
+
/*.ne0 =*/ ne0,
|
| 2333 |
+
/*.ne1 =*/ ne1,
|
| 2334 |
+
/*.nb01 =*/ nb01,
|
| 2335 |
+
/*.nb02 =*/ nb02,
|
| 2336 |
+
/*.nb11 =*/ nb11,
|
| 2337 |
+
/*.nb21 =*/ nb21,
|
| 2338 |
+
|
| 2339 |
+
};
|
| 2340 |
+
|
| 2341 |
+
[encoder setComputePipelineState:pipeline];
|
| 2342 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
| 2343 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 2344 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
| 2345 |
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
| 2346 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
| 2347 |
+
|
| 2348 |
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
| 2349 |
+
|
| 2350 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2351 |
+
} break;
|
| 2352 |
case GGML_OP_REPEAT:
|
| 2353 |
{
|
| 2354 |
id<MTLComputePipelineState> pipeline;
|
|
|
|
| 2768 |
case GGML_GLU_OP_SWIGLU:
|
| 2769 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
| 2770 |
break;
|
| 2771 |
+
case GGML_GLU_OP_SWIGLU_OAI:
|
| 2772 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_OAI].pipeline;
|
| 2773 |
+
break;
|
| 2774 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 2775 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
|
| 2776 |
break;
|
|
|
|
| 2781 |
GGML_ABORT("fatal error");
|
| 2782 |
}
|
| 2783 |
|
| 2784 |
+
const int32_t swp = ggml_get_op_params_i32(dst, 1);
|
| 2785 |
+
const float alpha = ggml_get_op_params_f32(dst, 2);
|
| 2786 |
+
const float limit = ggml_get_op_params_f32(dst, 3);
|
| 2787 |
|
| 2788 |
const int32_t i00 = swp ? ne0 : 0;
|
| 2789 |
const int32_t i10 = swp ? 0 : ne0;
|
|
|
|
| 2797 |
/*.nb1 =*/ nb1,
|
| 2798 |
/*.i00 =*/ src1 ? 0 : i00,
|
| 2799 |
/*.i10 =*/ src1 ? 0 : i10,
|
| 2800 |
+
/*.alpha=*/ alpha,
|
| 2801 |
+
/*.limit=*/ limit
|
| 2802 |
};
|
| 2803 |
|
| 2804 |
[encoder setComputePipelineState:pipeline];
|
|
|
|
| 3057 |
} else {
|
| 3058 |
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
|
| 3059 |
}
|
| 3060 |
+
if (id_src2) {
|
| 3061 |
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
| 3062 |
+
} else {
|
| 3063 |
+
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:2];
|
| 3064 |
+
}
|
| 3065 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
| 3066 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:4];
|
| 3067 |
|
| 3068 |
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
| 3069 |
|
|
|
|
| 3361 |
src0t == GGML_TYPE_Q5_0 ||
|
| 3362 |
src0t == GGML_TYPE_Q5_1 ||
|
| 3363 |
src0t == GGML_TYPE_Q8_0 ||
|
| 3364 |
+
src0t == GGML_TYPE_MXFP4 ||
|
| 3365 |
src0t == GGML_TYPE_IQ4_NL ||
|
| 3366 |
false) && (ne11 >= 2 && ne11 <= 8)
|
| 3367 |
) ||
|
|
|
|
| 3454 |
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
|
| 3455 |
default: GGML_ABORT("not implemented");
|
| 3456 |
} break;
|
| 3457 |
+
case GGML_TYPE_MXFP4:
|
| 3458 |
+
switch (r1ptg) {
|
| 3459 |
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2].pipeline; break;
|
| 3460 |
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3].pipeline; break;
|
| 3461 |
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4].pipeline; break;
|
| 3462 |
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5].pipeline; break;
|
| 3463 |
+
default: GGML_ABORT("not implemented");
|
| 3464 |
+
} break;
|
| 3465 |
case GGML_TYPE_Q4_K:
|
| 3466 |
switch (r1ptg) {
|
| 3467 |
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
|
|
|
|
| 3560 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
| 3561 |
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
| 3562 |
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
| 3563 |
+
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
|
| 3564 |
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
| 3565 |
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
| 3566 |
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
|
|
|
| 3703 |
nr0 = N_R0_Q8_0;
|
| 3704 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
| 3705 |
} break;
|
| 3706 |
+
case GGML_TYPE_MXFP4:
|
| 3707 |
+
{
|
| 3708 |
+
nsg = N_SG_MXFP4;
|
| 3709 |
+
nr0 = N_R0_MXFP4;
|
| 3710 |
+
smem = 32*sizeof(float);
|
| 3711 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
|
| 3712 |
+
} break;
|
| 3713 |
case GGML_TYPE_Q2_K:
|
| 3714 |
{
|
| 3715 |
nsg = N_SG_Q2_K;
|
|
|
|
| 3843 |
case GGML_OP_MUL_MAT_ID:
|
| 3844 |
{
|
| 3845 |
// src2 = ids
|
|
|
|
|
|
|
| 3846 |
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
| 3847 |
|
| 3848 |
GGML_ASSERT(!ggml_is_transposed(src0));
|
|
|
|
| 3968 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
|
| 3969 |
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
|
| 3970 |
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
|
| 3971 |
+
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break;
|
| 3972 |
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
|
| 3973 |
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
|
| 3974 |
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
|
|
|
|
| 4104 |
nr0 = N_R0_Q8_0;
|
| 4105 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
|
| 4106 |
} break;
|
| 4107 |
+
case GGML_TYPE_MXFP4:
|
| 4108 |
+
{
|
| 4109 |
+
nsg = N_SG_MXFP4;
|
| 4110 |
+
nr0 = N_R0_MXFP4;
|
| 4111 |
+
smem = 32*sizeof(float);
|
| 4112 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
|
| 4113 |
+
} break;
|
| 4114 |
case GGML_TYPE_Q2_K:
|
| 4115 |
{
|
| 4116 |
nsg = N_SG_Q2_K;
|
|
|
|
| 4263 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
|
| 4264 |
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
|
| 4265 |
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
|
| 4266 |
+
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4 ].pipeline; break;
|
| 4267 |
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
|
| 4268 |
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
|
| 4269 |
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
|
|
|
|
| 5074 |
GGML_ASSERT(ne11 == ne21);
|
| 5075 |
GGML_ASSERT(ne12 == ne22);
|
| 5076 |
|
| 5077 |
+
struct ggml_tensor * src3 = node->src[3]; // mask
|
| 5078 |
+
struct ggml_tensor * src4 = node->src[4]; // sinks
|
| 5079 |
|
| 5080 |
size_t offs_src3 = 0;
|
| 5081 |
+
size_t offs_src4 = 0;
|
| 5082 |
|
| 5083 |
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
| 5084 |
+
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
| 5085 |
|
| 5086 |
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
|
| 5087 |
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
|
|
|
| 5097 |
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
|
| 5098 |
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
|
| 5099 |
|
|
|
|
|
|
|
| 5100 |
float scale;
|
| 5101 |
float max_bias;
|
| 5102 |
float logit_softcap;
|
|
|
|
| 5484 |
} else {
|
| 5485 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
|
| 5486 |
}
|
| 5487 |
+
if (id_src4) {
|
| 5488 |
+
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
|
| 5489 |
+
} else {
|
| 5490 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
|
| 5491 |
+
}
|
| 5492 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
| 5493 |
|
| 5494 |
if (!use_vec_kernel) {
|
| 5495 |
// half8x8 kernel
|
|
@@ -35,6 +35,10 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
|
|
| 35 |
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
| 36 |
};
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
static inline int best_index_int8(int n, constant float * val, float x) {
|
| 39 |
if (x <= val[0]) return 0;
|
| 40 |
if (x >= val[n-1]) return n-1;
|
|
@@ -46,6 +50,18 @@ static inline int best_index_int8(int n, constant float * val, float x) {
|
|
| 46 |
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
| 47 |
}
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
// NOTE: this is not dequantizing - we are simply fitting the template
|
| 50 |
template <typename type4x4>
|
| 51 |
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
|
@@ -242,6 +258,27 @@ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
|
|
| 242 |
}
|
| 243 |
}
|
| 244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
|
| 246 |
#pragma METAL fp math_mode(safe)
|
| 247 |
float amax = 0.0f; // absolute max
|
|
@@ -462,25 +499,34 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
|
|
| 462 |
}
|
| 463 |
}
|
| 464 |
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
}
|
|
|
|
| 473 |
|
| 474 |
-
|
| 475 |
-
|
|
|
|
| 476 |
|
| 477 |
-
|
|
|
|
| 478 |
|
| 479 |
-
|
| 480 |
-
const float x0 = src[j]*id;
|
| 481 |
|
| 482 |
-
|
| 483 |
-
|
|
|
|
|
|
|
| 484 |
}
|
| 485 |
|
| 486 |
template <typename type4x4>
|
|
@@ -960,6 +1006,32 @@ kernel void kernel_div(
|
|
| 960 |
}
|
| 961 |
}
|
| 962 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 963 |
template<typename T>
|
| 964 |
kernel void kernel_repeat(
|
| 965 |
constant ggml_metal_kargs_repeat & args,
|
|
@@ -1431,6 +1503,32 @@ kernel void kernel_swiglu(
|
|
| 1431 |
}
|
| 1432 |
}
|
| 1433 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1434 |
kernel void kernel_geglu_erf(
|
| 1435 |
device const char * src0,
|
| 1436 |
device const char * src1,
|
|
@@ -1534,6 +1632,7 @@ template<typename T>
|
|
| 1534 |
kernel void kernel_soft_max(
|
| 1535 |
device const char * src0,
|
| 1536 |
device const char * src1,
|
|
|
|
| 1537 |
device char * dst,
|
| 1538 |
constant ggml_metal_kargs_soft_max & args,
|
| 1539 |
threadgroup float * buf [[threadgroup(0)]],
|
|
@@ -1552,6 +1651,7 @@ kernel void kernel_soft_max(
|
|
| 1552 |
|
| 1553 |
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
| 1554 |
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
|
|
|
| 1555 |
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
| 1556 |
|
| 1557 |
float slope = 1.0f;
|
|
@@ -1567,7 +1667,7 @@ kernel void kernel_soft_max(
|
|
| 1567 |
}
|
| 1568 |
|
| 1569 |
// parallel max
|
| 1570 |
-
float lmax = -INFINITY;
|
| 1571 |
|
| 1572 |
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
| 1573 |
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
|
|
@@ -1623,6 +1723,10 @@ kernel void kernel_soft_max(
|
|
| 1623 |
sum = simd_sum(sum);
|
| 1624 |
}
|
| 1625 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1626 |
const float inv_sum = 1.0f/sum;
|
| 1627 |
|
| 1628 |
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
|
@@ -1634,6 +1738,7 @@ template<typename T>
|
|
| 1634 |
kernel void kernel_soft_max_4(
|
| 1635 |
device const char * src0,
|
| 1636 |
device const char * src1,
|
|
|
|
| 1637 |
device char * dst,
|
| 1638 |
constant ggml_metal_kargs_soft_max & args,
|
| 1639 |
threadgroup float * buf [[threadgroup(0)]],
|
|
@@ -1652,6 +1757,7 @@ kernel void kernel_soft_max_4(
|
|
| 1652 |
|
| 1653 |
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
| 1654 |
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
|
|
|
| 1655 |
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
| 1656 |
|
| 1657 |
float slope = 1.0f;
|
|
@@ -1666,7 +1772,7 @@ kernel void kernel_soft_max_4(
|
|
| 1666 |
}
|
| 1667 |
|
| 1668 |
// parallel max
|
| 1669 |
-
float4 lmax4 = -INFINITY;
|
| 1670 |
|
| 1671 |
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
| 1672 |
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
|
@@ -1725,6 +1831,10 @@ kernel void kernel_soft_max_4(
|
|
| 1725 |
sum = simd_sum(sum);
|
| 1726 |
}
|
| 1727 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1728 |
const float inv_sum = 1.0f/sum;
|
| 1729 |
|
| 1730 |
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
|
@@ -3106,6 +3216,11 @@ template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4
|
|
| 3106 |
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
|
| 3107 |
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
|
| 3108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3109 |
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
| 3110 |
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
| 3111 |
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
|
@@ -4092,6 +4207,7 @@ kernel void kernel_flash_attn_ext(
|
|
| 4092 |
device const char * k,
|
| 4093 |
device const char * v,
|
| 4094 |
device const char * mask,
|
|
|
|
| 4095 |
device char * dst,
|
| 4096 |
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
| 4097 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
@@ -4407,6 +4523,35 @@ kernel void kernel_flash_attn_ext(
|
|
| 4407 |
}
|
| 4408 |
}
|
| 4409 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4410 |
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
| 4411 |
for (short j = tiisg; j < Q; j += NW) {
|
| 4412 |
ss[j*TS + 0] = S[j];
|
|
@@ -4618,6 +4763,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 4618 |
device const char * k,
|
| 4619 |
device const char * v,
|
| 4620 |
device const char * mask,
|
|
|
|
| 4621 |
device char * dst,
|
| 4622 |
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
| 4623 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
@@ -4835,6 +4981,23 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 4835 |
}
|
| 4836 |
}
|
| 4837 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4838 |
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
| 4839 |
if (tiisg == 0) {
|
| 4840 |
ss[0] = (s_t) S;
|
|
@@ -6940,6 +7103,95 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
| 6940 |
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 6941 |
}
|
| 6942 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6943 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 6944 |
kernel void kernel_get_rows_q(
|
| 6945 |
constant ggml_metal_kargs_get_rows & args,
|
|
@@ -7475,6 +7727,7 @@ template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get
|
|
| 7475 |
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
| 7476 |
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
|
| 7477 |
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
|
|
|
|
| 7478 |
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
|
| 7479 |
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
|
| 7480 |
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
@@ -7527,6 +7780,7 @@ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_m
|
|
| 7527 |
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
| 7528 |
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
| 7529 |
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
|
|
|
| 7530 |
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
| 7531 |
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
| 7532 |
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
|
@@ -7558,6 +7812,7 @@ template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_m
|
|
| 7558 |
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
| 7559 |
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
| 7560 |
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
|
|
|
| 7561 |
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
| 7562 |
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
| 7563 |
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
|
@@ -7703,6 +7958,8 @@ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t
|
|
| 7703 |
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
|
| 7704 |
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
|
| 7705 |
|
|
|
|
|
|
|
| 7706 |
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
|
| 7707 |
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
|
| 7708 |
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
|
|
|
|
| 35 |
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
| 36 |
};
|
| 37 |
|
| 38 |
+
constexpr constant static float kvalues_mxfp4_f[16] = {
|
| 39 |
+
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
static inline int best_index_int8(int n, constant float * val, float x) {
|
| 43 |
if (x <= val[0]) return 0;
|
| 44 |
if (x >= val[n-1]) return n-1;
|
|
|
|
| 50 |
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
| 51 |
}
|
| 52 |
|
| 53 |
+
static inline float e8m0_to_fp32(uint8_t x) {
|
| 54 |
+
uint32_t bits;
|
| 55 |
+
|
| 56 |
+
if (x == 0) {
|
| 57 |
+
bits = 0x00400000;
|
| 58 |
+
} else {
|
| 59 |
+
bits = (uint32_t) x << 23;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
return as_type<float>(bits);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
// NOTE: this is not dequantizing - we are simply fitting the template
|
| 66 |
template <typename type4x4>
|
| 67 |
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
|
|
|
| 258 |
}
|
| 259 |
}
|
| 260 |
|
| 261 |
+
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
| 262 |
+
#pragma METAL fp math_mode(safe)
|
| 263 |
+
float amax = 0.0f; // absolute max
|
| 264 |
+
|
| 265 |
+
for (int j = 0; j < QK8_0; j++) {
|
| 266 |
+
const float v = src[j];
|
| 267 |
+
amax = MAX(amax, fabs(v));
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
const float d = amax / ((1 << 7) - 1);
|
| 271 |
+
const float id = d ? 1.0f/d : 0.0f;
|
| 272 |
+
|
| 273 |
+
dst.d = d;
|
| 274 |
+
|
| 275 |
+
for (int j = 0; j < QK8_0; ++j) {
|
| 276 |
+
const float x0 = src[j]*id;
|
| 277 |
+
|
| 278 |
+
dst.qs[j] = round(x0);
|
| 279 |
+
}
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
|
| 283 |
#pragma METAL fp math_mode(safe)
|
| 284 |
float amax = 0.0f; // absolute max
|
|
|
|
| 499 |
}
|
| 500 |
}
|
| 501 |
|
| 502 |
+
template <typename type4x4>
|
| 503 |
+
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
|
| 504 |
+
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
|
| 505 |
|
| 506 |
+
const float d = e8m0_to_fp32(xb->e);
|
| 507 |
+
const uint8_t shr = il >= 1 ? 4 : 0;
|
| 508 |
+
|
| 509 |
+
for (int i = 0; i < 4; ++i) {
|
| 510 |
+
reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
|
| 511 |
+
reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
|
| 512 |
+
reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
|
| 513 |
+
reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
|
| 514 |
}
|
| 515 |
+
}
|
| 516 |
|
| 517 |
+
template <typename type4>
|
| 518 |
+
void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
|
| 519 |
+
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
|
| 520 |
|
| 521 |
+
const float d = e8m0_to_fp32(xb->e);
|
| 522 |
+
const short il4 = il%4;
|
| 523 |
|
| 524 |
+
const uint8_t shr = il >= 4 ? 4 : 0;
|
|
|
|
| 525 |
|
| 526 |
+
reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
|
| 527 |
+
reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
|
| 528 |
+
reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
|
| 529 |
+
reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
|
| 530 |
}
|
| 531 |
|
| 532 |
template <typename type4x4>
|
|
|
|
| 1006 |
}
|
| 1007 |
}
|
| 1008 |
|
| 1009 |
+
kernel void kernel_add_id(
|
| 1010 |
+
constant ggml_metal_kargs_add_id & args,
|
| 1011 |
+
device const char * src0,
|
| 1012 |
+
device const char * src1,
|
| 1013 |
+
device const char * src2,
|
| 1014 |
+
device char * dst,
|
| 1015 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1016 |
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
| 1017 |
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
| 1018 |
+
const int i1 = tgpig.x;
|
| 1019 |
+
const int i2 = tgpig.y;
|
| 1020 |
+
|
| 1021 |
+
const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
|
| 1022 |
+
|
| 1023 |
+
const size_t nb1 = args.ne0 * sizeof(float);
|
| 1024 |
+
const size_t nb2 = args.ne1 * nb1;
|
| 1025 |
+
|
| 1026 |
+
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
|
| 1027 |
+
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
|
| 1028 |
+
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
|
| 1029 |
+
|
| 1030 |
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
| 1031 |
+
dst_row[i0] = src0_row[i0] + src1_row[i0];
|
| 1032 |
+
}
|
| 1033 |
+
}
|
| 1034 |
+
|
| 1035 |
template<typename T>
|
| 1036 |
kernel void kernel_repeat(
|
| 1037 |
constant ggml_metal_kargs_repeat & args,
|
|
|
|
| 1503 |
}
|
| 1504 |
}
|
| 1505 |
|
| 1506 |
+
kernel void kernel_swiglu_oai(
|
| 1507 |
+
device const char * src0,
|
| 1508 |
+
device const char * src1,
|
| 1509 |
+
device char * dst,
|
| 1510 |
+
constant ggml_metal_kargs_glu & args,
|
| 1511 |
+
uint tgpig[[threadgroup_position_in_grid]],
|
| 1512 |
+
uint tpitg[[thread_position_in_threadgroup]],
|
| 1513 |
+
uint ntg[[threads_per_threadgroup]]) {
|
| 1514 |
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
| 1515 |
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
| 1516 |
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
| 1517 |
+
|
| 1518 |
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
| 1519 |
+
float x0 = src0_row[i0];
|
| 1520 |
+
float x1 = src1_row[i0];
|
| 1521 |
+
|
| 1522 |
+
x0 = min(x0, args.limit);
|
| 1523 |
+
x1 = max(min(x1, args.limit), -args.limit);
|
| 1524 |
+
|
| 1525 |
+
float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
|
| 1526 |
+
out_glu = out_glu * (1.0f + x1);
|
| 1527 |
+
|
| 1528 |
+
dst_row[i0] = out_glu;
|
| 1529 |
+
}
|
| 1530 |
+
}
|
| 1531 |
+
|
| 1532 |
kernel void kernel_geglu_erf(
|
| 1533 |
device const char * src0,
|
| 1534 |
device const char * src1,
|
|
|
|
| 1632 |
kernel void kernel_soft_max(
|
| 1633 |
device const char * src0,
|
| 1634 |
device const char * src1,
|
| 1635 |
+
device const char * src2,
|
| 1636 |
device char * dst,
|
| 1637 |
constant ggml_metal_kargs_soft_max & args,
|
| 1638 |
threadgroup float * buf [[threadgroup(0)]],
|
|
|
|
| 1651 |
|
| 1652 |
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
| 1653 |
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
| 1654 |
+
device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
|
| 1655 |
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
| 1656 |
|
| 1657 |
float slope = 1.0f;
|
|
|
|
| 1667 |
}
|
| 1668 |
|
| 1669 |
// parallel max
|
| 1670 |
+
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
|
| 1671 |
|
| 1672 |
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
| 1673 |
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
|
|
|
|
| 1723 |
sum = simd_sum(sum);
|
| 1724 |
}
|
| 1725 |
|
| 1726 |
+
if (psrc2) {
|
| 1727 |
+
sum += exp(psrc2[i02] - max_val);
|
| 1728 |
+
}
|
| 1729 |
+
|
| 1730 |
const float inv_sum = 1.0f/sum;
|
| 1731 |
|
| 1732 |
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
|
|
|
| 1738 |
kernel void kernel_soft_max_4(
|
| 1739 |
device const char * src0,
|
| 1740 |
device const char * src1,
|
| 1741 |
+
device const char * src2,
|
| 1742 |
device char * dst,
|
| 1743 |
constant ggml_metal_kargs_soft_max & args,
|
| 1744 |
threadgroup float * buf [[threadgroup(0)]],
|
|
|
|
| 1757 |
|
| 1758 |
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
| 1759 |
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
| 1760 |
+
device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
|
| 1761 |
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
| 1762 |
|
| 1763 |
float slope = 1.0f;
|
|
|
|
| 1772 |
}
|
| 1773 |
|
| 1774 |
// parallel max
|
| 1775 |
+
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
|
| 1776 |
|
| 1777 |
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
| 1778 |
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
|
|
|
| 1831 |
sum = simd_sum(sum);
|
| 1832 |
}
|
| 1833 |
|
| 1834 |
+
if (psrc2) {
|
| 1835 |
+
sum += exp(psrc2[i02] - max_val);
|
| 1836 |
+
}
|
| 1837 |
+
|
| 1838 |
const float inv_sum = 1.0f/sum;
|
| 1839 |
|
| 1840 |
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
|
|
|
| 3216 |
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
|
| 3217 |
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
|
| 3218 |
|
| 3219 |
+
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4, 32, dequantize_mxfp4_t4>;
|
| 3220 |
+
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4, 32, dequantize_mxfp4_t4>;
|
| 3221 |
+
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>;
|
| 3222 |
+
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>;
|
| 3223 |
+
|
| 3224 |
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
| 3225 |
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
| 3226 |
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
|
|
|
| 4207 |
device const char * k,
|
| 4208 |
device const char * v,
|
| 4209 |
device const char * mask,
|
| 4210 |
+
device const char * sinks,
|
| 4211 |
device char * dst,
|
| 4212 |
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
| 4213 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
| 4523 |
}
|
| 4524 |
}
|
| 4525 |
|
| 4526 |
+
if (sinks != q && sgitg == 0) {
|
| 4527 |
+
for (ushort j = 0; j < Q; ++j) {
|
| 4528 |
+
const float m = M[j];
|
| 4529 |
+
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
|
| 4530 |
+
|
| 4531 |
+
M[j] = simd_max(max(M[j], s));
|
| 4532 |
+
|
| 4533 |
+
const float ms = exp(m - M[j]);
|
| 4534 |
+
const float vs = exp(s - M[j]);
|
| 4535 |
+
|
| 4536 |
+
S[j] = S[j]*ms + simd_sum(vs);
|
| 4537 |
+
|
| 4538 |
+
if (tiisg == j) {
|
| 4539 |
+
ss[j*TS + 2*C + j] = ms;
|
| 4540 |
+
}
|
| 4541 |
+
}
|
| 4542 |
+
|
| 4543 |
+
// O = diag(ms)*O
|
| 4544 |
+
{
|
| 4545 |
+
s8x8_t ms;
|
| 4546 |
+
simdgroup_load(ms, ss + 2*C, TS, 0, false);
|
| 4547 |
+
|
| 4548 |
+
#pragma unroll(DV8)
|
| 4549 |
+
for (short i = 0; i < DV8; ++i) {
|
| 4550 |
+
simdgroup_multiply(lo[i], ms, lo[i]);
|
| 4551 |
+
}
|
| 4552 |
+
}
|
| 4553 |
+
}
|
| 4554 |
+
|
| 4555 |
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
| 4556 |
for (short j = tiisg; j < Q; j += NW) {
|
| 4557 |
ss[j*TS + 0] = S[j];
|
|
|
|
| 4763 |
device const char * k,
|
| 4764 |
device const char * v,
|
| 4765 |
device const char * mask,
|
| 4766 |
+
device const char * sinks,
|
| 4767 |
device char * dst,
|
| 4768 |
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
| 4769 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
| 4981 |
}
|
| 4982 |
}
|
| 4983 |
|
| 4984 |
+
if (sinks != q && sgitg == 0) {
|
| 4985 |
+
const float m = M;
|
| 4986 |
+
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
|
| 4987 |
+
|
| 4988 |
+
M = simd_max(max(M, s));
|
| 4989 |
+
|
| 4990 |
+
const float ms = exp(m - M);
|
| 4991 |
+
const float vs = exp(s - M);
|
| 4992 |
+
|
| 4993 |
+
S = S*ms + simd_sum(vs);
|
| 4994 |
+
|
| 4995 |
+
#pragma unroll(DV4/NL)
|
| 4996 |
+
for (short ii = 0; ii < DV4; ii += NL) {
|
| 4997 |
+
lo[ii/NL] *= ms;
|
| 4998 |
+
}
|
| 4999 |
+
}
|
| 5000 |
+
|
| 5001 |
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
| 5002 |
if (tiisg == 0) {
|
| 5003 |
ss[0] = (s_t) S;
|
|
|
|
| 7103 |
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 7104 |
}
|
| 7105 |
|
| 7106 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 7107 |
+
void kernel_mul_mv_mxfp4_f32_impl(
|
| 7108 |
+
args_t args,
|
| 7109 |
+
device const char * src0,
|
| 7110 |
+
device const char * src1,
|
| 7111 |
+
device char * dst,
|
| 7112 |
+
threadgroup char * shmem,
|
| 7113 |
+
uint3 tgpig,
|
| 7114 |
+
ushort tiisg,
|
| 7115 |
+
ushort sgitg) {
|
| 7116 |
+
|
| 7117 |
+
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
| 7118 |
+
const int nb = args.ne00/QK_MXFP4;
|
| 7119 |
+
|
| 7120 |
+
const int r0 = tgpig.x;
|
| 7121 |
+
const int r1 = tgpig.y;
|
| 7122 |
+
const int im = tgpig.z;
|
| 7123 |
+
|
| 7124 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 7125 |
+
|
| 7126 |
+
const uint i12 = im%args.ne12;
|
| 7127 |
+
const uint i13 = im/args.ne12;
|
| 7128 |
+
|
| 7129 |
+
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
| 7130 |
+
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
| 7131 |
+
|
| 7132 |
+
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
|
| 7133 |
+
device const float * y = (device const float *) (src1 + offset1);
|
| 7134 |
+
|
| 7135 |
+
const short ix = tiisg/2; // 0...15
|
| 7136 |
+
const short it = tiisg%2; // 0 or 1
|
| 7137 |
+
|
| 7138 |
+
shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
|
| 7139 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 7140 |
+
|
| 7141 |
+
float4 yl[4];
|
| 7142 |
+
float sumf[nr0]={0.f};
|
| 7143 |
+
|
| 7144 |
+
device const float * yb = y + ix * QK_MXFP4 + it * 8;
|
| 7145 |
+
|
| 7146 |
+
for (int ib = ix; ib < nb; ib += 16) {
|
| 7147 |
+
device const float4 * y4 = (device const float4 *)yb;
|
| 7148 |
+
yl[0] = y4[0];
|
| 7149 |
+
yl[1] = y4[4];
|
| 7150 |
+
yl[2] = y4[1];
|
| 7151 |
+
yl[3] = y4[5];
|
| 7152 |
+
|
| 7153 |
+
#pragma unroll(nr0)
|
| 7154 |
+
for (short row = 0; row < nr0; row++) {
|
| 7155 |
+
device const block_mxfp4 & xb = x[row*nb + ib];
|
| 7156 |
+
device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
|
| 7157 |
+
|
| 7158 |
+
float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
|
| 7159 |
+
float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]);
|
| 7160 |
+
float4 acc3 = yl[2]*float4(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]);
|
| 7161 |
+
float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]);
|
| 7162 |
+
|
| 7163 |
+
acc1 = (acc1 + acc3) + (acc2 + acc4);
|
| 7164 |
+
|
| 7165 |
+
sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3]));
|
| 7166 |
+
}
|
| 7167 |
+
|
| 7168 |
+
yb += 16 * QK_MXFP4;
|
| 7169 |
+
}
|
| 7170 |
+
|
| 7171 |
+
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 7172 |
+
|
| 7173 |
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
| 7174 |
+
float sum_all = simd_sum(sumf[row]);
|
| 7175 |
+
if (tiisg == 0) {
|
| 7176 |
+
dst_f32[first_row + row] = sum_all;
|
| 7177 |
+
}
|
| 7178 |
+
}
|
| 7179 |
+
}
|
| 7180 |
+
|
| 7181 |
+
[[host_name("kernel_mul_mv_mxfp4_f32")]]
|
| 7182 |
+
kernel void kernel_mul_mv_mxfp4_f32(
|
| 7183 |
+
constant ggml_metal_kargs_mul_mv & args,
|
| 7184 |
+
device const char * src0,
|
| 7185 |
+
device const char * src1,
|
| 7186 |
+
device char * dst,
|
| 7187 |
+
threadgroup char * shmem [[threadgroup(0)]],
|
| 7188 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 7189 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 7190 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 7191 |
+
|
| 7192 |
+
kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 7193 |
+
}
|
| 7194 |
+
|
| 7195 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 7196 |
kernel void kernel_get_rows_q(
|
| 7197 |
constant ggml_metal_kargs_get_rows & args,
|
|
|
|
| 7727 |
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
| 7728 |
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
|
| 7729 |
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
|
| 7730 |
+
template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
|
| 7731 |
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
|
| 7732 |
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
|
| 7733 |
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
|
|
| 7780 |
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
| 7781 |
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
| 7782 |
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
| 7783 |
+
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
|
| 7784 |
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
| 7785 |
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
| 7786 |
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
|
|
|
| 7812 |
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
| 7813 |
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
| 7814 |
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
| 7815 |
+
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
|
| 7816 |
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
| 7817 |
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
| 7818 |
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
|
|
|
| 7958 |
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
|
| 7959 |
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
|
| 7960 |
|
| 7961 |
+
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH>>>;
|
| 7962 |
+
|
| 7963 |
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
|
| 7964 |
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
|
| 7965 |
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
|
|
@@ -2497,6 +2497,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|
| 2497 |
case GGML_OP_CLAMP:
|
| 2498 |
return op->src[0]->type == GGML_TYPE_F32;
|
| 2499 |
case GGML_OP_SOFT_MAX:
|
|
|
|
|
|
|
| 2500 |
case GGML_OP_NORM:
|
| 2501 |
case GGML_OP_RMS_NORM:
|
| 2502 |
return true;
|
|
|
|
| 2497 |
case GGML_OP_CLAMP:
|
| 2498 |
return op->src[0]->type == GGML_TYPE_F32;
|
| 2499 |
case GGML_OP_SOFT_MAX:
|
| 2500 |
+
// TODO: support attention sinks [TAG_ATTN_SINKS]
|
| 2501 |
+
return op->src[2] == nullptr;
|
| 2502 |
case GGML_OP_NORM:
|
| 2503 |
case GGML_OP_RMS_NORM:
|
| 2504 |
return true;
|
|
@@ -21,6 +21,17 @@
|
|
| 21 |
|
| 22 |
#define UNUSED GGML_UNUSED
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
// reference implementation for deterministic creation of model files
|
| 25 |
void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
|
| 26 |
static const int qk = QK4_0;
|
|
@@ -246,6 +257,53 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST
|
|
| 246 |
}
|
| 247 |
}
|
| 248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
| 250 |
static const int qk = QK4_0;
|
| 251 |
|
|
@@ -356,6 +414,26 @@ void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRI
|
|
| 356 |
}
|
| 357 |
}
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
//
|
| 360 |
// 2-6 bit quantization in super-blocks
|
| 361 |
//
|
|
@@ -2014,6 +2092,12 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
|
|
| 2014 |
return nrow * row_size;
|
| 2015 |
}
|
| 2016 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2017 |
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
|
| 2018 |
|
| 2019 |
void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
|
|
@@ -4551,17 +4635,6 @@ size_t quantize_iq1_m(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
|
|
| 4551 |
|
| 4552 |
// ============================ 4-bit non-linear quants
|
| 4553 |
|
| 4554 |
-
static inline int best_index_int8(int n, const int8_t * val, float x) {
|
| 4555 |
-
if (x <= val[0]) return 0;
|
| 4556 |
-
if (x >= val[n-1]) return n-1;
|
| 4557 |
-
int ml = 0, mu = n-1;
|
| 4558 |
-
while (mu-ml > 1) {
|
| 4559 |
-
int mav = (ml+mu)/2;
|
| 4560 |
-
if (x < val[mav]) mu = mav; else ml = mav;
|
| 4561 |
-
}
|
| 4562 |
-
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
| 4563 |
-
}
|
| 4564 |
-
|
| 4565 |
static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
|
| 4566 |
ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
|
| 4567 |
float * scales, float * weight, uint8_t * L,
|
|
@@ -4961,6 +5034,15 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
|
|
| 4961 |
return true;
|
| 4962 |
}
|
| 4963 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4964 |
#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
|
| 4965 |
const type * q = (const type *) (data); \
|
| 4966 |
for (size_t i = 0; i < (nb); ++i) { \
|
|
@@ -4977,6 +5059,14 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
|
|
| 4977 |
} \
|
| 4978 |
}
|
| 4979 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4980 |
#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
|
| 4981 |
const type * q = (const type *) (data); \
|
| 4982 |
for (size_t i = 0; i < (nb); ++i) { \
|
|
@@ -5130,6 +5220,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
|
| 5130 |
{
|
| 5131 |
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
|
| 5132 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5133 |
case GGML_TYPE_Q2_K:
|
| 5134 |
{
|
| 5135 |
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
|
|
|
|
| 21 |
|
| 22 |
#define UNUSED GGML_UNUSED
|
| 23 |
|
| 24 |
+
static inline int best_index_int8(int n, const int8_t * val, float x) {
|
| 25 |
+
if (x <= val[0]) return 0;
|
| 26 |
+
if (x >= val[n-1]) return n-1;
|
| 27 |
+
int ml = 0, mu = n-1;
|
| 28 |
+
while (mu-ml > 1) {
|
| 29 |
+
int mav = (ml+mu)/2;
|
| 30 |
+
if (x < val[mav]) mu = mav; else ml = mav;
|
| 31 |
+
}
|
| 32 |
+
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
// reference implementation for deterministic creation of model files
|
| 36 |
void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
|
| 37 |
static const int qk = QK4_0;
|
|
|
|
| 257 |
}
|
| 258 |
}
|
| 259 |
|
| 260 |
+
static inline int best_index_mxfp4(float x, float e) {
|
| 261 |
+
int best_index = 0;
|
| 262 |
+
float best_err = fabsf(kvalues_mxfp4[0]*e - x);
|
| 263 |
+
for (int i = 1; i < 16; i++) {
|
| 264 |
+
float err = fabsf(kvalues_mxfp4[i]*e - x);
|
| 265 |
+
if (err < best_err) {
|
| 266 |
+
best_index = i;
|
| 267 |
+
best_err = err;
|
| 268 |
+
}
|
| 269 |
+
}
|
| 270 |
+
return best_index;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
|
| 274 |
+
static const int qk = QK_MXFP4;
|
| 275 |
+
|
| 276 |
+
assert(k % qk == 0);
|
| 277 |
+
|
| 278 |
+
const int nb = k / qk;
|
| 279 |
+
|
| 280 |
+
for (int i = 0; i < nb; i++) {
|
| 281 |
+
float amax = 0.0f; // absolute max
|
| 282 |
+
|
| 283 |
+
for (int j = 0; j < qk; j++) {
|
| 284 |
+
const float v = x[i*qk + j];
|
| 285 |
+
|
| 286 |
+
if (amax < fabsf(v)) {
|
| 287 |
+
amax = fabsf(v);
|
| 288 |
+
}
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127);
|
| 292 |
+
|
| 293 |
+
const float d = GGML_E8M0_TO_FP32_HALF(e);
|
| 294 |
+
|
| 295 |
+
y[i].e = e;
|
| 296 |
+
|
| 297 |
+
for (int j = 0; j < qk/2; ++j) {
|
| 298 |
+
const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d);
|
| 299 |
+
const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
|
| 300 |
+
|
| 301 |
+
y[i].qs[j] = x0;
|
| 302 |
+
y[i].qs[j] |= x1 << 4;
|
| 303 |
+
}
|
| 304 |
+
}
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
| 308 |
static const int qk = QK4_0;
|
| 309 |
|
|
|
|
| 414 |
}
|
| 415 |
}
|
| 416 |
|
| 417 |
+
void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
| 418 |
+
static const int qk = QK_MXFP4;
|
| 419 |
+
|
| 420 |
+
assert(k % qk == 0);
|
| 421 |
+
|
| 422 |
+
const int nb = k / qk;
|
| 423 |
+
|
| 424 |
+
for (int i = 0; i < nb; i++) {
|
| 425 |
+
const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
|
| 426 |
+
|
| 427 |
+
for (int j = 0; j < qk/2; ++j) {
|
| 428 |
+
const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];
|
| 429 |
+
const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >> 4];
|
| 430 |
+
|
| 431 |
+
y[i*qk + j + 0 ] = x0*d;
|
| 432 |
+
y[i*qk + j + qk/2] = x1*d;
|
| 433 |
+
}
|
| 434 |
+
}
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
//
|
| 438 |
// 2-6 bit quantization in super-blocks
|
| 439 |
//
|
|
|
|
| 2092 |
return nrow * row_size;
|
| 2093 |
}
|
| 2094 |
|
| 2095 |
+
size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
| 2096 |
+
GGML_UNUSED(quant_weights);
|
| 2097 |
+
quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
|
| 2098 |
+
return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
|
| 2099 |
+
}
|
| 2100 |
+
|
| 2101 |
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
|
| 2102 |
|
| 2103 |
void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
|
|
|
|
| 4635 |
|
| 4636 |
// ============================ 4-bit non-linear quants
|
| 4637 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4638 |
static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
|
| 4639 |
ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
|
| 4640 |
float * scales, float * weight, uint8_t * L,
|
|
|
|
| 5034 |
return true;
|
| 5035 |
}
|
| 5036 |
|
| 5037 |
+
static bool validate_e_e8m0(uint8_t e, size_t i) {
|
| 5038 |
+
if (e == 0xff) {
|
| 5039 |
+
fprintf(stderr, "ggml_validate_row_data: found invalid e value %d at block %zu\n", e, i);
|
| 5040 |
+
return false;
|
| 5041 |
+
}
|
| 5042 |
+
|
| 5043 |
+
return true;
|
| 5044 |
+
}
|
| 5045 |
+
|
| 5046 |
#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
|
| 5047 |
const type * q = (const type *) (data); \
|
| 5048 |
for (size_t i = 0; i < (nb); ++i) { \
|
|
|
|
| 5059 |
} \
|
| 5060 |
}
|
| 5061 |
|
| 5062 |
+
#define VALIDATE_ROW_DATA_E_E8M0_IMPL(type, data, nb) \
|
| 5063 |
+
const type * q = (const type *) (data); \
|
| 5064 |
+
for (size_t i = 0; i < (nb); ++i) { \
|
| 5065 |
+
if (!validate_e_e8m0(q[i].e, i)) { \
|
| 5066 |
+
return false; \
|
| 5067 |
+
} \
|
| 5068 |
+
}
|
| 5069 |
+
|
| 5070 |
#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
|
| 5071 |
const type * q = (const type *) (data); \
|
| 5072 |
for (size_t i = 0; i < (nb); ++i) { \
|
|
|
|
| 5220 |
{
|
| 5221 |
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
|
| 5222 |
} break;
|
| 5223 |
+
case GGML_TYPE_MXFP4:
|
| 5224 |
+
{
|
| 5225 |
+
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
|
| 5226 |
+
} break;
|
| 5227 |
case GGML_TYPE_Q2_K:
|
| 5228 |
{
|
| 5229 |
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
|
|
@@ -21,6 +21,8 @@ GGML_API void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 *
|
|
| 21 |
GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
|
| 22 |
GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
|
| 23 |
|
|
|
|
|
|
|
| 24 |
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
|
| 25 |
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
|
| 26 |
GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
|
|
@@ -45,6 +47,8 @@ GGML_API void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GG
|
|
| 45 |
GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
| 46 |
//GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
| 47 |
|
|
|
|
|
|
|
| 48 |
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
| 49 |
GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
| 50 |
GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
|
@@ -90,6 +94,8 @@ GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTR
|
|
| 90 |
GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
| 91 |
GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
| 92 |
|
|
|
|
|
|
|
| 93 |
GGML_API void iq2xs_init_impl(enum ggml_type type);
|
| 94 |
GGML_API void iq2xs_free_impl(enum ggml_type type);
|
| 95 |
GGML_API void iq3xs_init_impl(int grid_size);
|
|
|
|
| 21 |
GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
|
| 22 |
GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
|
| 23 |
|
| 24 |
+
GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
|
| 25 |
+
|
| 26 |
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
|
| 27 |
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
|
| 28 |
GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
|
|
|
|
| 47 |
GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
| 48 |
//GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
| 49 |
|
| 50 |
+
GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
| 51 |
+
|
| 52 |
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
| 53 |
GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
| 54 |
GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
|
|
|
| 94 |
GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
| 95 |
GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
| 96 |
|
| 97 |
+
GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
| 98 |
+
|
| 99 |
GGML_API void iq2xs_init_impl(enum ggml_type type);
|
| 100 |
GGML_API void iq2xs_free_impl(enum ggml_type type);
|
| 101 |
GGML_API void iq3xs_init_impl(int grid_size);
|
|
@@ -4193,15 +4193,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 4193 |
case GGML_OP_MUL_MAT:
|
| 4194 |
case GGML_OP_MUL_MAT_ID:
|
| 4195 |
{
|
| 4196 |
-
struct ggml_tensor * a;
|
| 4197 |
-
struct ggml_tensor * b;
|
| 4198 |
-
|
| 4199 |
-
a = op->src[0];
|
| 4200 |
-
b = op->src[1];
|
| 4201 |
-
} else {
|
| 4202 |
-
a = op->src[2];
|
| 4203 |
-
b = op->src[1];
|
| 4204 |
-
}
|
| 4205 |
if (a->ne[3] != b->ne[3]) {
|
| 4206 |
return false;
|
| 4207 |
}
|
|
@@ -4216,7 +4210,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 4216 |
}
|
| 4217 |
}
|
| 4218 |
ggml_type src0_type = op->src[0]->type;
|
| 4219 |
-
if (src0_type == GGML_TYPE_BF16) {
|
|
|
|
|
|
|
| 4220 |
return false;
|
| 4221 |
}
|
| 4222 |
return true;
|
|
@@ -4361,6 +4357,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 4361 |
if (op->src[0]->ne[3] != 1) {
|
| 4362 |
return false;
|
| 4363 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4364 |
// TODO: support broadcast
|
| 4365 |
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
| 4366 |
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
|
|
|
|
| 4193 |
case GGML_OP_MUL_MAT:
|
| 4194 |
case GGML_OP_MUL_MAT_ID:
|
| 4195 |
{
|
| 4196 |
+
struct ggml_tensor * a = op->src[0];
|
| 4197 |
+
struct ggml_tensor * b = op->src[1];
|
| 4198 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4199 |
if (a->ne[3] != b->ne[3]) {
|
| 4200 |
return false;
|
| 4201 |
}
|
|
|
|
| 4210 |
}
|
| 4211 |
}
|
| 4212 |
ggml_type src0_type = op->src[0]->type;
|
| 4213 |
+
if (src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_MXFP4) {
|
| 4214 |
+
// TODO: support MXFP4
|
| 4215 |
+
// FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
|
| 4216 |
return false;
|
| 4217 |
}
|
| 4218 |
return true;
|
|
|
|
| 4357 |
if (op->src[0]->ne[3] != 1) {
|
| 4358 |
return false;
|
| 4359 |
}
|
| 4360 |
+
// TODO: support attention sinks [TAG_ATTN_SINKS]
|
| 4361 |
+
if (op->src[2]) {
|
| 4362 |
+
return false;
|
| 4363 |
+
}
|
| 4364 |
// TODO: support broadcast
|
| 4365 |
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
| 4366 |
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
|
|
@@ -449,6 +449,8 @@ struct vk_device_struct {
|
|
| 449 |
vk_pipeline pipeline_div[2][2][2];
|
| 450 |
vk_pipeline pipeline_div_norepeat[2][2][2];
|
| 451 |
|
|
|
|
|
|
|
| 452 |
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
| 453 |
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
|
| 454 |
vk_pipeline pipeline_scale_f32;
|
|
@@ -483,6 +485,7 @@ struct vk_device_struct {
|
|
| 483 |
vk_pipeline pipeline_geglu[2];
|
| 484 |
vk_pipeline pipeline_reglu[2];
|
| 485 |
vk_pipeline pipeline_swiglu[2];
|
|
|
|
| 486 |
vk_pipeline pipeline_geglu_erf[2];
|
| 487 |
vk_pipeline pipeline_geglu_quick[2];
|
| 488 |
|
|
@@ -705,6 +708,8 @@ struct vk_op_glu_push_constants {
|
|
| 705 |
uint32_t ne00;
|
| 706 |
uint32_t ne20;
|
| 707 |
uint32_t mode; // 0: default, 1: swapped, 2: split
|
|
|
|
|
|
|
| 708 |
};
|
| 709 |
|
| 710 |
struct vk_op_unary_push_constants {
|
|
@@ -794,6 +799,15 @@ struct vk_op_binary_push_constants {
|
|
| 794 |
float param1; float param2; int32_t param3;
|
| 795 |
};
|
| 796 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 797 |
struct vk_op_diag_mask_push_constants {
|
| 798 |
uint32_t ncols;
|
| 799 |
uint32_t rows_per_channel;
|
|
@@ -835,6 +849,7 @@ struct vk_op_soft_max_push_constants {
|
|
| 835 |
float m1;
|
| 836 |
uint32_t n_head_log2;
|
| 837 |
uint32_t nrows_x;
|
|
|
|
| 838 |
};
|
| 839 |
|
| 840 |
struct vk_op_argsort_push_constants {
|
|
@@ -1977,6 +1992,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|
| 1977 |
break;
|
| 1978 |
case GGML_TYPE_IQ4_NL:
|
| 1979 |
case GGML_TYPE_IQ4_XS:
|
|
|
|
| 1980 |
lut_size = 4*16;
|
| 1981 |
break;
|
| 1982 |
default:
|
|
@@ -2353,6 +2369,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2353 |
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 2354 |
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 2355 |
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
|
|
| 2356 |
|
| 2357 |
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
| 2358 |
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
@@ -2379,6 +2396,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2379 |
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2380 |
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2381 |
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
|
|
| 2382 |
#undef CREATE_MM
|
| 2383 |
#undef CREATE_MM2
|
| 2384 |
} else
|
|
@@ -2440,6 +2458,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2440 |
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2441 |
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2442 |
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
|
|
| 2443 |
} else {
|
| 2444 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2445 |
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
@@ -2461,6 +2480,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2461 |
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2462 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2463 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
|
|
| 2464 |
}
|
| 2465 |
|
| 2466 |
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
@@ -2493,6 +2513,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2493 |
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2494 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2495 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
|
|
| 2496 |
} else {
|
| 2497 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2498 |
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
@@ -2514,6 +2535,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2514 |
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2515 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2516 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
|
|
| 2517 |
}
|
| 2518 |
#undef CREATE_MM2
|
| 2519 |
#undef CREATE_MM
|
|
@@ -2581,6 +2603,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2581 |
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2582 |
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2583 |
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
|
|
| 2584 |
|
| 2585 |
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 2586 |
if (device->integer_dot_product) {
|
|
@@ -2618,6 +2641,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2618 |
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2619 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2620 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
|
|
| 2621 |
#undef CREATE_MM2
|
| 2622 |
#undef CREATE_MMQ
|
| 2623 |
#undef CREATE_MM
|
|
@@ -2672,6 +2696,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2672 |
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2673 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2674 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
|
|
| 2675 |
|
| 2676 |
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 2677 |
if (device->integer_dot_product) {
|
|
@@ -2709,6 +2734,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2709 |
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2710 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2711 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
|
|
| 2712 |
}
|
| 2713 |
// reusing CREATE_MM from the fp32 path
|
| 2714 |
if ((device->coopmat2 || device->coopmat_support)
|
|
@@ -2767,6 +2793,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2767 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
| 2768 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
| 2769 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
|
|
| 2770 |
|
| 2771 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
| 2772 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
@@ -2790,6 +2817,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2790 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
| 2791 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
| 2792 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
|
|
| 2793 |
}
|
| 2794 |
|
| 2795 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
@@ -2814,6 +2842,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2814 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
| 2815 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
| 2816 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
|
|
| 2817 |
|
| 2818 |
// dequant shaders
|
| 2819 |
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
@@ -2836,6 +2865,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2836 |
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
| 2837 |
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
| 2838 |
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
|
|
| 2839 |
|
| 2840 |
// get_rows
|
| 2841 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
@@ -2855,6 +2885,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2855 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2856 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2857 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
|
|
| 2858 |
|
| 2859 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
| 2860 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
@@ -2873,6 +2904,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2873 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2874 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2875 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
|
|
| 2876 |
|
| 2877 |
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
| 2878 |
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
|
@@ -2976,6 +3008,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2976 |
CREATE_BINARY(div, _norepeat, {1})
|
| 2977 |
#undef CREATE_BINARY
|
| 2978 |
|
|
|
|
|
|
|
| 2979 |
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
| 2980 |
|
| 2981 |
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
@@ -3026,6 +3060,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 3026 |
CREATE_GLU(geglu)
|
| 3027 |
CREATE_GLU(reglu)
|
| 3028 |
CREATE_GLU(swiglu)
|
|
|
|
| 3029 |
CREATE_GLU(geglu_erf)
|
| 3030 |
CREATE_GLU(geglu_quick)
|
| 3031 |
#undef CREATE_GLU
|
|
@@ -3035,10 +3070,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 3035 |
|
| 3036 |
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
| 3037 |
|
| 3038 |
-
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main",
|
| 3039 |
-
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main",
|
| 3040 |
-
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main",
|
| 3041 |
-
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main",
|
| 3042 |
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
| 3043 |
|
| 3044 |
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
@@ -4244,6 +4279,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
|
|
| 4244 |
case GGML_TYPE_IQ3_S:
|
| 4245 |
case GGML_TYPE_IQ4_XS:
|
| 4246 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 4247 |
break;
|
| 4248 |
default:
|
| 4249 |
return nullptr;
|
|
@@ -4314,6 +4350,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
| 4314 |
case GGML_TYPE_IQ3_S:
|
| 4315 |
case GGML_TYPE_IQ4_XS:
|
| 4316 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 4317 |
break;
|
| 4318 |
default:
|
| 4319 |
return nullptr;
|
|
@@ -4357,6 +4394,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
|
| 4357 |
case GGML_TYPE_IQ3_S:
|
| 4358 |
case GGML_TYPE_IQ4_XS:
|
| 4359 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 4360 |
break;
|
| 4361 |
default:
|
| 4362 |
return nullptr;
|
|
@@ -4411,6 +4449,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
|
| 4411 |
case GGML_TYPE_IQ3_S:
|
| 4412 |
case GGML_TYPE_IQ4_XS:
|
| 4413 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 4414 |
break;
|
| 4415 |
default:
|
| 4416 |
return nullptr;
|
|
@@ -4446,6 +4485,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
|
| 4446 |
case GGML_TYPE_IQ3_S:
|
| 4447 |
case GGML_TYPE_IQ4_XS:
|
| 4448 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 4449 |
break;
|
| 4450 |
default:
|
| 4451 |
return nullptr;
|
|
@@ -4631,6 +4671,7 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
|
|
| 4631 |
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
|
| 4632 |
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
|
| 4633 |
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
|
|
|
|
| 4634 |
|
| 4635 |
vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
|
| 4636 |
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
|
|
@@ -6847,6 +6888,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 6847 |
break;
|
| 6848 |
}
|
| 6849 |
return nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6850 |
case GGML_OP_CONCAT:
|
| 6851 |
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 6852 |
return ctx->device->pipeline_concat_f32;
|
|
@@ -6992,6 +7038,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 6992 |
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
| 6993 |
case GGML_GLU_OP_SWIGLU:
|
| 6994 |
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
|
|
|
|
|
|
| 6995 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 6996 |
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
|
| 6997 |
case GGML_GLU_OP_GEGLU_QUICK:
|
|
@@ -7007,6 +7055,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 7007 |
return nullptr;
|
| 7008 |
case GGML_OP_SOFT_MAX:
|
| 7009 |
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
|
|
|
| 7010 |
|
| 7011 |
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
| 7012 |
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
|
|
@@ -7177,6 +7226,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|
| 7177 |
case GGML_OP_SUB:
|
| 7178 |
case GGML_OP_MUL:
|
| 7179 |
case GGML_OP_DIV:
|
|
|
|
| 7180 |
case GGML_OP_CONCAT:
|
| 7181 |
case GGML_OP_UPSCALE:
|
| 7182 |
case GGML_OP_SQR:
|
|
@@ -7523,6 +7573,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
| 7523 |
elements = { ne, 1, 1 };
|
| 7524 |
}
|
| 7525 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7526 |
case GGML_OP_SET_ROWS:
|
| 7527 |
{
|
| 7528 |
uint32_t ne = ggml_nelements(src0);
|
|
@@ -7562,8 +7616,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
| 7562 |
}
|
| 7563 |
}
|
| 7564 |
|
| 7565 |
-
if (op ==
|
| 7566 |
-
// Empty src1 is possible in
|
| 7567 |
vk_subbuffer subbuf_y;
|
| 7568 |
if (use_src1) {
|
| 7569 |
subbuf_y = { d_Y, y_buf_offset, y_sz };
|
|
@@ -7573,6 +7627,24 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
| 7573 |
|
| 7574 |
ggml_vk_sync_buffers(subctx);
|
| 7575 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7576 |
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
|
| 7577 |
// Empty src2 is possible in rope, but the shader needs a buffer
|
| 7578 |
vk_subbuffer subbuf_z;
|
|
@@ -7701,6 +7773,21 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
| 7701 |
}, dryrun);
|
| 7702 |
}
|
| 7703 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7704 |
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
|
| 7705 |
GGML_ASSERT(version == 6 || version == 7);
|
| 7706 |
int num_srcs = version == 6 ? 6 : 7;
|
|
@@ -8119,8 +8206,12 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
|
|
| 8119 |
}
|
| 8120 |
|
| 8121 |
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
|
|
|
|
|
| 8122 |
const bool swapped = (bool)dst->op_params[1];
|
| 8123 |
const bool split = src1 != nullptr;
|
|
|
|
|
|
|
| 8124 |
|
| 8125 |
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 8126 |
|
|
@@ -8134,7 +8225,15 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
| 8134 |
|
| 8135 |
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
|
| 8136 |
|
| 8137 |
-
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8138 |
}
|
| 8139 |
|
| 8140 |
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
@@ -8142,7 +8241,7 @@ static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 8142 |
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
|
| 8143 |
}
|
| 8144 |
|
| 8145 |
-
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 8146 |
float * op_params = (float *)dst->op_params;
|
| 8147 |
|
| 8148 |
float scale = op_params[0];
|
|
@@ -8164,7 +8263,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
| 8164 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 8165 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 8166 |
|
| 8167 |
-
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1,
|
| 8168 |
ncols,
|
| 8169 |
src1 != nullptr ? nrows_y : (uint32_t)0,
|
| 8170 |
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
|
@@ -8174,6 +8273,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
| 8174 |
m0, m1,
|
| 8175 |
n_head_log2,
|
| 8176 |
nrows_x,
|
|
|
|
| 8177 |
}, dryrun);
|
| 8178 |
}
|
| 8179 |
|
|
@@ -9413,6 +9513,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 9413 |
case GGML_GLU_OP_GEGLU:
|
| 9414 |
case GGML_GLU_OP_REGLU:
|
| 9415 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
| 9416 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 9417 |
case GGML_GLU_OP_GEGLU_QUICK:
|
| 9418 |
break;
|
|
@@ -9424,6 +9525,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 9424 |
case GGML_OP_REPEAT_BACK:
|
| 9425 |
case GGML_OP_GET_ROWS:
|
| 9426 |
case GGML_OP_ADD:
|
|
|
|
| 9427 |
case GGML_OP_ACC:
|
| 9428 |
case GGML_OP_SUB:
|
| 9429 |
case GGML_OP_MUL:
|
|
@@ -9578,6 +9680,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 9578 |
case GGML_OP_DIV:
|
| 9579 |
ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
|
| 9580 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9581 |
break;
|
| 9582 |
case GGML_OP_CONCAT:
|
| 9583 |
ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -9675,6 +9781,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 9675 |
case GGML_GLU_OP_GEGLU:
|
| 9676 |
case GGML_GLU_OP_REGLU:
|
| 9677 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
| 9678 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 9679 |
case GGML_GLU_OP_GEGLU_QUICK:
|
| 9680 |
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -9688,7 +9795,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 9688 |
|
| 9689 |
break;
|
| 9690 |
case GGML_OP_SOFT_MAX:
|
| 9691 |
-
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
|
| 9692 |
|
| 9693 |
break;
|
| 9694 |
case GGML_OP_SOFT_MAX_BACK:
|
|
@@ -9834,6 +9941,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
| 9834 |
case GGML_OP_SUB:
|
| 9835 |
case GGML_OP_MUL:
|
| 9836 |
case GGML_OP_DIV:
|
|
|
|
| 9837 |
case GGML_OP_CONCAT:
|
| 9838 |
case GGML_OP_UPSCALE:
|
| 9839 |
case GGML_OP_SCALE:
|
|
@@ -9903,6 +10011,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
| 9903 |
case GGML_GLU_OP_GEGLU:
|
| 9904 |
case GGML_GLU_OP_REGLU:
|
| 9905 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
| 9906 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 9907 |
case GGML_GLU_OP_GEGLU_QUICK:
|
| 9908 |
buf = tensor->buffer;
|
|
@@ -10752,6 +10861,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 10752 |
case GGML_GLU_OP_GEGLU:
|
| 10753 |
case GGML_GLU_OP_REGLU:
|
| 10754 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
| 10755 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 10756 |
case GGML_GLU_OP_GEGLU_QUICK:
|
| 10757 |
return ggml_is_contiguous(op->src[0]) &&
|
|
@@ -10797,6 +10907,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 10797 |
case GGML_TYPE_IQ3_S:
|
| 10798 |
case GGML_TYPE_IQ4_XS:
|
| 10799 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 10800 |
break;
|
| 10801 |
default:
|
| 10802 |
return false;
|
|
@@ -10834,6 +10945,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 10834 |
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
|
| 10835 |
return false;
|
| 10836 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10837 |
if (op->src[0]->type != GGML_TYPE_F32) {
|
| 10838 |
return false;
|
| 10839 |
}
|
|
@@ -10906,6 +11021,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 10906 |
case GGML_TYPE_IQ3_S:
|
| 10907 |
case GGML_TYPE_IQ4_XS:
|
| 10908 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 10909 |
return true;
|
| 10910 |
default:
|
| 10911 |
return false;
|
|
@@ -11004,6 +11120,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 11004 |
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
| 11005 |
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
|
| 11006 |
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
|
|
|
|
|
|
|
|
|
|
| 11007 |
case GGML_OP_SILU_BACK:
|
| 11008 |
case GGML_OP_RMS_NORM_BACK:
|
| 11009 |
case GGML_OP_SQR:
|
|
|
|
| 449 |
vk_pipeline pipeline_div[2][2][2];
|
| 450 |
vk_pipeline pipeline_div_norepeat[2][2][2];
|
| 451 |
|
| 452 |
+
vk_pipeline pipeline_add_id_f32;
|
| 453 |
+
|
| 454 |
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
| 455 |
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
|
| 456 |
vk_pipeline pipeline_scale_f32;
|
|
|
|
| 485 |
vk_pipeline pipeline_geglu[2];
|
| 486 |
vk_pipeline pipeline_reglu[2];
|
| 487 |
vk_pipeline pipeline_swiglu[2];
|
| 488 |
+
vk_pipeline pipeline_swiglu_oai[2];
|
| 489 |
vk_pipeline pipeline_geglu_erf[2];
|
| 490 |
vk_pipeline pipeline_geglu_quick[2];
|
| 491 |
|
|
|
|
| 708 |
uint32_t ne00;
|
| 709 |
uint32_t ne20;
|
| 710 |
uint32_t mode; // 0: default, 1: swapped, 2: split
|
| 711 |
+
float alpha; // for swiglu_oai
|
| 712 |
+
float limit;
|
| 713 |
};
|
| 714 |
|
| 715 |
struct vk_op_unary_push_constants {
|
|
|
|
| 799 |
float param1; float param2; int32_t param3;
|
| 800 |
};
|
| 801 |
|
| 802 |
+
struct vk_op_add_id_push_constants {
|
| 803 |
+
uint32_t ne0;
|
| 804 |
+
uint32_t ne1;
|
| 805 |
+
uint32_t s01;
|
| 806 |
+
uint32_t s02;
|
| 807 |
+
uint32_t s11;
|
| 808 |
+
uint32_t s21;
|
| 809 |
+
};
|
| 810 |
+
|
| 811 |
struct vk_op_diag_mask_push_constants {
|
| 812 |
uint32_t ncols;
|
| 813 |
uint32_t rows_per_channel;
|
|
|
|
| 849 |
float m1;
|
| 850 |
uint32_t n_head_log2;
|
| 851 |
uint32_t nrows_x;
|
| 852 |
+
uint32_t has_sinks;
|
| 853 |
};
|
| 854 |
|
| 855 |
struct vk_op_argsort_push_constants {
|
|
|
|
| 1992 |
break;
|
| 1993 |
case GGML_TYPE_IQ4_NL:
|
| 1994 |
case GGML_TYPE_IQ4_XS:
|
| 1995 |
+
case GGML_TYPE_MXFP4:
|
| 1996 |
lut_size = 4*16;
|
| 1997 |
break;
|
| 1998 |
default:
|
|
|
|
| 2369 |
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 2370 |
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 2371 |
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 2372 |
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 2373 |
|
| 2374 |
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
| 2375 |
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
|
|
| 2396 |
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2397 |
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2398 |
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2399 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 2400 |
#undef CREATE_MM
|
| 2401 |
#undef CREATE_MM2
|
| 2402 |
} else
|
|
|
|
| 2458 |
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2459 |
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2460 |
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2461 |
+
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2462 |
} else {
|
| 2463 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2464 |
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
|
|
| 2480 |
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2481 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2482 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2483 |
+
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2484 |
}
|
| 2485 |
|
| 2486 |
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
|
|
| 2513 |
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2514 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2515 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2516 |
+
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2517 |
} else {
|
| 2518 |
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2519 |
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
|
|
| 2535 |
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2536 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2537 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2538 |
+
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2539 |
}
|
| 2540 |
#undef CREATE_MM2
|
| 2541 |
#undef CREATE_MM
|
|
|
|
| 2603 |
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2604 |
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2605 |
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2606 |
+
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2607 |
|
| 2608 |
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 2609 |
if (device->integer_dot_product) {
|
|
|
|
| 2641 |
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2642 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2643 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2644 |
+
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2645 |
#undef CREATE_MM2
|
| 2646 |
#undef CREATE_MMQ
|
| 2647 |
#undef CREATE_MM
|
|
|
|
| 2696 |
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2697 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2698 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2699 |
+
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2700 |
|
| 2701 |
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 2702 |
if (device->integer_dot_product) {
|
|
|
|
| 2734 |
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2735 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2736 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2737 |
+
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2738 |
}
|
| 2739 |
// reusing CREATE_MM from the fp32 path
|
| 2740 |
if ((device->coopmat2 || device->coopmat_support)
|
|
|
|
| 2793 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
| 2794 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
| 2795 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
| 2796 |
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f32_f32_len, mul_mat_vec_mxfp4_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
| 2797 |
|
| 2798 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
| 2799 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
|
|
| 2817 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
| 2818 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
| 2819 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
| 2820 |
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f16_f32_len, mul_mat_vec_mxfp4_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
| 2821 |
}
|
| 2822 |
|
| 2823 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
|
|
| 2842 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
| 2843 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
| 2844 |
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
| 2845 |
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
| 2846 |
|
| 2847 |
// dequant shaders
|
| 2848 |
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
|
|
| 2865 |
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
| 2866 |
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
| 2867 |
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
| 2868 |
+
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
| 2869 |
|
| 2870 |
// get_rows
|
| 2871 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
|
|
| 2885 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2886 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2887 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2888 |
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2889 |
|
| 2890 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
| 2891 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
|
|
| 2904 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2905 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2906 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2907 |
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2908 |
|
| 2909 |
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
| 2910 |
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
|
|
|
| 3008 |
CREATE_BINARY(div, _norepeat, {1})
|
| 3009 |
#undef CREATE_BINARY
|
| 3010 |
|
| 3011 |
+
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
|
| 3012 |
+
|
| 3013 |
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
| 3014 |
|
| 3015 |
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
|
|
| 3060 |
CREATE_GLU(geglu)
|
| 3061 |
CREATE_GLU(reglu)
|
| 3062 |
CREATE_GLU(swiglu)
|
| 3063 |
+
CREATE_GLU(swiglu_oai)
|
| 3064 |
CREATE_GLU(geglu_erf)
|
| 3065 |
CREATE_GLU(geglu_quick)
|
| 3066 |
#undef CREATE_GLU
|
|
|
|
| 3070 |
|
| 3071 |
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
| 3072 |
|
| 3073 |
+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
| 3074 |
+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
| 3075 |
+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
| 3076 |
+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
| 3077 |
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
| 3078 |
|
| 3079 |
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
|
|
| 4279 |
case GGML_TYPE_IQ3_S:
|
| 4280 |
case GGML_TYPE_IQ4_XS:
|
| 4281 |
case GGML_TYPE_IQ4_NL:
|
| 4282 |
+
case GGML_TYPE_MXFP4:
|
| 4283 |
break;
|
| 4284 |
default:
|
| 4285 |
return nullptr;
|
|
|
|
| 4350 |
case GGML_TYPE_IQ3_S:
|
| 4351 |
case GGML_TYPE_IQ4_XS:
|
| 4352 |
case GGML_TYPE_IQ4_NL:
|
| 4353 |
+
case GGML_TYPE_MXFP4:
|
| 4354 |
break;
|
| 4355 |
default:
|
| 4356 |
return nullptr;
|
|
|
|
| 4394 |
case GGML_TYPE_IQ3_S:
|
| 4395 |
case GGML_TYPE_IQ4_XS:
|
| 4396 |
case GGML_TYPE_IQ4_NL:
|
| 4397 |
+
case GGML_TYPE_MXFP4:
|
| 4398 |
break;
|
| 4399 |
default:
|
| 4400 |
return nullptr;
|
|
|
|
| 4449 |
case GGML_TYPE_IQ3_S:
|
| 4450 |
case GGML_TYPE_IQ4_XS:
|
| 4451 |
case GGML_TYPE_IQ4_NL:
|
| 4452 |
+
case GGML_TYPE_MXFP4:
|
| 4453 |
break;
|
| 4454 |
default:
|
| 4455 |
return nullptr;
|
|
|
|
| 4485 |
case GGML_TYPE_IQ3_S:
|
| 4486 |
case GGML_TYPE_IQ4_XS:
|
| 4487 |
case GGML_TYPE_IQ4_NL:
|
| 4488 |
+
case GGML_TYPE_MXFP4:
|
| 4489 |
break;
|
| 4490 |
default:
|
| 4491 |
return nullptr;
|
|
|
|
| 4671 |
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
|
| 4672 |
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
|
| 4673 |
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
|
| 4674 |
+
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
|
| 4675 |
|
| 4676 |
vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
|
| 4677 |
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
|
|
|
|
| 6888 |
break;
|
| 6889 |
}
|
| 6890 |
return nullptr;
|
| 6891 |
+
case GGML_OP_ADD_ID:
|
| 6892 |
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
|
| 6893 |
+
return ctx->device->pipeline_add_id_f32;
|
| 6894 |
+
}
|
| 6895 |
+
return nullptr;
|
| 6896 |
case GGML_OP_CONCAT:
|
| 6897 |
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 6898 |
return ctx->device->pipeline_concat_f32;
|
|
|
|
| 7038 |
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
| 7039 |
case GGML_GLU_OP_SWIGLU:
|
| 7040 |
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
| 7041 |
+
case GGML_GLU_OP_SWIGLU_OAI:
|
| 7042 |
+
return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16];
|
| 7043 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 7044 |
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
|
| 7045 |
case GGML_GLU_OP_GEGLU_QUICK:
|
|
|
|
| 7055 |
return nullptr;
|
| 7056 |
case GGML_OP_SOFT_MAX:
|
| 7057 |
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
| 7058 |
+
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
|
| 7059 |
|
| 7060 |
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
| 7061 |
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
|
|
|
|
| 7226 |
case GGML_OP_SUB:
|
| 7227 |
case GGML_OP_MUL:
|
| 7228 |
case GGML_OP_DIV:
|
| 7229 |
+
case GGML_OP_ADD_ID:
|
| 7230 |
case GGML_OP_CONCAT:
|
| 7231 |
case GGML_OP_UPSCALE:
|
| 7232 |
case GGML_OP_SQR:
|
|
|
|
| 7573 |
elements = { ne, 1, 1 };
|
| 7574 |
}
|
| 7575 |
} break;
|
| 7576 |
+
case GGML_OP_ADD_ID:
|
| 7577 |
+
{
|
| 7578 |
+
elements = { (uint32_t)ne01, (uint32_t)ne02, 1 };
|
| 7579 |
+
} break;
|
| 7580 |
case GGML_OP_SET_ROWS:
|
| 7581 |
{
|
| 7582 |
uint32_t ne = ggml_nelements(src0);
|
|
|
|
| 7616 |
}
|
| 7617 |
}
|
| 7618 |
|
| 7619 |
+
if (op == GGML_OP_GLU) {
|
| 7620 |
+
// Empty src1 is possible in glu, but the shader needs a buffer
|
| 7621 |
vk_subbuffer subbuf_y;
|
| 7622 |
if (use_src1) {
|
| 7623 |
subbuf_y = { d_Y, y_buf_offset, y_sz };
|
|
|
|
| 7627 |
|
| 7628 |
ggml_vk_sync_buffers(subctx);
|
| 7629 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
| 7630 |
+
} else if (op == GGML_OP_SOFT_MAX) {
|
| 7631 |
+
// Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
|
| 7632 |
+
vk_subbuffer subbuf_y;
|
| 7633 |
+
if (use_src1) {
|
| 7634 |
+
subbuf_y = { d_Y, y_buf_offset, y_sz };
|
| 7635 |
+
} else {
|
| 7636 |
+
subbuf_y = { d_X, 0, x_sz };
|
| 7637 |
+
}
|
| 7638 |
+
|
| 7639 |
+
vk_subbuffer subbuf_z;
|
| 7640 |
+
if (use_src2) {
|
| 7641 |
+
subbuf_z = { d_Z, z_buf_offset, z_sz };
|
| 7642 |
+
} else {
|
| 7643 |
+
subbuf_z = { d_X, 0, x_sz };
|
| 7644 |
+
}
|
| 7645 |
+
|
| 7646 |
+
ggml_vk_sync_buffers(subctx);
|
| 7647 |
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
| 7648 |
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
|
| 7649 |
// Empty src2 is possible in rope, but the shader needs a buffer
|
| 7650 |
vk_subbuffer subbuf_z;
|
|
|
|
| 7773 |
}, dryrun);
|
| 7774 |
}
|
| 7775 |
|
| 7776 |
+
static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
| 7777 |
+
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
| 7778 |
+
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
| 7779 |
+
const uint32_t src2_type_size = ggml_type_size(src2->type);
|
| 7780 |
+
|
| 7781 |
+
ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, {
|
| 7782 |
+
(uint32_t)dst->ne[0],
|
| 7783 |
+
(uint32_t)dst->ne[1],
|
| 7784 |
+
(uint32_t)src0->nb[1] / src0_type_size,
|
| 7785 |
+
(uint32_t)src0->nb[2] / src0_type_size,
|
| 7786 |
+
(uint32_t)src1->nb[1] / src1_type_size,
|
| 7787 |
+
(uint32_t)src2->nb[1] / src2_type_size,
|
| 7788 |
+
}, dryrun);
|
| 7789 |
+
}
|
| 7790 |
+
|
| 7791 |
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
|
| 7792 |
GGML_ASSERT(version == 6 || version == 7);
|
| 7793 |
int num_srcs = version == 6 ? 6 : 7;
|
|
|
|
| 8206 |
}
|
| 8207 |
|
| 8208 |
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 8209 |
+
const float * op_params_f = (const float *)dst->op_params;
|
| 8210 |
+
|
| 8211 |
const bool swapped = (bool)dst->op_params[1];
|
| 8212 |
const bool split = src1 != nullptr;
|
| 8213 |
+
const float alpha = op_params_f[2];
|
| 8214 |
+
const float limit = op_params_f[3];
|
| 8215 |
|
| 8216 |
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 8217 |
|
|
|
|
| 8225 |
|
| 8226 |
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
|
| 8227 |
|
| 8228 |
+
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU,
|
| 8229 |
+
{
|
| 8230 |
+
(uint32_t)ggml_nelements(dst),
|
| 8231 |
+
(uint32_t)src0->ne[0],
|
| 8232 |
+
(uint32_t)dst->ne[0],
|
| 8233 |
+
mode,
|
| 8234 |
+
alpha,
|
| 8235 |
+
limit
|
| 8236 |
+
}, dryrun);
|
| 8237 |
}
|
| 8238 |
|
| 8239 |
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
|
|
| 8241 |
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
|
| 8242 |
}
|
| 8243 |
|
| 8244 |
+
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
| 8245 |
float * op_params = (float *)dst->op_params;
|
| 8246 |
|
| 8247 |
float scale = op_params[0];
|
|
|
|
| 8263 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 8264 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 8265 |
|
| 8266 |
+
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
|
| 8267 |
ncols,
|
| 8268 |
src1 != nullptr ? nrows_y : (uint32_t)0,
|
| 8269 |
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
|
|
|
| 8273 |
m0, m1,
|
| 8274 |
n_head_log2,
|
| 8275 |
nrows_x,
|
| 8276 |
+
src2 != nullptr
|
| 8277 |
}, dryrun);
|
| 8278 |
}
|
| 8279 |
|
|
|
|
| 9513 |
case GGML_GLU_OP_GEGLU:
|
| 9514 |
case GGML_GLU_OP_REGLU:
|
| 9515 |
case GGML_GLU_OP_SWIGLU:
|
| 9516 |
+
case GGML_GLU_OP_SWIGLU_OAI:
|
| 9517 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 9518 |
case GGML_GLU_OP_GEGLU_QUICK:
|
| 9519 |
break;
|
|
|
|
| 9525 |
case GGML_OP_REPEAT_BACK:
|
| 9526 |
case GGML_OP_GET_ROWS:
|
| 9527 |
case GGML_OP_ADD:
|
| 9528 |
+
case GGML_OP_ADD_ID:
|
| 9529 |
case GGML_OP_ACC:
|
| 9530 |
case GGML_OP_SUB:
|
| 9531 |
case GGML_OP_MUL:
|
|
|
|
| 9680 |
case GGML_OP_DIV:
|
| 9681 |
ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
|
| 9682 |
|
| 9683 |
+
break;
|
| 9684 |
+
case GGML_OP_ADD_ID:
|
| 9685 |
+
ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
| 9686 |
+
|
| 9687 |
break;
|
| 9688 |
case GGML_OP_CONCAT:
|
| 9689 |
ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
|
|
| 9781 |
case GGML_GLU_OP_GEGLU:
|
| 9782 |
case GGML_GLU_OP_REGLU:
|
| 9783 |
case GGML_GLU_OP_SWIGLU:
|
| 9784 |
+
case GGML_GLU_OP_SWIGLU_OAI:
|
| 9785 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 9786 |
case GGML_GLU_OP_GEGLU_QUICK:
|
| 9787 |
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
|
|
| 9795 |
|
| 9796 |
break;
|
| 9797 |
case GGML_OP_SOFT_MAX:
|
| 9798 |
+
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
| 9799 |
|
| 9800 |
break;
|
| 9801 |
case GGML_OP_SOFT_MAX_BACK:
|
|
|
|
| 9941 |
case GGML_OP_SUB:
|
| 9942 |
case GGML_OP_MUL:
|
| 9943 |
case GGML_OP_DIV:
|
| 9944 |
+
case GGML_OP_ADD_ID:
|
| 9945 |
case GGML_OP_CONCAT:
|
| 9946 |
case GGML_OP_UPSCALE:
|
| 9947 |
case GGML_OP_SCALE:
|
|
|
|
| 10011 |
case GGML_GLU_OP_GEGLU:
|
| 10012 |
case GGML_GLU_OP_REGLU:
|
| 10013 |
case GGML_GLU_OP_SWIGLU:
|
| 10014 |
+
case GGML_GLU_OP_SWIGLU_OAI:
|
| 10015 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 10016 |
case GGML_GLU_OP_GEGLU_QUICK:
|
| 10017 |
buf = tensor->buffer;
|
|
|
|
| 10861 |
case GGML_GLU_OP_GEGLU:
|
| 10862 |
case GGML_GLU_OP_REGLU:
|
| 10863 |
case GGML_GLU_OP_SWIGLU:
|
| 10864 |
+
case GGML_GLU_OP_SWIGLU_OAI:
|
| 10865 |
case GGML_GLU_OP_GEGLU_ERF:
|
| 10866 |
case GGML_GLU_OP_GEGLU_QUICK:
|
| 10867 |
return ggml_is_contiguous(op->src[0]) &&
|
|
|
|
| 10907 |
case GGML_TYPE_IQ3_S:
|
| 10908 |
case GGML_TYPE_IQ4_XS:
|
| 10909 |
case GGML_TYPE_IQ4_NL:
|
| 10910 |
+
case GGML_TYPE_MXFP4:
|
| 10911 |
break;
|
| 10912 |
default:
|
| 10913 |
return false;
|
|
|
|
| 10945 |
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
|
| 10946 |
return false;
|
| 10947 |
}
|
| 10948 |
+
// TODO: support attention sinks [TAG_ATTN_SINKS]
|
| 10949 |
+
if (op->src[4]) {
|
| 10950 |
+
return false;
|
| 10951 |
+
}
|
| 10952 |
if (op->src[0]->type != GGML_TYPE_F32) {
|
| 10953 |
return false;
|
| 10954 |
}
|
|
|
|
| 11021 |
case GGML_TYPE_IQ3_S:
|
| 11022 |
case GGML_TYPE_IQ4_XS:
|
| 11023 |
case GGML_TYPE_IQ4_NL:
|
| 11024 |
+
case GGML_TYPE_MXFP4:
|
| 11025 |
return true;
|
| 11026 |
default:
|
| 11027 |
return false;
|
|
|
|
| 11120 |
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
| 11121 |
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
|
| 11122 |
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
|
| 11123 |
+
case GGML_OP_ADD_ID:
|
| 11124 |
+
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
|
| 11125 |
+
op->type == GGML_TYPE_F32;
|
| 11126 |
case GGML_OP_SILU_BACK:
|
| 11127 |
case GGML_OP_RMS_NORM_BACK:
|
| 11128 |
case GGML_OP_SQR:
|
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#extension GL_EXT_control_flow_attributes : require
|
| 4 |
+
|
| 5 |
+
#include "types.comp"
|
| 6 |
+
|
| 7 |
+
layout (push_constant) uniform parameter
|
| 8 |
+
{
|
| 9 |
+
uint ne0;
|
| 10 |
+
uint ne1;
|
| 11 |
+
uint s01;
|
| 12 |
+
uint s02;
|
| 13 |
+
uint s11;
|
| 14 |
+
uint s21;
|
| 15 |
+
} p;
|
| 16 |
+
|
| 17 |
+
#define BLOCK_SIZE 512
|
| 18 |
+
|
| 19 |
+
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
| 20 |
+
|
| 21 |
+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
| 22 |
+
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
|
| 23 |
+
layout (binding = 2) readonly buffer Z {int32_t data_c[];};
|
| 24 |
+
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
|
| 25 |
+
|
| 26 |
+
void main() {
|
| 27 |
+
const uint i1 = gl_WorkGroupID.x;
|
| 28 |
+
const uint i2 = gl_WorkGroupID.y;
|
| 29 |
+
|
| 30 |
+
const uint i11 = data_c[i1 + i2 * p.s21];
|
| 31 |
+
|
| 32 |
+
const uint s1 = p.ne0;
|
| 33 |
+
const uint s2 = p.ne0 * p.ne1;
|
| 34 |
+
|
| 35 |
+
const uint d0 = i1 * s1 + i2 * s2;
|
| 36 |
+
const uint a0 = i1 * p.s01 + i2 * p.s02;
|
| 37 |
+
const uint b0 = i11 * p.s11;
|
| 38 |
+
|
| 39 |
+
for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) {
|
| 40 |
+
data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0];
|
| 41 |
+
}
|
| 42 |
+
}
|
|
@@ -4,8 +4,8 @@
|
|
| 4 |
#include "generic_unary_head.comp"
|
| 5 |
#include "dequant_funcs.comp"
|
| 6 |
|
| 7 |
-
#if defined(DATA_A_IQ4_NL)
|
| 8 |
-
// 16 invocations needed for
|
| 9 |
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
|
| 10 |
#else
|
| 11 |
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
| 4 |
#include "generic_unary_head.comp"
|
| 5 |
#include "dequant_funcs.comp"
|
| 6 |
|
| 7 |
+
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
|
| 8 |
+
// 16 invocations needed for init_iq_shmem
|
| 9 |
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
|
| 10 |
#else
|
| 11 |
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
|
@@ -434,6 +434,18 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
|
| 434 |
}
|
| 435 |
#endif
|
| 436 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
|
| 438 |
vec2 get_dm(uint ib, uint a_offset) {
|
| 439 |
return vec2(0, 0);
|
|
@@ -455,6 +467,12 @@ vec2 get_dm(uint ib, uint a_offset) {
|
|
| 455 |
}
|
| 456 |
#endif
|
| 457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
| 459 |
vec2 get_dm(uint ib, uint a_offset) {
|
| 460 |
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
|
|
|
|
| 434 |
}
|
| 435 |
#endif
|
| 436 |
|
| 437 |
+
#if defined(DATA_A_MXFP4)
|
| 438 |
+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
| 439 |
+
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
| 440 |
+
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
|
| 441 |
+
}
|
| 442 |
+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
| 443 |
+
vec2 v0 = dequantize(ib, iqs, a_offset);
|
| 444 |
+
vec2 v1 = dequantize(ib, iqs + 1, a_offset);
|
| 445 |
+
return vec4(v0.x, v0.y, v1.x, v1.y);
|
| 446 |
+
}
|
| 447 |
+
#endif
|
| 448 |
+
|
| 449 |
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
|
| 450 |
vec2 get_dm(uint ib, uint a_offset) {
|
| 451 |
return vec2(0, 0);
|
|
|
|
| 467 |
}
|
| 468 |
#endif
|
| 469 |
|
| 470 |
+
#if defined(DATA_A_MXFP4)
|
| 471 |
+
vec2 get_dm(uint ib, uint a_offset) {
|
| 472 |
+
return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);
|
| 473 |
+
}
|
| 474 |
+
#endif
|
| 475 |
+
|
| 476 |
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
| 477 |
vec2 get_dm(uint ib, uint a_offset) {
|
| 478 |
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
|
|
@@ -654,6 +654,25 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
|
|
| 654 |
}
|
| 655 |
#endif
|
| 656 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 657 |
#if defined(DATA_A_Q4_0)
|
| 658 |
#define dequantFuncA dequantFuncQ4_0
|
| 659 |
#elif defined(DATA_A_Q4_1)
|
|
@@ -696,4 +715,6 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
|
|
| 696 |
#define dequantFuncA dequantFuncIQ4_XS
|
| 697 |
#elif defined(DATA_A_IQ4_NL)
|
| 698 |
#define dequantFuncA dequantFuncIQ4_NL
|
|
|
|
|
|
|
| 699 |
#endif
|
|
|
|
| 654 |
}
|
| 655 |
#endif
|
| 656 |
|
| 657 |
+
#if defined(DATA_A_MXFP4)
|
| 658 |
+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
|
| 659 |
+
block_mxfp4 block;
|
| 660 |
+
};
|
| 661 |
+
|
| 662 |
+
float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 663 |
+
{
|
| 664 |
+
const float d = e8m0_to_fp32(bl.block.e);
|
| 665 |
+
const uint idx = coordInBlock[1];
|
| 666 |
+
const uint iqs = idx & 0xF;
|
| 667 |
+
const uint shift = (idx & 0x10) >> 2;
|
| 668 |
+
uint32_t qs = bl.block.qs[iqs];
|
| 669 |
+
qs >>= shift;
|
| 670 |
+
qs &= 0xF;
|
| 671 |
+
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
|
| 672 |
+
return ret;
|
| 673 |
+
}
|
| 674 |
+
#endif
|
| 675 |
+
|
| 676 |
#if defined(DATA_A_Q4_0)
|
| 677 |
#define dequantFuncA dequantFuncQ4_0
|
| 678 |
#elif defined(DATA_A_Q4_1)
|
|
|
|
| 715 |
#define dequantFuncA dequantFuncIQ4_XS
|
| 716 |
#elif defined(DATA_A_IQ4_NL)
|
| 717 |
#define dequantFuncA dequantFuncIQ4_NL
|
| 718 |
+
#elif defined(DATA_A_MXFP4)
|
| 719 |
+
#define dequantFuncA dequantFuncMXFP4
|
| 720 |
#endif
|
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#include "dequant_head.comp"
|
| 4 |
+
|
| 5 |
+
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
| 6 |
+
|
| 7 |
+
layout (binding = 0) readonly buffer A {block_mxfp4 data_a[];};
|
| 8 |
+
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
| 9 |
+
|
| 10 |
+
void main() {
|
| 11 |
+
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
|
| 12 |
+
|
| 13 |
+
init_iq_shmem(gl_WorkGroupSize);
|
| 14 |
+
|
| 15 |
+
const uint tid = gl_LocalInvocationID.x % 64;
|
| 16 |
+
const uint il = tid/32;
|
| 17 |
+
const uint ir = tid%32;
|
| 18 |
+
const uint ib = 32*i + ir;
|
| 19 |
+
if (ib >= p.nel / 32) {
|
| 20 |
+
return;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
const uint q_idx = 8*il;
|
| 24 |
+
const uint b_idx = 1024*i + 32*ir + q_idx;
|
| 25 |
+
|
| 26 |
+
const float d = e8m0_to_fp32(data_a[ib].e);
|
| 27 |
+
|
| 28 |
+
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
| 29 |
+
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
|
| 30 |
+
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
|
| 31 |
+
}
|
| 32 |
+
}
|