Spaces:
Running
SOTA 3-bit quants (llama/5196)
Browse files* iq3_xxs: quantize/dequantize
RMSE seems a bit high-ish at about half-way between q2_K and
q3_K, so need to check more.
* iq3_xxs: CUDA dequantize works
* iq2_xxs: tuning quantization
* iq3_xxs: starting to look better
PPL on wiki.test.raw
LLaMA-v1-7B: 6.4218
LLaMA-v2-7B: 6.3560
Mistral-7B : 6.0717
This is better than Q3_K_XS, with a 5% reduction in quantized model
size.
* iq3_xxs: CUDA dot product
We have
PP-512: 5891 t/s
TG-128: 143.9 t/s
* iq3_xxs: scalar and AVX2 dot products
* iq3_xxs: ARM_NEON and Metal
Metal performance is decent, ARM_NEON is pathetic
* iq3_xxs: slightly better grid points
* Faster iq3_xxs and iq2_xs dot products on CUDA
* iq3_xxs: add some quant mix
* iq3_xxs: fix failing quantization test
Dot product still fails. Is this real?
* iq3_xxs: hopefully fix ROCm
* iq3_xxs: failing tests
This time the dot product accuracy did find an actual bug
in the AVX2 implementation.
* Add IQ3_XXS to test-backend-ops
---------
Co-authored-by: Iwan Kawrakow <[email protected]>
- ggml-cuda.cu +189 -11
- ggml-metal.m +35 -0
- ggml-metal.metal +274 -0
- ggml-quants.c +630 -0
- ggml-quants.h +17 -1
- ggml.c +30 -0
- ggml.h +2 -0
|
@@ -191,6 +191,10 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
|
| 191 |
#endif // __has_builtin(__builtin_elementwise_sub_sat)
|
| 192 |
}
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
| 195 |
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
|
| 196 |
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
|
@@ -505,6 +509,14 @@ typedef struct {
|
|
| 505 |
} block_iq2_xs;
|
| 506 |
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
|
| 507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
#define WARP_SIZE 32
|
| 509 |
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
| 510 |
|
|
@@ -1613,6 +1625,41 @@ static const __device__ uint64_t iq2xs_grid[512] = {
|
|
| 1613 |
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
| 1614 |
};
|
| 1615 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1616 |
static const __device__ uint8_t ksigns_iq2xs[128] = {
|
| 1617 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 1618 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
@@ -1624,6 +1671,43 @@ static const __device__ uint8_t ksigns_iq2xs[128] = {
|
|
| 1624 |
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
| 1625 |
};
|
| 1626 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1627 |
static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
| 1628 |
|
| 1629 |
inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
|
|
@@ -1690,6 +1774,34 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
|
|
| 1690 |
|
| 1691 |
}
|
| 1692 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1693 |
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
| 1694 |
|
| 1695 |
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
|
@@ -4313,6 +4425,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
|
| 4313 |
|
| 4314 |
static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
| 4315 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
|
|
|
| 4316 |
#if QK_K == 256
|
| 4317 |
const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
|
| 4318 |
|
|
@@ -4323,20 +4436,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
|
| 4323 |
const uint8_t ls2 = bq2->scales[ib32] >> 4;
|
| 4324 |
int sumi1 = 0;
|
| 4325 |
for (int l = 0; l < 2; ++l) {
|
| 4326 |
-
const
|
| 4327 |
-
const
|
| 4328 |
-
|
| 4329 |
-
|
| 4330 |
-
|
|
|
|
| 4331 |
q8 += 8;
|
| 4332 |
}
|
| 4333 |
int sumi2 = 0;
|
| 4334 |
for (int l = 2; l < 4; ++l) {
|
| 4335 |
-
const
|
| 4336 |
-
const
|
| 4337 |
-
|
| 4338 |
-
|
| 4339 |
-
|
|
|
|
| 4340 |
q8 += 8;
|
| 4341 |
}
|
| 4342 |
const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f;
|
|
@@ -4345,6 +4460,45 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
|
| 4345 |
assert(false);
|
| 4346 |
return 0.f;
|
| 4347 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4348 |
}
|
| 4349 |
|
| 4350 |
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
|
@@ -6394,6 +6548,12 @@ static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k,
|
|
| 6394 |
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
|
| 6395 |
}
|
| 6396 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6397 |
template <typename src_t, typename dst_t>
|
| 6398 |
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
| 6399 |
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
|
@@ -6431,6 +6591,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|
| 6431 |
return dequantize_row_iq2_xxs_cuda;
|
| 6432 |
case GGML_TYPE_IQ2_XS:
|
| 6433 |
return dequantize_row_iq2_xs_cuda;
|
|
|
|
|
|
|
| 6434 |
case GGML_TYPE_F32:
|
| 6435 |
return convert_unary_cuda<float>;
|
| 6436 |
default:
|
|
@@ -6464,6 +6626,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|
| 6464 |
return dequantize_row_iq2_xxs_cuda;
|
| 6465 |
case GGML_TYPE_IQ2_XS:
|
| 6466 |
return dequantize_row_iq2_xs_cuda;
|
|
|
|
|
|
|
| 6467 |
case GGML_TYPE_F16:
|
| 6468 |
return convert_unary_cuda<half>;
|
| 6469 |
default:
|
|
@@ -6676,6 +6840,15 @@ static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, float
|
|
| 6676 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
| 6677 |
}
|
| 6678 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6679 |
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
| 6680 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 6681 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
|
@@ -8239,6 +8412,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
|
|
| 8239 |
case GGML_TYPE_Q6_K:
|
| 8240 |
case GGML_TYPE_IQ2_XXS:
|
| 8241 |
case GGML_TYPE_IQ2_XS:
|
|
|
|
| 8242 |
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
| 8243 |
default:
|
| 8244 |
GGML_ASSERT(false);
|
|
@@ -8261,6 +8435,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
|
|
| 8261 |
case GGML_TYPE_Q5_K:
|
| 8262 |
case GGML_TYPE_IQ2_XXS:
|
| 8263 |
case GGML_TYPE_IQ2_XS:
|
|
|
|
| 8264 |
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
| 8265 |
case GGML_TYPE_Q6_K:
|
| 8266 |
return 64;
|
|
@@ -8332,6 +8507,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
|
| 8332 |
case GGML_TYPE_IQ2_XS:
|
| 8333 |
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
| 8334 |
break;
|
|
|
|
|
|
|
|
|
|
| 8335 |
default:
|
| 8336 |
GGML_ASSERT(false);
|
| 8337 |
break;
|
|
@@ -10968,7 +11146,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 10968 |
return false;
|
| 10969 |
}
|
| 10970 |
ggml_type a_type = a->type;
|
| 10971 |
-
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS) {
|
| 10972 |
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
| 10973 |
return false;
|
| 10974 |
}
|
|
|
|
| 191 |
#endif // __has_builtin(__builtin_elementwise_sub_sat)
|
| 192 |
}
|
| 193 |
|
| 194 |
+
static __device__ __forceinline__ int __vsub4(const int a, const int b) {
|
| 195 |
+
return __vsubss4(a, b);
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
| 199 |
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
|
| 200 |
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
|
|
|
| 509 |
} block_iq2_xs;
|
| 510 |
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
|
| 511 |
|
| 512 |
+
#define QR3_XXS 8
|
| 513 |
+
#define QI3_XXS (QK_K / (4*QR3_XXS))
|
| 514 |
+
typedef struct {
|
| 515 |
+
half d;
|
| 516 |
+
uint8_t qs[3*(QK_K/8)];
|
| 517 |
+
} block_iq3_xxs;
|
| 518 |
+
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
|
| 519 |
+
|
| 520 |
#define WARP_SIZE 32
|
| 521 |
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
| 522 |
|
|
|
|
| 1625 |
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
| 1626 |
};
|
| 1627 |
|
| 1628 |
+
static const __device__ uint32_t iq3xxs_grid[256] = {
|
| 1629 |
+
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
|
| 1630 |
+
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
| 1631 |
+
0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
|
| 1632 |
+
0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
|
| 1633 |
+
0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
|
| 1634 |
+
0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
|
| 1635 |
+
0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
|
| 1636 |
+
0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
|
| 1637 |
+
0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
|
| 1638 |
+
0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
|
| 1639 |
+
0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
|
| 1640 |
+
0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
|
| 1641 |
+
0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
|
| 1642 |
+
0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
|
| 1643 |
+
0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
|
| 1644 |
+
0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
|
| 1645 |
+
0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
|
| 1646 |
+
0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
|
| 1647 |
+
0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
|
| 1648 |
+
0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
|
| 1649 |
+
0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
|
| 1650 |
+
0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
|
| 1651 |
+
0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
|
| 1652 |
+
0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
|
| 1653 |
+
0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
|
| 1654 |
+
0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
|
| 1655 |
+
0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
|
| 1656 |
+
0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
|
| 1657 |
+
0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
|
| 1658 |
+
0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
|
| 1659 |
+
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
|
| 1660 |
+
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 1661 |
+
};
|
| 1662 |
+
|
| 1663 |
static const __device__ uint8_t ksigns_iq2xs[128] = {
|
| 1664 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 1665 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
|
|
| 1671 |
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
| 1672 |
};
|
| 1673 |
|
| 1674 |
+
//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
| 1675 |
+
static const __device__ uint64_t ksigns64[128] = {
|
| 1676 |
+
0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
|
| 1677 |
+
0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
|
| 1678 |
+
0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff,
|
| 1679 |
+
0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff,
|
| 1680 |
+
0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff,
|
| 1681 |
+
0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff,
|
| 1682 |
+
0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff,
|
| 1683 |
+
0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff,
|
| 1684 |
+
0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff,
|
| 1685 |
+
0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff,
|
| 1686 |
+
0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff,
|
| 1687 |
+
0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff,
|
| 1688 |
+
0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff,
|
| 1689 |
+
0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff,
|
| 1690 |
+
0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff,
|
| 1691 |
+
0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff,
|
| 1692 |
+
0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff,
|
| 1693 |
+
0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff,
|
| 1694 |
+
0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff,
|
| 1695 |
+
0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff,
|
| 1696 |
+
0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff,
|
| 1697 |
+
0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff,
|
| 1698 |
+
0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff,
|
| 1699 |
+
0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff,
|
| 1700 |
+
0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff,
|
| 1701 |
+
0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff,
|
| 1702 |
+
0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff,
|
| 1703 |
+
0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff,
|
| 1704 |
+
0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff,
|
| 1705 |
+
0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff,
|
| 1706 |
+
0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
|
| 1707 |
+
0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
|
| 1708 |
+
};
|
| 1709 |
+
//#endif
|
| 1710 |
+
|
| 1711 |
static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
| 1712 |
|
| 1713 |
inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
|
|
|
|
| 1774 |
|
| 1775 |
}
|
| 1776 |
|
| 1777 |
+
template<typename dst_t>
|
| 1778 |
+
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
| 1779 |
+
|
| 1780 |
+
const int i = blockIdx.x;
|
| 1781 |
+
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
| 1782 |
+
|
| 1783 |
+
const int tid = threadIdx.x;
|
| 1784 |
+
#if QK_K == 256
|
| 1785 |
+
const int il = tid/8; // 0...3
|
| 1786 |
+
const int ib = tid%8; // 0...7
|
| 1787 |
+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
| 1788 |
+
const uint8_t * q3 = x[i].qs + 8*ib;
|
| 1789 |
+
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
|
| 1790 |
+
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
|
| 1791 |
+
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
|
| 1792 |
+
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
| 1793 |
+
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
|
| 1794 |
+
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
| 1795 |
+
for (int j = 0; j < 4; ++j) {
|
| 1796 |
+
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
| 1797 |
+
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
| 1798 |
+
}
|
| 1799 |
+
#else
|
| 1800 |
+
assert(false);
|
| 1801 |
+
#endif
|
| 1802 |
+
|
| 1803 |
+
}
|
| 1804 |
+
|
| 1805 |
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
| 1806 |
|
| 1807 |
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
|
|
|
| 4425 |
|
| 4426 |
static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
| 4427 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
| 4428 |
+
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
| 4429 |
#if QK_K == 256
|
| 4430 |
const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
|
| 4431 |
|
|
|
|
| 4436 |
const uint8_t ls2 = bq2->scales[ib32] >> 4;
|
| 4437 |
int sumi1 = 0;
|
| 4438 |
for (int l = 0; l < 2; ++l) {
|
| 4439 |
+
const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
|
| 4440 |
+
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
|
| 4441 |
+
const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]);
|
| 4442 |
+
const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]);
|
| 4443 |
+
sumi1 = __dp4a(grid_l, *((const int *)q8 + 0), sumi1);
|
| 4444 |
+
sumi1 = __dp4a(grid_h, *((const int *)q8 + 1), sumi1);
|
| 4445 |
q8 += 8;
|
| 4446 |
}
|
| 4447 |
int sumi2 = 0;
|
| 4448 |
for (int l = 2; l < 4; ++l) {
|
| 4449 |
+
const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
|
| 4450 |
+
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
|
| 4451 |
+
const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]);
|
| 4452 |
+
const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]);
|
| 4453 |
+
sumi2 = __dp4a(grid_l, *((const int *)q8 + 0), sumi2);
|
| 4454 |
+
sumi2 = __dp4a(grid_h, *((const int *)q8 + 1), sumi2);
|
| 4455 |
q8 += 8;
|
| 4456 |
}
|
| 4457 |
const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f;
|
|
|
|
| 4460 |
assert(false);
|
| 4461 |
return 0.f;
|
| 4462 |
#endif
|
| 4463 |
+
#else
|
| 4464 |
+
assert(false);
|
| 4465 |
+
return 0.f;
|
| 4466 |
+
#endif
|
| 4467 |
+
}
|
| 4468 |
+
|
| 4469 |
+
static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
| 4470 |
+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
| 4471 |
+
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
| 4472 |
+
#if QK_K == 256
|
| 4473 |
+
const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
|
| 4474 |
+
|
| 4475 |
+
const int ib32 = iqs;
|
| 4476 |
+
const uint8_t * q3 = bq2->qs + 8*ib32;
|
| 4477 |
+
const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32;
|
| 4478 |
+
const int8_t * q8 = bq8_1[ib32].qs;
|
| 4479 |
+
uint32_t aux32 = gas[0] | (gas[1] << 16);
|
| 4480 |
+
int sumi = 0;
|
| 4481 |
+
for (int l = 0; l < 4; ++l) {
|
| 4482 |
+
const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0];
|
| 4483 |
+
const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1];
|
| 4484 |
+
const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127));
|
| 4485 |
+
const int grid_l = __vsub4(grid1[0] ^ signs[0], signs[0]);
|
| 4486 |
+
const int grid_h = __vsub4(grid2[0] ^ signs[1], signs[1]);
|
| 4487 |
+
sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
|
| 4488 |
+
sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
|
| 4489 |
+
q8 += 8;
|
| 4490 |
+
aux32 >>= 7;
|
| 4491 |
+
}
|
| 4492 |
+
const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.5f;
|
| 4493 |
+
return d * sumi;
|
| 4494 |
+
#else
|
| 4495 |
+
assert(false);
|
| 4496 |
+
return 0.f;
|
| 4497 |
+
#endif
|
| 4498 |
+
#else
|
| 4499 |
+
assert(false);
|
| 4500 |
+
return 0.f;
|
| 4501 |
+
#endif
|
| 4502 |
}
|
| 4503 |
|
| 4504 |
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
|
|
|
| 6548 |
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
|
| 6549 |
}
|
| 6550 |
|
| 6551 |
+
template<typename dst_t>
|
| 6552 |
+
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
| 6553 |
+
const int nb = k / QK_K;
|
| 6554 |
+
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
| 6555 |
+
}
|
| 6556 |
+
|
| 6557 |
template <typename src_t, typename dst_t>
|
| 6558 |
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
| 6559 |
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
|
|
|
| 6591 |
return dequantize_row_iq2_xxs_cuda;
|
| 6592 |
case GGML_TYPE_IQ2_XS:
|
| 6593 |
return dequantize_row_iq2_xs_cuda;
|
| 6594 |
+
case GGML_TYPE_IQ3_XXS:
|
| 6595 |
+
return dequantize_row_iq3_xxs_cuda;
|
| 6596 |
case GGML_TYPE_F32:
|
| 6597 |
return convert_unary_cuda<float>;
|
| 6598 |
default:
|
|
|
|
| 6626 |
return dequantize_row_iq2_xxs_cuda;
|
| 6627 |
case GGML_TYPE_IQ2_XS:
|
| 6628 |
return dequantize_row_iq2_xs_cuda;
|
| 6629 |
+
case GGML_TYPE_IQ3_XXS:
|
| 6630 |
+
return dequantize_row_iq3_xxs_cuda;
|
| 6631 |
case GGML_TYPE_F16:
|
| 6632 |
return convert_unary_cuda<half>;
|
| 6633 |
default:
|
|
|
|
| 6840 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
| 6841 |
}
|
| 6842 |
|
| 6843 |
+
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 6844 |
+
GGML_ASSERT(ncols % QK_K == 0);
|
| 6845 |
+
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 6846 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 6847 |
+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 6848 |
+
mul_mat_vec_q<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
|
| 6849 |
+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
| 6850 |
+
}
|
| 6851 |
+
|
| 6852 |
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
| 6853 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 6854 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
|
|
|
| 8412 |
case GGML_TYPE_Q6_K:
|
| 8413 |
case GGML_TYPE_IQ2_XXS:
|
| 8414 |
case GGML_TYPE_IQ2_XS:
|
| 8415 |
+
case GGML_TYPE_IQ3_XXS:
|
| 8416 |
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
| 8417 |
default:
|
| 8418 |
GGML_ASSERT(false);
|
|
|
|
| 8435 |
case GGML_TYPE_Q5_K:
|
| 8436 |
case GGML_TYPE_IQ2_XXS:
|
| 8437 |
case GGML_TYPE_IQ2_XS:
|
| 8438 |
+
case GGML_TYPE_IQ3_XXS:
|
| 8439 |
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
| 8440 |
case GGML_TYPE_Q6_K:
|
| 8441 |
return 64;
|
|
|
|
| 8507 |
case GGML_TYPE_IQ2_XS:
|
| 8508 |
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
| 8509 |
break;
|
| 8510 |
+
case GGML_TYPE_IQ3_XXS:
|
| 8511 |
+
mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
| 8512 |
+
break;
|
| 8513 |
default:
|
| 8514 |
GGML_ASSERT(false);
|
| 8515 |
break;
|
|
|
|
| 11146 |
return false;
|
| 11147 |
}
|
| 11148 |
ggml_type a_type = a->type;
|
| 11149 |
+
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS) {
|
| 11150 |
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
| 11151 |
return false;
|
| 11152 |
}
|
|
@@ -60,6 +60,7 @@ enum ggml_metal_kernel_type {
|
|
| 60 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
|
| 61 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
| 62 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
|
|
|
| 63 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
| 64 |
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
| 65 |
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
@@ -81,6 +82,7 @@ enum ggml_metal_kernel_type {
|
|
| 81 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
|
| 82 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
|
| 83 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
|
|
|
| 84 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
| 85 |
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
| 86 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
|
@@ -98,6 +100,7 @@ enum ggml_metal_kernel_type {
|
|
| 98 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
|
| 99 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
|
| 100 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
|
|
|
| 101 |
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
| 102 |
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
| 103 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
|
@@ -112,6 +115,7 @@ enum ggml_metal_kernel_type {
|
|
| 112 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
|
| 113 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
| 114 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
|
|
|
| 115 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
| 116 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
| 117 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
|
@@ -126,6 +130,7 @@ enum ggml_metal_kernel_type {
|
|
| 126 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
|
| 127 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
| 128 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
|
|
|
| 129 |
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
| 130 |
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
| 131 |
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
|
@@ -426,6 +431,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 426 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
| 427 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
| 428 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
|
|
|
| 429 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
| 430 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
| 431 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
|
@@ -447,6 +453,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 447 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
|
| 448 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
| 449 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
|
|
|
| 450 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
| 451 |
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
| 452 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
|
@@ -464,6 +471,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 464 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
|
| 465 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
| 466 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
|
|
|
| 467 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
| 468 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
| 469 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
|
@@ -478,6 +486,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 478 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
|
| 479 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
| 480 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
|
|
|
| 481 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
| 482 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
| 483 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
|
@@ -492,6 +501,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 492 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
|
| 493 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
| 494 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
|
|
|
| 495 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
| 496 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
| 497 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
|
@@ -1279,6 +1289,7 @@ static bool ggml_metal_graph_compute(
|
|
| 1279 |
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
|
| 1280 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
| 1281 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
|
|
|
| 1282 |
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
| 1283 |
}
|
| 1284 |
|
|
@@ -1407,6 +1418,12 @@ static bool ggml_metal_graph_compute(
|
|
| 1407 |
nth1 = 16;
|
| 1408 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
| 1409 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1410 |
default:
|
| 1411 |
{
|
| 1412 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
@@ -1449,6 +1466,11 @@ static bool ggml_metal_graph_compute(
|
|
| 1449 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1450 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1451 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1452 |
else if (src0t == GGML_TYPE_Q4_K) {
|
| 1453 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1454 |
}
|
|
@@ -1543,6 +1565,7 @@ static bool ggml_metal_graph_compute(
|
|
| 1543 |
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
|
| 1544 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
| 1545 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
|
|
|
| 1546 |
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
| 1547 |
}
|
| 1548 |
|
|
@@ -1674,6 +1697,12 @@ static bool ggml_metal_graph_compute(
|
|
| 1674 |
nth1 = 16;
|
| 1675 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
|
| 1676 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1677 |
default:
|
| 1678 |
{
|
| 1679 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
|
@@ -1732,6 +1761,11 @@ static bool ggml_metal_graph_compute(
|
|
| 1732 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1733 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1734 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1735 |
else if (src2t == GGML_TYPE_Q4_K) {
|
| 1736 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1737 |
}
|
|
@@ -1772,6 +1806,7 @@ static bool ggml_metal_graph_compute(
|
|
| 1772 |
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
|
| 1773 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
|
| 1774 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
|
|
|
| 1775 |
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
| 1776 |
default: GGML_ASSERT(false && "not implemented");
|
| 1777 |
}
|
|
|
|
| 60 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
|
| 61 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
| 62 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
| 63 |
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
|
| 64 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
| 65 |
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
| 66 |
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
|
|
| 82 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
|
| 83 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
|
| 84 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
| 85 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
|
| 86 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
| 87 |
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
| 88 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
|
|
|
| 100 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
|
| 101 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
|
| 102 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
| 103 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
|
| 104 |
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
| 105 |
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
| 106 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
|
|
|
| 115 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
|
| 116 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
| 117 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
| 118 |
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
|
| 119 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
| 120 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
| 121 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
|
|
|
| 130 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
|
| 131 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
| 132 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
| 133 |
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
|
| 134 |
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
| 135 |
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
| 136 |
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
|
|
|
| 431 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
| 432 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
| 433 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
| 434 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
| 435 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
| 436 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
| 437 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
|
|
|
| 453 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
|
| 454 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
| 455 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
| 456 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
| 457 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
| 458 |
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
| 459 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
|
|
|
| 471 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
|
| 472 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
| 473 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
| 474 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
| 475 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
| 476 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
| 477 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
|
|
|
| 486 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
|
| 487 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
| 488 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
| 489 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
| 490 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
| 491 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
| 492 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
|
|
|
| 501 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
|
| 502 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
| 503 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
| 504 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
| 505 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
| 506 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
| 507 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
|
|
|
| 1289 |
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
|
| 1290 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
| 1291 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
| 1292 |
+
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
| 1293 |
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
| 1294 |
}
|
| 1295 |
|
|
|
|
| 1418 |
nth1 = 16;
|
| 1419 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
| 1420 |
} break;
|
| 1421 |
+
case GGML_TYPE_IQ3_XXS:
|
| 1422 |
+
{
|
| 1423 |
+
nth0 = 4;
|
| 1424 |
+
nth1 = 16;
|
| 1425 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
| 1426 |
+
} break;
|
| 1427 |
default:
|
| 1428 |
{
|
| 1429 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
|
|
| 1466 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1467 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1468 |
}
|
| 1469 |
+
else if (src0t == GGML_TYPE_IQ3_XXS) {
|
| 1470 |
+
const int mem_size = 256*4+128;
|
| 1471 |
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1472 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1473 |
+
}
|
| 1474 |
else if (src0t == GGML_TYPE_Q4_K) {
|
| 1475 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1476 |
}
|
|
|
|
| 1565 |
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
|
| 1566 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
| 1567 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
| 1568 |
+
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
| 1569 |
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
| 1570 |
}
|
| 1571 |
|
|
|
|
| 1697 |
nth1 = 16;
|
| 1698 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
|
| 1699 |
} break;
|
| 1700 |
+
case GGML_TYPE_IQ3_XXS:
|
| 1701 |
+
{
|
| 1702 |
+
nth0 = 4;
|
| 1703 |
+
nth1 = 16;
|
| 1704 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
| 1705 |
+
} break;
|
| 1706 |
default:
|
| 1707 |
{
|
| 1708 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
|
|
|
| 1761 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1762 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1763 |
}
|
| 1764 |
+
else if (src2t == GGML_TYPE_IQ3_XXS) {
|
| 1765 |
+
const int mem_size = 256*4+128;
|
| 1766 |
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1767 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1768 |
+
}
|
| 1769 |
else if (src2t == GGML_TYPE_Q4_K) {
|
| 1770 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1771 |
}
|
|
|
|
| 1806 |
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
|
| 1807 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
|
| 1808 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
| 1809 |
+
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
|
| 1810 |
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
| 1811 |
default: GGML_ASSERT(false && "not implemented");
|
| 1812 |
}
|
|
@@ -2459,6 +2459,12 @@ typedef struct {
|
|
| 2459 |
} block_iq2_xs;
|
| 2460 |
// 74 bytes / block for QK_K = 256, so 2.3125 bpw
|
| 2461 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2462 |
//====================================== dot products =========================
|
| 2463 |
|
| 2464 |
void kernel_mul_mv_q2_K_f32_impl(
|
|
@@ -3681,6 +3687,42 @@ constexpr constant static uint64_t iq2xs_grid[512] = {
|
|
| 3681 |
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
| 3682 |
};
|
| 3683 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3684 |
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
| 3685 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 3686 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
@@ -3970,6 +4012,143 @@ kernel void kernel_mul_mv_iq2_xs_f32(
|
|
| 3970 |
kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 3971 |
}
|
| 3972 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3973 |
//============================= templates and their specializations =============================
|
| 3974 |
|
| 3975 |
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
@@ -4287,6 +4466,33 @@ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4
|
|
| 4287 |
}
|
| 4288 |
}
|
| 4289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4290 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 4291 |
kernel void kernel_get_rows(
|
| 4292 |
device const void * src0,
|
|
@@ -4828,6 +5034,7 @@ template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows
|
|
| 4828 |
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 4829 |
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 4830 |
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
|
|
| 4831 |
|
| 4832 |
//
|
| 4833 |
// matrix-matrix multiplication
|
|
@@ -4866,6 +5073,7 @@ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
| 4866 |
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 4867 |
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 4868 |
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
|
|
| 4869 |
|
| 4870 |
//
|
| 4871 |
// indirect matrix-matrix multiplication
|
|
@@ -4916,6 +5124,7 @@ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
| 4916 |
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 4917 |
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 4918 |
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
|
|
| 4919 |
|
| 4920 |
//
|
| 4921 |
// matrix-vector multiplication
|
|
@@ -5818,3 +6027,68 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
| 5818 |
tiisg,
|
| 5819 |
sgitg);
|
| 5820 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2459 |
} block_iq2_xs;
|
| 2460 |
// 74 bytes / block for QK_K = 256, so 2.3125 bpw
|
| 2461 |
|
| 2462 |
+
typedef struct {
|
| 2463 |
+
half d;
|
| 2464 |
+
uint8_t qs[3*QK_K/8];
|
| 2465 |
+
} block_iq3_xxs;
|
| 2466 |
+
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
|
| 2467 |
+
|
| 2468 |
//====================================== dot products =========================
|
| 2469 |
|
| 2470 |
void kernel_mul_mv_q2_K_f32_impl(
|
|
|
|
| 3687 |
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
| 3688 |
};
|
| 3689 |
|
| 3690 |
+
constexpr constant static uint32_t iq3xxs_grid[256] = {
|
| 3691 |
+
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3c, 0x04041404, 0x04041414,
|
| 3692 |
+
0x04041c0c, 0x04042414, 0x04043c1c, 0x04043c2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
| 3693 |
+
0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3c04, 0x04140404,
|
| 3694 |
+
0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3c,
|
| 3695 |
+
0x04142c0c, 0x04142c3c, 0x04143c2c, 0x041c040c, 0x041c043c, 0x041c0c04, 0x041c0c14, 0x041c142c,
|
| 3696 |
+
0x041c3c04, 0x04240c1c, 0x04241c3c, 0x04242424, 0x04242c3c, 0x04243c1c, 0x04243c2c, 0x042c040c,
|
| 3697 |
+
0x042c043c, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043c0c04, 0x043c0c24, 0x043c0c34,
|
| 3698 |
+
0x043c241c, 0x043c340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
|
| 3699 |
+
0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243c, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
|
| 3700 |
+
0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
|
| 3701 |
+
0x0c143c14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
|
| 3702 |
+
0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3c0c, 0x0c34042c, 0x0c3c1414,
|
| 3703 |
+
0x0c3c2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
|
| 3704 |
+
0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
|
| 3705 |
+
0x140c1c04, 0x140c341c, 0x140c343c, 0x140c3c04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3c,
|
| 3706 |
+
0x14141404, 0x14141414, 0x14141c3c, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
|
| 3707 |
+
0x141c3c04, 0x141c3c24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143c, 0x142c240c, 0x142c3c24,
|
| 3708 |
+
0x143c040c, 0x143c041c, 0x143c0c34, 0x143c242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
|
| 3709 |
+
0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043c14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
|
| 3710 |
+
0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143c14,
|
| 3711 |
+
0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243c, 0x1c243c14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
|
| 3712 |
+
0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3c1c1c, 0x1c3c3404, 0x24040424, 0x24040c3c,
|
| 3713 |
+
0x24041c2c, 0x24041c3c, 0x24042c1c, 0x24042c3c, 0x240c3c24, 0x24141404, 0x24141c3c, 0x24142404,
|
| 3714 |
+
0x24143404, 0x24143434, 0x241c043c, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
|
| 3715 |
+
0x242c241c, 0x242c3c04, 0x243c042c, 0x243c0c04, 0x243c0c14, 0x243c1c04, 0x2c040c14, 0x2c04240c,
|
| 3716 |
+
0x2c043c04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143c14,
|
| 3717 |
+
0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143c, 0x2c243c14, 0x2c2c0414, 0x2c2c1c0c,
|
| 3718 |
+
0x2c342c04, 0x2c3c1424, 0x2c3c2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
|
| 3719 |
+
0x340c340c, 0x34140c3c, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
|
| 3720 |
+
0x34341c1c, 0x343c041c, 0x343c140c, 0x3c04041c, 0x3c04042c, 0x3c04043c, 0x3c040c04, 0x3c041c14,
|
| 3721 |
+
0x3c042c14, 0x3c0c1434, 0x3c0c2404, 0x3c140c14, 0x3c14242c, 0x3c142c14, 0x3c1c0404, 0x3c1c0c2c,
|
| 3722 |
+
0x3c1c1c1c, 0x3c1c3404, 0x3c24140c, 0x3c24240c, 0x3c2c0404, 0x3c2c0414, 0x3c2c1424, 0x3c341c04,
|
| 3723 |
+
};
|
| 3724 |
+
|
| 3725 |
+
|
| 3726 |
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
| 3727 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 3728 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
|
|
| 4012 |
kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4013 |
}
|
| 4014 |
|
| 4015 |
+
void kernel_mul_mv_iq3_xxs_f32_impl(
|
| 4016 |
+
device const void * src0,
|
| 4017 |
+
device const float * src1,
|
| 4018 |
+
device float * dst,
|
| 4019 |
+
constant int64_t & ne00,
|
| 4020 |
+
constant int64_t & ne01,
|
| 4021 |
+
constant int64_t & ne02,
|
| 4022 |
+
constant int64_t & ne10,
|
| 4023 |
+
constant int64_t & ne12,
|
| 4024 |
+
constant int64_t & ne0,
|
| 4025 |
+
constant int64_t & ne1,
|
| 4026 |
+
constant uint & r2,
|
| 4027 |
+
constant uint & r3,
|
| 4028 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 4029 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4030 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4031 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4032 |
+
|
| 4033 |
+
const int nb = ne00/QK_K;
|
| 4034 |
+
const int r0 = tgpig.x;
|
| 4035 |
+
const int r1 = tgpig.y;
|
| 4036 |
+
const int im = tgpig.z;
|
| 4037 |
+
|
| 4038 |
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 4039 |
+
const int ib_row = first_row * nb;
|
| 4040 |
+
|
| 4041 |
+
const uint i12 = im%ne12;
|
| 4042 |
+
const uint i13 = im/ne12;
|
| 4043 |
+
|
| 4044 |
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
| 4045 |
+
|
| 4046 |
+
device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0;
|
| 4047 |
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
| 4048 |
+
|
| 4049 |
+
float yl[32];
|
| 4050 |
+
float sumf[N_DST]={0.f}, all_sum;
|
| 4051 |
+
|
| 4052 |
+
const int nb32 = nb * (QK_K / 32);
|
| 4053 |
+
|
| 4054 |
+
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
|
| 4055 |
+
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
|
| 4056 |
+
{
|
| 4057 |
+
int nval = 4;
|
| 4058 |
+
int pos = (32*sgitg + tiisg)*nval;
|
| 4059 |
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i];
|
| 4060 |
+
nval = 2;
|
| 4061 |
+
pos = (32*sgitg + tiisg)*nval;
|
| 4062 |
+
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
| 4063 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 4064 |
+
}
|
| 4065 |
+
|
| 4066 |
+
#if QK_K == 256
|
| 4067 |
+
const int ix = tiisg;
|
| 4068 |
+
|
| 4069 |
+
device const float * y4 = y + 32 * ix;
|
| 4070 |
+
|
| 4071 |
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 4072 |
+
|
| 4073 |
+
for (int i = 0; i < 32; ++i) {
|
| 4074 |
+
yl[i] = y4[i];
|
| 4075 |
+
}
|
| 4076 |
+
|
| 4077 |
+
const int ibl = ib32 / (QK_K / 32);
|
| 4078 |
+
const int ib = ib32 % (QK_K / 32);
|
| 4079 |
+
|
| 4080 |
+
device const block_iq3_xxs * xr = x + ibl;
|
| 4081 |
+
device const uint8_t * q3 = xr->qs + 8 * ib;
|
| 4082 |
+
device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
|
| 4083 |
+
device const half * dh = &xr->d;
|
| 4084 |
+
|
| 4085 |
+
for (int row = 0; row < N_DST; row++) {
|
| 4086 |
+
|
| 4087 |
+
const float db = dh[0];
|
| 4088 |
+
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
| 4089 |
+
const float d = db * (0.5f + (aux32 >> 28));
|
| 4090 |
+
|
| 4091 |
+
float2 sum = {0};
|
| 4092 |
+
for (int l = 0; l < 4; ++l) {
|
| 4093 |
+
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]);
|
| 4094 |
+
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]);
|
| 4095 |
+
const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
|
| 4096 |
+
for (int j = 0; j < 4; ++j) {
|
| 4097 |
+
sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
| 4098 |
+
sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
| 4099 |
+
}
|
| 4100 |
+
}
|
| 4101 |
+
sumf[row] += d * (sum[0] + sum[1]);
|
| 4102 |
+
|
| 4103 |
+
dh += nb*sizeof(block_iq3_xxs)/2;
|
| 4104 |
+
q3 += nb*sizeof(block_iq3_xxs);
|
| 4105 |
+
gas += nb*sizeof(block_iq3_xxs)/2;
|
| 4106 |
+
}
|
| 4107 |
+
|
| 4108 |
+
y4 += 32 * 32;
|
| 4109 |
+
}
|
| 4110 |
+
#else
|
| 4111 |
+
// TODO
|
| 4112 |
+
#endif
|
| 4113 |
+
|
| 4114 |
+
for (int row = 0; row < N_DST; ++row) {
|
| 4115 |
+
all_sum = simd_sum(sumf[row]);
|
| 4116 |
+
if (tiisg == 0) {
|
| 4117 |
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
|
| 4118 |
+
}
|
| 4119 |
+
}
|
| 4120 |
+
}
|
| 4121 |
+
|
| 4122 |
+
[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
|
| 4123 |
+
kernel void kernel_mul_mv_iq3_xxs_f32(
|
| 4124 |
+
device const void * src0,
|
| 4125 |
+
device const float * src1,
|
| 4126 |
+
device float * dst,
|
| 4127 |
+
constant int64_t & ne00,
|
| 4128 |
+
constant int64_t & ne01,
|
| 4129 |
+
constant int64_t & ne02,
|
| 4130 |
+
constant uint64_t & nb00,
|
| 4131 |
+
constant uint64_t & nb01,
|
| 4132 |
+
constant uint64_t & nb02,
|
| 4133 |
+
constant int64_t & ne10,
|
| 4134 |
+
constant int64_t & ne11,
|
| 4135 |
+
constant int64_t & ne12,
|
| 4136 |
+
constant uint64_t & nb10,
|
| 4137 |
+
constant uint64_t & nb11,
|
| 4138 |
+
constant uint64_t & nb12,
|
| 4139 |
+
constant int64_t & ne0,
|
| 4140 |
+
constant int64_t & ne1,
|
| 4141 |
+
constant uint & r2,
|
| 4142 |
+
constant uint & r3,
|
| 4143 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 4144 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4145 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4146 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4147 |
+
|
| 4148 |
+
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4149 |
+
}
|
| 4150 |
+
|
| 4151 |
+
|
| 4152 |
//============================= templates and their specializations =============================
|
| 4153 |
|
| 4154 |
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
|
|
| 4466 |
}
|
| 4467 |
}
|
| 4468 |
|
| 4469 |
+
template <typename type4x4>
|
| 4470 |
+
void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
|
| 4471 |
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
| 4472 |
+
const float d = xb->d;
|
| 4473 |
+
const int ib32 = il/2;
|
| 4474 |
+
il = il%2;
|
| 4475 |
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
| 4476 |
+
device const uint8_t * q3 = xb->qs + 8*ib32;
|
| 4477 |
+
device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
|
| 4478 |
+
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
| 4479 |
+
const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
|
| 4480 |
+
constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
|
| 4481 |
+
constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
|
| 4482 |
+
uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
|
| 4483 |
+
for (int i = 0; i < 4; ++i) {
|
| 4484 |
+
reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
| 4485 |
+
reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
| 4486 |
+
}
|
| 4487 |
+
grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
|
| 4488 |
+
grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
|
| 4489 |
+
signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
|
| 4490 |
+
for (int i = 0; i < 4; ++i) {
|
| 4491 |
+
reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
| 4492 |
+
reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
| 4493 |
+
}
|
| 4494 |
+
}
|
| 4495 |
+
|
| 4496 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 4497 |
kernel void kernel_get_rows(
|
| 4498 |
device const void * src0,
|
|
|
|
| 5034 |
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 5035 |
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5036 |
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5037 |
+
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
| 5038 |
|
| 5039 |
//
|
| 5040 |
// matrix-matrix multiplication
|
|
|
|
| 5073 |
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 5074 |
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5075 |
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5076 |
+
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
| 5077 |
|
| 5078 |
//
|
| 5079 |
// indirect matrix-matrix multiplication
|
|
|
|
| 5124 |
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 5125 |
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5126 |
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5127 |
+
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
| 5128 |
|
| 5129 |
//
|
| 5130 |
// matrix-vector multiplication
|
|
|
|
| 6027 |
tiisg,
|
| 6028 |
sgitg);
|
| 6029 |
}
|
| 6030 |
+
|
| 6031 |
+
[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
|
| 6032 |
+
kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
| 6033 |
+
device const char * ids,
|
| 6034 |
+
device const char * src1,
|
| 6035 |
+
device float * dst,
|
| 6036 |
+
constant uint64_t & nbi1,
|
| 6037 |
+
constant int64_t & ne00,
|
| 6038 |
+
constant int64_t & ne01,
|
| 6039 |
+
constant int64_t & ne02,
|
| 6040 |
+
constant uint64_t & nb00,
|
| 6041 |
+
constant uint64_t & nb01,
|
| 6042 |
+
constant uint64_t & nb02,
|
| 6043 |
+
constant int64_t & ne10,
|
| 6044 |
+
constant int64_t & ne11,
|
| 6045 |
+
constant int64_t & ne12,
|
| 6046 |
+
constant int64_t & ne13,
|
| 6047 |
+
constant uint64_t & nb10,
|
| 6048 |
+
constant uint64_t & nb11,
|
| 6049 |
+
constant uint64_t & nb12,
|
| 6050 |
+
constant int64_t & ne0,
|
| 6051 |
+
constant int64_t & ne1,
|
| 6052 |
+
constant uint64_t & nb1,
|
| 6053 |
+
constant uint & r2,
|
| 6054 |
+
constant uint & r3,
|
| 6055 |
+
constant int & idx,
|
| 6056 |
+
device const char * src00,
|
| 6057 |
+
device const char * src01,
|
| 6058 |
+
device const char * src02,
|
| 6059 |
+
device const char * src03,
|
| 6060 |
+
device const char * src04,
|
| 6061 |
+
device const char * src05,
|
| 6062 |
+
device const char * src06,
|
| 6063 |
+
device const char * src07,
|
| 6064 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6065 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6066 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 6067 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 6068 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6069 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6070 |
+
|
| 6071 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6072 |
+
|
| 6073 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6074 |
+
|
| 6075 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6076 |
+
|
| 6077 |
+
kernel_mul_mv_iq3_xxs_f32_impl(
|
| 6078 |
+
src0[id],
|
| 6079 |
+
(device const float *) (src1 + bid*nb11),
|
| 6080 |
+
dst + bid*ne0,
|
| 6081 |
+
ne00,
|
| 6082 |
+
ne01,
|
| 6083 |
+
ne02,
|
| 6084 |
+
ne10,
|
| 6085 |
+
ne12,
|
| 6086 |
+
ne0,
|
| 6087 |
+
ne1,
|
| 6088 |
+
r2,
|
| 6089 |
+
r3,
|
| 6090 |
+
shared_values,
|
| 6091 |
+
tgpig,
|
| 6092 |
+
tiisg,
|
| 6093 |
+
sgitg);
|
| 6094 |
+
}
|
|
@@ -3441,6 +3441,41 @@ static const uint64_t iq2xs_grid[512] = {
|
|
| 3441 |
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
| 3442 |
};
|
| 3443 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3444 |
static const uint8_t ksigns_iq2xs[128] = {
|
| 3445 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 3446 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
@@ -3507,6 +3542,38 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y,
|
|
| 3507 |
}
|
| 3508 |
}
|
| 3509 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3510 |
//===================================== Q8_K ==============================================
|
| 3511 |
|
| 3512 |
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
|
@@ -8551,6 +8618,136 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|
| 8551 |
#endif
|
| 8552 |
}
|
| 8553 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8554 |
// ================================ IQ2 quantization =============================================
|
| 8555 |
|
| 8556 |
typedef struct {
|
|
@@ -9189,3 +9386,436 @@ size_t quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, i
|
|
| 9189 |
return nrow * nblock * sizeof(block_iq2_xs);
|
| 9190 |
}
|
| 9191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3441 |
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
| 3442 |
};
|
| 3443 |
|
| 3444 |
+
static const uint32_t iq3xxs_grid[256] = {
|
| 3445 |
+
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
|
| 3446 |
+
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
| 3447 |
+
0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
|
| 3448 |
+
0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
|
| 3449 |
+
0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
|
| 3450 |
+
0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
|
| 3451 |
+
0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
|
| 3452 |
+
0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
|
| 3453 |
+
0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
|
| 3454 |
+
0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
|
| 3455 |
+
0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
|
| 3456 |
+
0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
|
| 3457 |
+
0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
|
| 3458 |
+
0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
|
| 3459 |
+
0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
|
| 3460 |
+
0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
|
| 3461 |
+
0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
|
| 3462 |
+
0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
|
| 3463 |
+
0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
|
| 3464 |
+
0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
|
| 3465 |
+
0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
|
| 3466 |
+
0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
|
| 3467 |
+
0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
|
| 3468 |
+
0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
|
| 3469 |
+
0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
|
| 3470 |
+
0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
|
| 3471 |
+
0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
|
| 3472 |
+
0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
|
| 3473 |
+
0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
|
| 3474 |
+
0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
|
| 3475 |
+
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
|
| 3476 |
+
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 3477 |
+
};
|
| 3478 |
+
|
| 3479 |
static const uint8_t ksigns_iq2xs[128] = {
|
| 3480 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 3481 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
|
|
| 3542 |
}
|
| 3543 |
}
|
| 3544 |
|
| 3545 |
+
// ====================== 3.0625 bpw (de)-quantization
|
| 3546 |
+
|
| 3547 |
+
void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k) {
|
| 3548 |
+
assert(k % QK_K == 0);
|
| 3549 |
+
const int nb = k / QK_K;
|
| 3550 |
+
|
| 3551 |
+
uint32_t aux32;
|
| 3552 |
+
|
| 3553 |
+
for (int i = 0; i < nb; i++) {
|
| 3554 |
+
|
| 3555 |
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
| 3556 |
+
const uint8_t * qs = x[i].qs;
|
| 3557 |
+
const uint8_t * scales_and_signs = qs + QK_K/4;
|
| 3558 |
+
|
| 3559 |
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
| 3560 |
+
memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t));
|
| 3561 |
+
const float db = d * (0.5f + (aux32 >> 28)) * 0.5f;
|
| 3562 |
+
for (int l = 0; l < 4; ++l) {
|
| 3563 |
+
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
|
| 3564 |
+
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]);
|
| 3565 |
+
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]);
|
| 3566 |
+
for (int j = 0; j < 4; ++j) {
|
| 3567 |
+
y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
| 3568 |
+
y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
| 3569 |
+
}
|
| 3570 |
+
y += 8;
|
| 3571 |
+
}
|
| 3572 |
+
qs += 8;
|
| 3573 |
+
}
|
| 3574 |
+
}
|
| 3575 |
+
}
|
| 3576 |
+
|
| 3577 |
//===================================== Q8_K ==============================================
|
| 3578 |
|
| 3579 |
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
|
|
|
| 8618 |
#endif
|
| 8619 |
}
|
| 8620 |
|
| 8621 |
+
// TODO
|
| 8622 |
+
void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
| 8623 |
+
assert(n % QK_K == 0);
|
| 8624 |
+
|
| 8625 |
+
const block_iq3_xxs * restrict x = vx;
|
| 8626 |
+
const block_q8_K * restrict y = vy;
|
| 8627 |
+
|
| 8628 |
+
const int nb = n / QK_K;
|
| 8629 |
+
|
| 8630 |
+
#if defined(__ARM_NEON)
|
| 8631 |
+
|
| 8632 |
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
| 8633 |
+
|
| 8634 |
+
uint32_t aux32[2];
|
| 8635 |
+
|
| 8636 |
+
ggml_int8x16x4_t q3s;
|
| 8637 |
+
ggml_int8x16x4_t q8b;
|
| 8638 |
+
|
| 8639 |
+
float sumf = 0;
|
| 8640 |
+
for (int i = 0; i < nb; ++i) {
|
| 8641 |
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
| 8642 |
+
const uint8_t * restrict q3 = x[i].qs;
|
| 8643 |
+
const uint8_t * restrict gas = x[i].qs + QK_K/4;
|
| 8644 |
+
const int8_t * restrict q8 = y[i].qs;
|
| 8645 |
+
float sumf1 = 0, sumf2 = 0;
|
| 8646 |
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
| 8647 |
+
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
| 8648 |
+
memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
|
| 8649 |
+
const uint32x4_t aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};
|
| 8650 |
+
const uint32x4_t aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};
|
| 8651 |
+
const uint32x4_t aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};
|
| 8652 |
+
const uint32x4_t aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};
|
| 8653 |
+
q3 += 16;
|
| 8654 |
+
q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
|
| 8655 |
+
q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
|
| 8656 |
+
q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
|
| 8657 |
+
q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
|
| 8658 |
+
q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
|
| 8659 |
+
q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
|
| 8660 |
+
q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
|
| 8661 |
+
q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
|
| 8662 |
+
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
|
| 8663 |
+
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
| 8664 |
+
sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
|
| 8665 |
+
sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
|
| 8666 |
+
}
|
| 8667 |
+
sumf += d*(sumf1 + sumf2);
|
| 8668 |
+
}
|
| 8669 |
+
*s = 0.5f * sumf;
|
| 8670 |
+
|
| 8671 |
+
#elif defined(__AVX2__)
|
| 8672 |
+
|
| 8673 |
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
| 8674 |
+
|
| 8675 |
+
uint32_t aux32[2];
|
| 8676 |
+
|
| 8677 |
+
__m256 accumf = _mm256_setzero_ps();
|
| 8678 |
+
for (int i = 0; i < nb; ++i) {
|
| 8679 |
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
| 8680 |
+
const uint8_t * restrict q3 = x[i].qs;
|
| 8681 |
+
const uint8_t * restrict gas = x[i].qs + QK_K/4;
|
| 8682 |
+
const int8_t * restrict q8 = y[i].qs;
|
| 8683 |
+
__m256i sumi1 = _mm256_setzero_si256();
|
| 8684 |
+
__m256i sumi2 = _mm256_setzero_si256();
|
| 8685 |
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
| 8686 |
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 8687 |
+
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 8688 |
+
const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
|
| 8689 |
+
iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
|
| 8690 |
+
q3 += 8;
|
| 8691 |
+
const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
|
| 8692 |
+
iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
|
| 8693 |
+
q3 += 8;
|
| 8694 |
+
memcpy(aux32, gas, 8); gas += 8;
|
| 8695 |
+
const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
|
| 8696 |
+
signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
|
| 8697 |
+
const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
|
| 8698 |
+
signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
|
| 8699 |
+
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
|
| 8700 |
+
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
|
| 8701 |
+
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
| 8702 |
+
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
| 8703 |
+
const uint16_t ls1 = aux32[0] >> 28;
|
| 8704 |
+
const uint16_t ls2 = aux32[1] >> 28;
|
| 8705 |
+
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
|
| 8706 |
+
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
|
| 8707 |
+
sumi1 = _mm256_add_epi32(sumi1, p1);
|
| 8708 |
+
sumi2 = _mm256_add_epi32(sumi2, p2);
|
| 8709 |
+
}
|
| 8710 |
+
|
| 8711 |
+
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
|
| 8712 |
+
|
| 8713 |
+
}
|
| 8714 |
+
|
| 8715 |
+
*s = 0.25f * hsum_float_8(accumf);
|
| 8716 |
+
|
| 8717 |
+
#else
|
| 8718 |
+
|
| 8719 |
+
uint32_t aux32;
|
| 8720 |
+
|
| 8721 |
+
float sumf = 0.f;
|
| 8722 |
+
for (int i = 0; i < nb; ++i) {
|
| 8723 |
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
| 8724 |
+
const uint8_t * restrict q3 = x[i].qs;
|
| 8725 |
+
const uint8_t * restrict gas = x[i].qs + QK_K/4;
|
| 8726 |
+
const int8_t * restrict q8 = y[i].qs;
|
| 8727 |
+
int32_t bsum = 0;
|
| 8728 |
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
| 8729 |
+
memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
|
| 8730 |
+
const uint32_t ls = 2*(aux32 >> 28) + 1;
|
| 8731 |
+
int32_t sumi = 0;
|
| 8732 |
+
for (int l = 0; l < 4; ++l) {
|
| 8733 |
+
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
|
| 8734 |
+
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
|
| 8735 |
+
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
|
| 8736 |
+
for (int j = 0; j < 4; ++j) {
|
| 8737 |
+
sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
|
| 8738 |
+
sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
|
| 8739 |
+
}
|
| 8740 |
+
q8 += 8;
|
| 8741 |
+
}
|
| 8742 |
+
q3 += 8;
|
| 8743 |
+
bsum += sumi * ls;
|
| 8744 |
+
}
|
| 8745 |
+
sumf += d * bsum;
|
| 8746 |
+
}
|
| 8747 |
+
*s = 0.25f * sumf;
|
| 8748 |
+
#endif
|
| 8749 |
+
}
|
| 8750 |
+
|
| 8751 |
// ================================ IQ2 quantization =============================================
|
| 8752 |
|
| 8753 |
typedef struct {
|
|
|
|
| 9386 |
return nrow * nblock * sizeof(block_iq2_xs);
|
| 9387 |
}
|
| 9388 |
|
| 9389 |
+
//
|
| 9390 |
+
// ============================================= 3-bit using D4 lattice
|
| 9391 |
+
//
|
| 9392 |
+
|
| 9393 |
+
typedef struct {
|
| 9394 |
+
uint32_t * grid;
|
| 9395 |
+
int * map;
|
| 9396 |
+
uint16_t * neighbours;
|
| 9397 |
+
} iq3_entry_t;
|
| 9398 |
+
|
| 9399 |
+
static iq3_entry_t iq3_data[1] = {
|
| 9400 |
+
{NULL, NULL, NULL},
|
| 9401 |
+
};
|
| 9402 |
+
|
| 9403 |
+
static inline int iq3_data_index(int grid_size) {
|
| 9404 |
+
(void)grid_size;
|
| 9405 |
+
GGML_ASSERT(grid_size == 256);
|
| 9406 |
+
return 0;
|
| 9407 |
+
}
|
| 9408 |
+
|
| 9409 |
+
static int iq3_compare_func(const void * left, const void * right) {
|
| 9410 |
+
const int * l = (const int *)left;
|
| 9411 |
+
const int * r = (const int *)right;
|
| 9412 |
+
return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
|
| 9413 |
+
}
|
| 9414 |
+
|
| 9415 |
+
void iq3xs_init_impl(int grid_size) {
|
| 9416 |
+
const int gindex = iq3_data_index(grid_size);
|
| 9417 |
+
if (iq3_data[gindex].grid) {
|
| 9418 |
+
return;
|
| 9419 |
+
}
|
| 9420 |
+
static const uint16_t kgrid_256[256] = {
|
| 9421 |
+
0, 2, 4, 9, 11, 15, 16, 18, 25, 34, 59, 61, 65, 67, 72, 74,
|
| 9422 |
+
81, 85, 88, 90, 97, 108, 120, 128, 130, 132, 137, 144, 146, 153, 155, 159,
|
| 9423 |
+
169, 175, 189, 193, 199, 200, 202, 213, 248, 267, 287, 292, 303, 315, 317, 321,
|
| 9424 |
+
327, 346, 362, 413, 436, 456, 460, 462, 483, 497, 513, 515, 520, 522, 529, 531,
|
| 9425 |
+
536, 538, 540, 551, 552, 576, 578, 585, 592, 594, 641, 643, 648, 650, 657, 664,
|
| 9426 |
+
698, 704, 706, 720, 729, 742, 758, 769, 773, 808, 848, 852, 870, 889, 901, 978,
|
| 9427 |
+
992, 1024, 1026, 1033, 1035, 1040, 1042, 1046, 1049, 1058, 1089, 1091, 1093, 1096, 1098, 1105,
|
| 9428 |
+
1112, 1139, 1143, 1144, 1152, 1154, 1161, 1167, 1168, 1170, 1183, 1184, 1197, 1217, 1224, 1228,
|
| 9429 |
+
1272, 1276, 1309, 1323, 1347, 1367, 1377, 1404, 1473, 1475, 1486, 1509, 1537, 1544, 1546, 1553,
|
| 9430 |
+
1555, 1576, 1589, 1594, 1600, 1602, 1616, 1625, 1636, 1638, 1665, 1667, 1672, 1685, 1706, 1722,
|
| 9431 |
+
1737, 1755, 1816, 1831, 1850, 1856, 1862, 1874, 1901, 1932, 1950, 1971, 2011, 2032, 2052, 2063,
|
| 9432 |
+
2077, 2079, 2091, 2095, 2172, 2192, 2207, 2208, 2224, 2230, 2247, 2277, 2308, 2345, 2356, 2389,
|
| 9433 |
+
2403, 2424, 2501, 2504, 2506, 2520, 2570, 2593, 2616, 2624, 2630, 2646, 2669, 2700, 2714, 2746,
|
| 9434 |
+
2754, 2795, 2824, 2835, 2839, 2874, 2882, 2905, 2984, 3028, 3042, 3092, 3108, 3110, 3124, 3153,
|
| 9435 |
+
3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610,
|
| 9436 |
+
3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992,
|
| 9437 |
+
};
|
| 9438 |
+
const int kmap_size = 4096;
|
| 9439 |
+
const int nwant = 2;
|
| 9440 |
+
const uint16_t * kgrid = kgrid_256;
|
| 9441 |
+
uint32_t * kgrid_q3xs;
|
| 9442 |
+
int * kmap_q3xs;
|
| 9443 |
+
uint16_t * kneighbors_q3xs;
|
| 9444 |
+
|
| 9445 |
+
printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
|
| 9446 |
+
uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t));
|
| 9447 |
+
for (int k = 0; k < grid_size; ++k) {
|
| 9448 |
+
int8_t * pos = (int8_t *)(the_grid + k);
|
| 9449 |
+
for (int i = 0; i < 4; ++i) {
|
| 9450 |
+
int l = (kgrid[k] >> 3*i) & 0x7;
|
| 9451 |
+
pos[i] = 2*l + 1;
|
| 9452 |
+
}
|
| 9453 |
+
}
|
| 9454 |
+
kgrid_q3xs = the_grid;
|
| 9455 |
+
iq3_data[gindex].grid = the_grid;
|
| 9456 |
+
kmap_q3xs = (int *)malloc(kmap_size*sizeof(int));
|
| 9457 |
+
iq3_data[gindex].map = kmap_q3xs;
|
| 9458 |
+
for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1;
|
| 9459 |
+
uint32_t aux32;
|
| 9460 |
+
uint8_t * aux8 = (uint8_t *)&aux32;
|
| 9461 |
+
for (int i = 0; i < grid_size; ++i) {
|
| 9462 |
+
aux32 = kgrid_q3xs[i];
|
| 9463 |
+
uint16_t index = 0;
|
| 9464 |
+
for (int k=0; k<4; ++k) {
|
| 9465 |
+
uint16_t q = (aux8[k] - 1)/2;
|
| 9466 |
+
index |= (q << 3*k);
|
| 9467 |
+
}
|
| 9468 |
+
kmap_q3xs[index] = i;
|
| 9469 |
+
}
|
| 9470 |
+
int8_t pos[4];
|
| 9471 |
+
int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
|
| 9472 |
+
int num_neighbors = 0, num_not_in_map = 0;
|
| 9473 |
+
for (int i = 0; i < kmap_size; ++i) {
|
| 9474 |
+
if (kmap_q3xs[i] >= 0) continue;
|
| 9475 |
+
++num_not_in_map;
|
| 9476 |
+
for (int k = 0; k < 4; ++k) {
|
| 9477 |
+
int l = (i >> 3*k) & 0x7;
|
| 9478 |
+
pos[k] = 2*l + 1;
|
| 9479 |
+
}
|
| 9480 |
+
for (int j = 0; j < grid_size; ++j) {
|
| 9481 |
+
const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
|
| 9482 |
+
int d2 = 0;
|
| 9483 |
+
for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
|
| 9484 |
+
dist2[2*j+0] = d2;
|
| 9485 |
+
dist2[2*j+1] = j;
|
| 9486 |
+
}
|
| 9487 |
+
qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
|
| 9488 |
+
int n = 0; int d2 = dist2[0];
|
| 9489 |
+
int nhave = 1;
|
| 9490 |
+
for (int j = 0; j < grid_size; ++j) {
|
| 9491 |
+
if (dist2[2*j] > d2) {
|
| 9492 |
+
if (nhave == nwant) break;
|
| 9493 |
+
d2 = dist2[2*j];
|
| 9494 |
+
++nhave;
|
| 9495 |
+
}
|
| 9496 |
+
++n;
|
| 9497 |
+
}
|
| 9498 |
+
num_neighbors += n;
|
| 9499 |
+
}
|
| 9500 |
+
printf("%s: %d neighbours in total\n", __func__, num_neighbors);
|
| 9501 |
+
kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
|
| 9502 |
+
iq3_data[gindex].neighbours = kneighbors_q3xs;
|
| 9503 |
+
int counter = 0;
|
| 9504 |
+
for (int i = 0; i < kmap_size; ++i) {
|
| 9505 |
+
if (kmap_q3xs[i] >= 0) continue;
|
| 9506 |
+
for (int k = 0; k < 4; ++k) {
|
| 9507 |
+
int l = (i >> 3*k) & 0x7;
|
| 9508 |
+
pos[k] = 2*l + 1;
|
| 9509 |
+
}
|
| 9510 |
+
for (int j = 0; j < grid_size; ++j) {
|
| 9511 |
+
const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
|
| 9512 |
+
int d2 = 0;
|
| 9513 |
+
for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
|
| 9514 |
+
dist2[2*j+0] = d2;
|
| 9515 |
+
dist2[2*j+1] = j;
|
| 9516 |
+
}
|
| 9517 |
+
qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
|
| 9518 |
+
kmap_q3xs[i] = -(counter + 1);
|
| 9519 |
+
int d2 = dist2[0];
|
| 9520 |
+
uint16_t * start = &kneighbors_q3xs[counter++];
|
| 9521 |
+
int n = 0, nhave = 1;
|
| 9522 |
+
for (int j = 0; j < grid_size; ++j) {
|
| 9523 |
+
if (dist2[2*j] > d2) {
|
| 9524 |
+
if (nhave == nwant) break;
|
| 9525 |
+
d2 = dist2[2*j];
|
| 9526 |
+
++nhave;
|
| 9527 |
+
}
|
| 9528 |
+
kneighbors_q3xs[counter++] = dist2[2*j+1];
|
| 9529 |
+
++n;
|
| 9530 |
+
}
|
| 9531 |
+
*start = n;
|
| 9532 |
+
}
|
| 9533 |
+
free(dist2);
|
| 9534 |
+
}
|
| 9535 |
+
|
| 9536 |
+
void iq3xs_free_impl(int grid_size) {
|
| 9537 |
+
GGML_ASSERT(grid_size == 256);
|
| 9538 |
+
const int gindex = iq3_data_index(grid_size);
|
| 9539 |
+
if (iq3_data[gindex].grid) {
|
| 9540 |
+
free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL;
|
| 9541 |
+
free(iq3_data[gindex].map); iq3_data[gindex].map = NULL;
|
| 9542 |
+
free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL;
|
| 9543 |
+
}
|
| 9544 |
+
}
|
| 9545 |
+
|
| 9546 |
+
static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const uint32_t * restrict grid,
|
| 9547 |
+
const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
|
| 9548 |
+
int num_neighbors = neighbours[0];
|
| 9549 |
+
GGML_ASSERT(num_neighbors > 0);
|
| 9550 |
+
float best_d2 = FLT_MAX;
|
| 9551 |
+
int grid_index = -1;
|
| 9552 |
+
for (int j = 1; j <= num_neighbors; ++j) {
|
| 9553 |
+
const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
|
| 9554 |
+
float d2 = 0;
|
| 9555 |
+
for (int i = 0; i < 4; ++i) {
|
| 9556 |
+
float q = pg[i];
|
| 9557 |
+
float diff = scale*q - xval[i];
|
| 9558 |
+
d2 += weight[i]*diff*diff;
|
| 9559 |
+
}
|
| 9560 |
+
if (d2 < best_d2) {
|
| 9561 |
+
best_d2 = d2; grid_index = neighbours[j];
|
| 9562 |
+
}
|
| 9563 |
+
}
|
| 9564 |
+
GGML_ASSERT(grid_index >= 0);
|
| 9565 |
+
const int8_t * pg = (const int8_t *)(grid + grid_index);
|
| 9566 |
+
for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2;
|
| 9567 |
+
return grid_index;
|
| 9568 |
+
}
|
| 9569 |
+
|
| 9570 |
+
static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
| 9571 |
+
|
| 9572 |
+
const int gindex = iq3_data_index(256);
|
| 9573 |
+
|
| 9574 |
+
const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
|
| 9575 |
+
const int * kmap_q3xs = iq3_data[gindex].map;
|
| 9576 |
+
const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
|
| 9577 |
+
|
| 9578 |
+
//GGML_ASSERT(quant_weights && "missing quantization weights");
|
| 9579 |
+
GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?");
|
| 9580 |
+
GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?");
|
| 9581 |
+
GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
|
| 9582 |
+
GGML_ASSERT(n%QK_K == 0);
|
| 9583 |
+
|
| 9584 |
+
const int kMaxQ = 8;
|
| 9585 |
+
|
| 9586 |
+
const int nbl = n/256;
|
| 9587 |
+
|
| 9588 |
+
block_iq3_xxs * y = vy;
|
| 9589 |
+
|
| 9590 |
+
float scales[QK_K/32];
|
| 9591 |
+
float weight[32];
|
| 9592 |
+
float xval[32];
|
| 9593 |
+
int8_t L[32];
|
| 9594 |
+
int8_t Laux[32];
|
| 9595 |
+
float waux[32];
|
| 9596 |
+
bool is_on_grid[8];
|
| 9597 |
+
bool is_on_grid_aux[8];
|
| 9598 |
+
uint8_t block_signs[8];
|
| 9599 |
+
uint8_t q3[3*(QK_K/8)];
|
| 9600 |
+
uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
|
| 9601 |
+
|
| 9602 |
+
for (int ibl = 0; ibl < nbl; ++ibl) {
|
| 9603 |
+
|
| 9604 |
+
y[ibl].d = GGML_FP32_TO_FP16(0.f);
|
| 9605 |
+
memset(q3, 0, 3*QK_K/8);
|
| 9606 |
+
|
| 9607 |
+
float max_scale = 0;
|
| 9608 |
+
|
| 9609 |
+
const float * xbl = x + QK_K*ibl;
|
| 9610 |
+
float sumx2 = 0;
|
| 9611 |
+
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
|
| 9612 |
+
float sigma2 = sumx2/QK_K;
|
| 9613 |
+
|
| 9614 |
+
for (int ib = 0; ib < QK_K/32; ++ib) {
|
| 9615 |
+
const float * xb = xbl + 32*ib;
|
| 9616 |
+
if (quant_weights) {
|
| 9617 |
+
const float * qw = quant_weights + QK_K*ibl + 32*ib;
|
| 9618 |
+
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
| 9619 |
+
} else {
|
| 9620 |
+
for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
|
| 9621 |
+
}
|
| 9622 |
+
for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
|
| 9623 |
+
for (int k = 0; k < 4; ++k) {
|
| 9624 |
+
int nflip = 0;
|
| 9625 |
+
uint8_t s = 0;
|
| 9626 |
+
for (int i = 0; i < 8; ++i) {
|
| 9627 |
+
if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
|
| 9628 |
+
else {
|
| 9629 |
+
xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
|
| 9630 |
+
}
|
| 9631 |
+
}
|
| 9632 |
+
if (nflip%2) {
|
| 9633 |
+
int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
|
| 9634 |
+
for (int i = 1; i < 8; ++i) {
|
| 9635 |
+
float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
|
| 9636 |
+
if (ax < min) {
|
| 9637 |
+
min = ax; imin = i;
|
| 9638 |
+
}
|
| 9639 |
+
}
|
| 9640 |
+
xval[8*k+imin] = -xval[8*k+imin];
|
| 9641 |
+
s ^= (1 << imin);
|
| 9642 |
+
}
|
| 9643 |
+
block_signs[k] = s & 127;
|
| 9644 |
+
}
|
| 9645 |
+
float max = xval[0];
|
| 9646 |
+
for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
|
| 9647 |
+
if (!max) {
|
| 9648 |
+
scales[ib] = 0;
|
| 9649 |
+
memset(L, 0, 32);
|
| 9650 |
+
continue;
|
| 9651 |
+
}
|
| 9652 |
+
float best = 0;
|
| 9653 |
+
float scale = max/(2*kMaxQ-1);
|
| 9654 |
+
for (int is = -15; is <= 15; ++is) {
|
| 9655 |
+
float id = (2*kMaxQ-1+is*0.2f)/max;
|
| 9656 |
+
float this_scale = 1/id;
|
| 9657 |
+
for (int k = 0; k < 8; ++k) {
|
| 9658 |
+
for (int i = 0; i < 4; ++i) {
|
| 9659 |
+
int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
|
| 9660 |
+
Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
|
| 9661 |
+
}
|
| 9662 |
+
uint16_t u = 0;
|
| 9663 |
+
for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
|
| 9664 |
+
int grid_index = kmap_q3xs[u];
|
| 9665 |
+
is_on_grid_aux[k] = true;
|
| 9666 |
+
if (grid_index < 0) {
|
| 9667 |
+
is_on_grid_aux[k] = false;
|
| 9668 |
+
const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
|
| 9669 |
+
grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
|
| 9670 |
+
}
|
| 9671 |
+
}
|
| 9672 |
+
float sumqx = 0, sumq2 = 0;
|
| 9673 |
+
for (int i = 0; i < 32; ++i) {
|
| 9674 |
+
float w = weight[i];
|
| 9675 |
+
float q = 2*Laux[i] + 1;
|
| 9676 |
+
sumqx += w*xval[i]*q;
|
| 9677 |
+
sumq2 += w*q*q;
|
| 9678 |
+
}
|
| 9679 |
+
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
|
| 9680 |
+
scale = sumqx/sumq2; best = scale*sumqx;
|
| 9681 |
+
for (int i = 0; i < 32; ++i) L[i] = Laux[i];
|
| 9682 |
+
for (int k = 0; k < 8; ++k) is_on_grid[k] = is_on_grid_aux[k];
|
| 9683 |
+
}
|
| 9684 |
+
}
|
| 9685 |
+
int n_not_ongrid = 0;
|
| 9686 |
+
for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
|
| 9687 |
+
if (n_not_ongrid > 0 && scale > 0) {
|
| 9688 |
+
float id = 1/scale;
|
| 9689 |
+
for (int k = 0; k < 8; ++k) {
|
| 9690 |
+
if (is_on_grid[k]) continue;
|
| 9691 |
+
uint16_t u = 0;
|
| 9692 |
+
for (int i = 0; i < 4; ++i) {
|
| 9693 |
+
int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
|
| 9694 |
+
l = MAX(0, MIN(kMaxQ-1, l));
|
| 9695 |
+
u |= (l << 3*i);
|
| 9696 |
+
}
|
| 9697 |
+
int grid_index = kmap_q3xs[u];
|
| 9698 |
+
if (grid_index < 0) {
|
| 9699 |
+
const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
|
| 9700 |
+
grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
|
| 9701 |
+
}
|
| 9702 |
+
const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
|
| 9703 |
+
for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
|
| 9704 |
+
}
|
| 9705 |
+
float sumqx = 0, sumq2 = 0;
|
| 9706 |
+
for (int i = 0; i < 32; ++i) {
|
| 9707 |
+
float w = weight[i];
|
| 9708 |
+
float q = 2*L[i] + 1;
|
| 9709 |
+
sumqx += w*xval[i]*q;
|
| 9710 |
+
sumq2 += w*q*q;
|
| 9711 |
+
}
|
| 9712 |
+
if (sumq2 > 0) scale = sumqx/sumq2;
|
| 9713 |
+
}
|
| 9714 |
+
if (scale < 0) {
|
| 9715 |
+
// This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
|
| 9716 |
+
// and correspondingly flip quant signs.
|
| 9717 |
+
scale = -scale;
|
| 9718 |
+
for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
|
| 9719 |
+
}
|
| 9720 |
+
for (int k = 0; k < 8; ++k) {
|
| 9721 |
+
uint16_t u = 0;
|
| 9722 |
+
for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
|
| 9723 |
+
int grid_index = kmap_q3xs[u];
|
| 9724 |
+
if (grid_index < 0) {
|
| 9725 |
+
printf("Oops: found point %u not on grid:", u);
|
| 9726 |
+
for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
|
| 9727 |
+
printf("\n");
|
| 9728 |
+
GGML_ASSERT(false);
|
| 9729 |
+
}
|
| 9730 |
+
q3[8*ib+k] = grid_index;
|
| 9731 |
+
}
|
| 9732 |
+
scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
|
| 9733 |
+
GGML_ASSERT(scale >= 0);
|
| 9734 |
+
scales[ib] = scale;
|
| 9735 |
+
max_scale = MAX(max_scale, scale);
|
| 9736 |
+
}
|
| 9737 |
+
|
| 9738 |
+
if (!max_scale) {
|
| 9739 |
+
memset(y[ibl].qs, 0, 3*QK_K/8);
|
| 9740 |
+
continue;
|
| 9741 |
+
}
|
| 9742 |
+
|
| 9743 |
+
float d = max_scale/31;
|
| 9744 |
+
y[ibl].d = GGML_FP32_TO_FP16(d);
|
| 9745 |
+
float id = 1/d;
|
| 9746 |
+
float sumqx = 0, sumq2 = 0;
|
| 9747 |
+
for (int ib = 0; ib < QK_K/32; ++ib) {
|
| 9748 |
+
int l = nearest_int(0.5f*(id*scales[ib]-1));
|
| 9749 |
+
l = MAX(0, MIN(15, l));
|
| 9750 |
+
scales_and_signs[ib] |= ((uint32_t)l << 28);
|
| 9751 |
+
if (false) {
|
| 9752 |
+
const float * xb = xbl + 32*ib;
|
| 9753 |
+
if (quant_weights) {
|
| 9754 |
+
const float * qw = quant_weights + QK_K*ibl + 32*ib;
|
| 9755 |
+
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
| 9756 |
+
} else {
|
| 9757 |
+
for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
|
| 9758 |
+
}
|
| 9759 |
+
const float db = 0.25f * d * (1 + 2*l);
|
| 9760 |
+
for (int k = 0; k < 8; ++k) {
|
| 9761 |
+
const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2);
|
| 9762 |
+
const float * xk = xb + 4*k;
|
| 9763 |
+
const float * wk = weight + 4*k;
|
| 9764 |
+
//const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]);
|
| 9765 |
+
const uint8_t * grid = (const uint8_t *)(iq3xxs_grid + q3[8*ib+k]);
|
| 9766 |
+
float best_mse = 0; int best_index = q3[8*ib+k];
|
| 9767 |
+
for (int j = 0; j < 4; ++j) {
|
| 9768 |
+
float diff = db * grid[j] * signs[j] - xk[j];
|
| 9769 |
+
best_mse += wk[j] * diff * diff;
|
| 9770 |
+
}
|
| 9771 |
+
for (int idx = 0; idx < 256; ++idx) {
|
| 9772 |
+
//grid = (const uint8_t *)(kgrid_q3xs + idx);
|
| 9773 |
+
grid = (const uint8_t *)(iq3xxs_grid + idx);
|
| 9774 |
+
float mse = 0;
|
| 9775 |
+
for (int j = 0; j < 4; ++j) {
|
| 9776 |
+
float diff = db * grid[j] * signs[j] - xk[j];
|
| 9777 |
+
mse += wk[j] * diff * diff;
|
| 9778 |
+
}
|
| 9779 |
+
if (mse < best_mse) {
|
| 9780 |
+
best_mse = mse; best_index = idx;
|
| 9781 |
+
}
|
| 9782 |
+
}
|
| 9783 |
+
q3[8*ib+k] = best_index;
|
| 9784 |
+
//grid = (const uint8_t *)(kgrid_q3xs + best_index);
|
| 9785 |
+
grid = (const uint8_t *)(iq3xxs_grid + best_index);
|
| 9786 |
+
for (int j = 0; j < 4; ++j) {
|
| 9787 |
+
float q = db * grid[j] * signs[j];
|
| 9788 |
+
sumqx += wk[j] * q * xk[j];
|
| 9789 |
+
sumq2 += wk[j] * q * q;
|
| 9790 |
+
}
|
| 9791 |
+
}
|
| 9792 |
+
if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
|
| 9793 |
+
}
|
| 9794 |
+
}
|
| 9795 |
+
memcpy(y[ibl].qs, q3, 3*QK_K/8);
|
| 9796 |
+
}
|
| 9797 |
+
}
|
| 9798 |
+
|
| 9799 |
+
size_t quantize_iq3_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
| 9800 |
+
(void)hist;
|
| 9801 |
+
GGML_ASSERT(n_per_row%QK_K == 0);
|
| 9802 |
+
int nblock = n_per_row/QK_K;
|
| 9803 |
+
char * qrow = (char *)dst;
|
| 9804 |
+
for (int row = 0; row < nrow; ++row) {
|
| 9805 |
+
quantize_row_iq3_xxs_impl(src, qrow, n_per_row, quant_weights);
|
| 9806 |
+
src += n_per_row;
|
| 9807 |
+
qrow += nblock*sizeof(block_iq3_xxs);
|
| 9808 |
+
}
|
| 9809 |
+
return nrow * nblock * sizeof(block_iq3_xxs);
|
| 9810 |
+
}
|
| 9811 |
+
|
| 9812 |
+
void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) {
|
| 9813 |
+
assert(k % QK_K == 0);
|
| 9814 |
+
block_iq3_xxs * restrict y = vy;
|
| 9815 |
+
quantize_row_iq3_xxs_reference(x, y, k);
|
| 9816 |
+
}
|
| 9817 |
+
|
| 9818 |
+
void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
|
| 9819 |
+
assert(k % QK_K == 0);
|
| 9820 |
+
quantize_row_iq3_xxs_impl(x, y, k, NULL);
|
| 9821 |
+
}
|
|
@@ -166,7 +166,7 @@ typedef struct {
|
|
| 166 |
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
|
| 167 |
|
| 168 |
// (Almost) "true" 2-bit quantization.
|
| 169 |
-
// Due to the need to use blocks as per ggml
|
| 170 |
// 2.0625 bpw because of the 16-bit scale for each block of 256.
|
| 171 |
typedef struct {
|
| 172 |
ggml_fp16_t d;
|
|
@@ -182,6 +182,15 @@ typedef struct {
|
|
| 182 |
} block_iq2_xs;
|
| 183 |
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
// Quantization
|
| 186 |
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
|
| 187 |
void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
|
|
@@ -196,6 +205,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
|
| 196 |
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
|
| 197 |
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
|
| 198 |
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
|
|
|
|
| 199 |
|
| 200 |
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
|
| 201 |
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
|
|
@@ -210,6 +220,7 @@ void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
|
|
| 210 |
void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
|
| 211 |
void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
|
| 212 |
void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
|
|
|
|
| 213 |
|
| 214 |
// Dequantization
|
| 215 |
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
|
|
@@ -227,6 +238,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
|
|
| 227 |
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
|
| 228 |
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
|
| 229 |
void dequantize_row_iq2_xs (const block_iq2_xs * restrict x, float * restrict y, int k);
|
|
|
|
| 230 |
|
| 231 |
// Dot product
|
| 232 |
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
|
@@ -242,12 +254,14 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx,
|
|
| 242 |
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 243 |
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 244 |
void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
|
|
|
| 245 |
|
| 246 |
//
|
| 247 |
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
| 248 |
//
|
| 249 |
size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 250 |
size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
|
|
|
| 251 |
size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 252 |
size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 253 |
size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
|
@@ -260,3 +274,5 @@ size_t quantize_q5_1 (const float * src, void * dst, int nrows, int n_per_row,
|
|
| 260 |
|
| 261 |
void iq2xs_init_impl(int grid_size);
|
| 262 |
void iq2xs_free_impl(int grid_size);
|
|
|
|
|
|
|
|
|
| 166 |
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
|
| 167 |
|
| 168 |
// (Almost) "true" 2-bit quantization.
|
| 169 |
+
// Due to the need to use blocks as per ggml design, it ends up using
|
| 170 |
// 2.0625 bpw because of the 16-bit scale for each block of 256.
|
| 171 |
typedef struct {
|
| 172 |
ggml_fp16_t d;
|
|
|
|
| 182 |
} block_iq2_xs;
|
| 183 |
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
|
| 184 |
|
| 185 |
+
// (Almost) "true" 3-bit quantization.
|
| 186 |
+
// Due to the need to use blocks as per ggml design, it ends up using
|
| 187 |
+
// 3.0625 bpw because of the 16-bit scale for each block of 256.
|
| 188 |
+
typedef struct {
|
| 189 |
+
ggml_fp16_t d;
|
| 190 |
+
uint8_t qs[3*QK_K/8];
|
| 191 |
+
} block_iq3_xxs;
|
| 192 |
+
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
|
| 193 |
+
|
| 194 |
// Quantization
|
| 195 |
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
|
| 196 |
void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
|
|
|
|
| 205 |
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
|
| 206 |
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
|
| 207 |
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
|
| 208 |
+
void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k);
|
| 209 |
|
| 210 |
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
|
| 211 |
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
|
|
|
|
| 220 |
void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
|
| 221 |
void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
|
| 222 |
void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
|
| 223 |
+
void quantize_row_iq3_xxs(const float * restrict x, void * restrict y, int k);
|
| 224 |
|
| 225 |
// Dequantization
|
| 226 |
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
|
|
|
|
| 238 |
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
|
| 239 |
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
|
| 240 |
void dequantize_row_iq2_xs (const block_iq2_xs * restrict x, float * restrict y, int k);
|
| 241 |
+
void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k);
|
| 242 |
|
| 243 |
// Dot product
|
| 244 |
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
|
|
|
| 254 |
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 255 |
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 256 |
void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 257 |
+
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 258 |
|
| 259 |
//
|
| 260 |
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
| 261 |
//
|
| 262 |
size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 263 |
size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 264 |
+
size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 265 |
size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 266 |
size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 267 |
size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
|
|
|
| 274 |
|
| 275 |
void iq2xs_init_impl(int grid_size);
|
| 276 |
void iq2xs_free_impl(int grid_size);
|
| 277 |
+
void iq3xs_init_impl(int grid_size);
|
| 278 |
+
void iq3xs_free_impl(int grid_size);
|
|
@@ -632,6 +632,17 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|
| 632 |
.vec_dot = ggml_vec_dot_iq2_xs_q8_K,
|
| 633 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 634 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
[GGML_TYPE_Q8_K] = {
|
| 636 |
.type_name = "q8_K",
|
| 637 |
.blck_size = QK_K,
|
|
@@ -2177,6 +2188,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
|
| 2177 |
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
|
| 2178 |
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
|
| 2179 |
case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
|
|
|
|
| 2180 |
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
| 2181 |
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
| 2182 |
}
|
|
@@ -7570,6 +7582,7 @@ static void ggml_compute_forward_add(
|
|
| 7570 |
case GGML_TYPE_Q6_K:
|
| 7571 |
case GGML_TYPE_IQ2_XXS:
|
| 7572 |
case GGML_TYPE_IQ2_XS:
|
|
|
|
| 7573 |
{
|
| 7574 |
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
| 7575 |
} break;
|
|
@@ -7836,6 +7849,7 @@ static void ggml_compute_forward_add1(
|
|
| 7836 |
case GGML_TYPE_Q6_K:
|
| 7837 |
case GGML_TYPE_IQ2_XXS:
|
| 7838 |
case GGML_TYPE_IQ2_XS:
|
|
|
|
| 7839 |
{
|
| 7840 |
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
|
| 7841 |
} break;
|
|
@@ -7955,6 +7969,7 @@ static void ggml_compute_forward_acc(
|
|
| 7955 |
case GGML_TYPE_Q6_K:
|
| 7956 |
case GGML_TYPE_IQ2_XXS:
|
| 7957 |
case GGML_TYPE_IQ2_XS:
|
|
|
|
| 7958 |
default:
|
| 7959 |
{
|
| 7960 |
GGML_ASSERT(false);
|
|
@@ -10706,6 +10721,7 @@ static void ggml_compute_forward_out_prod(
|
|
| 10706 |
case GGML_TYPE_Q6_K:
|
| 10707 |
case GGML_TYPE_IQ2_XXS:
|
| 10708 |
case GGML_TYPE_IQ2_XS:
|
|
|
|
| 10709 |
{
|
| 10710 |
ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
| 10711 |
} break;
|
|
@@ -10885,6 +10901,7 @@ static void ggml_compute_forward_set(
|
|
| 10885 |
case GGML_TYPE_Q6_K:
|
| 10886 |
case GGML_TYPE_IQ2_XXS:
|
| 10887 |
case GGML_TYPE_IQ2_XS:
|
|
|
|
| 10888 |
default:
|
| 10889 |
{
|
| 10890 |
GGML_ASSERT(false);
|
|
@@ -11081,6 +11098,7 @@ static void ggml_compute_forward_get_rows(
|
|
| 11081 |
case GGML_TYPE_Q6_K:
|
| 11082 |
case GGML_TYPE_IQ2_XXS:
|
| 11083 |
case GGML_TYPE_IQ2_XS:
|
|
|
|
| 11084 |
{
|
| 11085 |
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
| 11086 |
} break;
|
|
@@ -11728,6 +11746,7 @@ static void ggml_compute_forward_alibi(
|
|
| 11728 |
case GGML_TYPE_Q6_K:
|
| 11729 |
case GGML_TYPE_IQ2_XXS:
|
| 11730 |
case GGML_TYPE_IQ2_XS:
|
|
|
|
| 11731 |
case GGML_TYPE_Q8_K:
|
| 11732 |
case GGML_TYPE_I8:
|
| 11733 |
case GGML_TYPE_I16:
|
|
@@ -11804,6 +11823,7 @@ static void ggml_compute_forward_clamp(
|
|
| 11804 |
case GGML_TYPE_Q6_K:
|
| 11805 |
case GGML_TYPE_IQ2_XXS:
|
| 11806 |
case GGML_TYPE_IQ2_XS:
|
|
|
|
| 11807 |
case GGML_TYPE_Q8_K:
|
| 11808 |
case GGML_TYPE_I8:
|
| 11809 |
case GGML_TYPE_I16:
|
|
@@ -18860,6 +18880,7 @@ void ggml_quantize_init(enum ggml_type type) {
|
|
| 18860 |
switch (type) {
|
| 18861 |
case GGML_TYPE_IQ2_XXS: iq2xs_init_impl(256); break;
|
| 18862 |
case GGML_TYPE_IQ2_XS: iq2xs_init_impl(512); break;
|
|
|
|
| 18863 |
default: // nothing
|
| 18864 |
break;
|
| 18865 |
}
|
|
@@ -19122,6 +19143,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|
| 19122 |
result = quantize_iq2_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
| 19123 |
GGML_ASSERT(result == row_size * nrows);
|
| 19124 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19125 |
case GGML_TYPE_F16:
|
| 19126 |
{
|
| 19127 |
size_t elemsize = sizeof(ggml_fp16_t);
|
|
|
|
| 632 |
.vec_dot = ggml_vec_dot_iq2_xs_q8_K,
|
| 633 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 634 |
},
|
| 635 |
+
[GGML_TYPE_IQ3_XXS] = {
|
| 636 |
+
.type_name = "iq3_xxs",
|
| 637 |
+
.blck_size = QK_K,
|
| 638 |
+
.type_size = sizeof(block_iq3_xxs),
|
| 639 |
+
.is_quantized = true,
|
| 640 |
+
.to_float = (ggml_to_float_t) dequantize_row_iq3_xxs,
|
| 641 |
+
.from_float = quantize_row_iq3_xxs,
|
| 642 |
+
.from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference,
|
| 643 |
+
.vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
|
| 644 |
+
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 645 |
+
},
|
| 646 |
[GGML_TYPE_Q8_K] = {
|
| 647 |
.type_name = "q8_K",
|
| 648 |
.blck_size = QK_K,
|
|
|
|
| 2188 |
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
|
| 2189 |
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
|
| 2190 |
case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
|
| 2191 |
+
case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
|
| 2192 |
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
| 2193 |
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
| 2194 |
}
|
|
|
|
| 7582 |
case GGML_TYPE_Q6_K:
|
| 7583 |
case GGML_TYPE_IQ2_XXS:
|
| 7584 |
case GGML_TYPE_IQ2_XS:
|
| 7585 |
+
case GGML_TYPE_IQ3_XXS:
|
| 7586 |
{
|
| 7587 |
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
| 7588 |
} break;
|
|
|
|
| 7849 |
case GGML_TYPE_Q6_K:
|
| 7850 |
case GGML_TYPE_IQ2_XXS:
|
| 7851 |
case GGML_TYPE_IQ2_XS:
|
| 7852 |
+
case GGML_TYPE_IQ3_XXS:
|
| 7853 |
{
|
| 7854 |
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
|
| 7855 |
} break;
|
|
|
|
| 7969 |
case GGML_TYPE_Q6_K:
|
| 7970 |
case GGML_TYPE_IQ2_XXS:
|
| 7971 |
case GGML_TYPE_IQ2_XS:
|
| 7972 |
+
case GGML_TYPE_IQ3_XXS:
|
| 7973 |
default:
|
| 7974 |
{
|
| 7975 |
GGML_ASSERT(false);
|
|
|
|
| 10721 |
case GGML_TYPE_Q6_K:
|
| 10722 |
case GGML_TYPE_IQ2_XXS:
|
| 10723 |
case GGML_TYPE_IQ2_XS:
|
| 10724 |
+
case GGML_TYPE_IQ3_XXS:
|
| 10725 |
{
|
| 10726 |
ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
| 10727 |
} break;
|
|
|
|
| 10901 |
case GGML_TYPE_Q6_K:
|
| 10902 |
case GGML_TYPE_IQ2_XXS:
|
| 10903 |
case GGML_TYPE_IQ2_XS:
|
| 10904 |
+
case GGML_TYPE_IQ3_XXS:
|
| 10905 |
default:
|
| 10906 |
{
|
| 10907 |
GGML_ASSERT(false);
|
|
|
|
| 11098 |
case GGML_TYPE_Q6_K:
|
| 11099 |
case GGML_TYPE_IQ2_XXS:
|
| 11100 |
case GGML_TYPE_IQ2_XS:
|
| 11101 |
+
case GGML_TYPE_IQ3_XXS:
|
| 11102 |
{
|
| 11103 |
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
| 11104 |
} break;
|
|
|
|
| 11746 |
case GGML_TYPE_Q6_K:
|
| 11747 |
case GGML_TYPE_IQ2_XXS:
|
| 11748 |
case GGML_TYPE_IQ2_XS:
|
| 11749 |
+
case GGML_TYPE_IQ3_XXS:
|
| 11750 |
case GGML_TYPE_Q8_K:
|
| 11751 |
case GGML_TYPE_I8:
|
| 11752 |
case GGML_TYPE_I16:
|
|
|
|
| 11823 |
case GGML_TYPE_Q6_K:
|
| 11824 |
case GGML_TYPE_IQ2_XXS:
|
| 11825 |
case GGML_TYPE_IQ2_XS:
|
| 11826 |
+
case GGML_TYPE_IQ3_XXS:
|
| 11827 |
case GGML_TYPE_Q8_K:
|
| 11828 |
case GGML_TYPE_I8:
|
| 11829 |
case GGML_TYPE_I16:
|
|
|
|
| 18880 |
switch (type) {
|
| 18881 |
case GGML_TYPE_IQ2_XXS: iq2xs_init_impl(256); break;
|
| 18882 |
case GGML_TYPE_IQ2_XS: iq2xs_init_impl(512); break;
|
| 18883 |
+
case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
|
| 18884 |
default: // nothing
|
| 18885 |
break;
|
| 18886 |
}
|
|
|
|
| 19143 |
result = quantize_iq2_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
| 19144 |
GGML_ASSERT(result == row_size * nrows);
|
| 19145 |
} break;
|
| 19146 |
+
case GGML_TYPE_IQ3_XXS:
|
| 19147 |
+
{
|
| 19148 |
+
GGML_ASSERT(start % QK_K == 0);
|
| 19149 |
+
GGML_ASSERT(start % n_per_row == 0);
|
| 19150 |
+
size_t start_row = start / n_per_row;
|
| 19151 |
+
size_t row_size = ggml_row_size(type, n_per_row);
|
| 19152 |
+
result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
| 19153 |
+
GGML_ASSERT(result == row_size * nrows);
|
| 19154 |
+
} break;
|
| 19155 |
case GGML_TYPE_F16:
|
| 19156 |
{
|
| 19157 |
size_t elemsize = sizeof(ggml_fp16_t);
|
|
@@ -353,6 +353,7 @@ extern "C" {
|
|
| 353 |
GGML_TYPE_Q8_K = 15,
|
| 354 |
GGML_TYPE_IQ2_XXS = 16,
|
| 355 |
GGML_TYPE_IQ2_XS = 17,
|
|
|
|
| 356 |
GGML_TYPE_I8,
|
| 357 |
GGML_TYPE_I16,
|
| 358 |
GGML_TYPE_I32,
|
|
@@ -389,6 +390,7 @@ extern "C" {
|
|
| 389 |
GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
|
| 390 |
GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
|
| 391 |
GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
|
|
|
|
| 392 |
};
|
| 393 |
|
| 394 |
// available tensor operations:
|
|
|
|
| 353 |
GGML_TYPE_Q8_K = 15,
|
| 354 |
GGML_TYPE_IQ2_XXS = 16,
|
| 355 |
GGML_TYPE_IQ2_XS = 17,
|
| 356 |
+
GGML_TYPE_IQ3_XXS = 18,
|
| 357 |
GGML_TYPE_I8,
|
| 358 |
GGML_TYPE_I16,
|
| 359 |
GGML_TYPE_I32,
|
|
|
|
| 390 |
GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
|
| 391 |
GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
|
| 392 |
GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
|
| 393 |
+
GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
|
| 394 |
};
|
| 395 |
|
| 396 |
// available tensor operations:
|