Kawrakow ikawrakow commited on
Commit
4649943
·
unverified ·
1 Parent(s): 8181686

SOTA 3-bit quants (llama/5196)

Browse files

* iq3_xxs: quantize/dequantize

RMSE seems a bit high-ish at about half-way between q2_K and
q3_K, so need to check more.

* iq3_xxs: CUDA dequantize works

* iq2_xxs: tuning quantization

* iq3_xxs: starting to look better

PPL on wiki.test.raw
LLaMA-v1-7B: 6.4218
LLaMA-v2-7B: 6.3560
Mistral-7B : 6.0717

This is better than Q3_K_XS, with a 5% reduction in quantized model
size.

* iq3_xxs: CUDA dot product

We have
PP-512: 5891 t/s
TG-128: 143.9 t/s

* iq3_xxs: scalar and AVX2 dot products

* iq3_xxs: ARM_NEON and Metal

Metal performance is decent, ARM_NEON is pathetic

* iq3_xxs: slightly better grid points

* Faster iq3_xxs and iq2_xs dot products on CUDA

* iq3_xxs: add some quant mix

* iq3_xxs: fix failing quantization test

Dot product still fails. Is this real?

* iq3_xxs: hopefully fix ROCm

* iq3_xxs: failing tests

This time the dot product accuracy did find an actual bug
in the AVX2 implementation.

* Add IQ3_XXS to test-backend-ops

---------

Co-authored-by: Iwan Kawrakow <[email protected]>

Files changed (7) hide show
  1. ggml-cuda.cu +189 -11
  2. ggml-metal.m +35 -0
  3. ggml-metal.metal +274 -0
  4. ggml-quants.c +630 -0
  5. ggml-quants.h +17 -1
  6. ggml.c +30 -0
  7. ggml.h +2 -0
ggml-cuda.cu CHANGED
@@ -191,6 +191,10 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
191
  #endif // __has_builtin(__builtin_elementwise_sub_sat)
192
  }
193
 
 
 
 
 
194
  static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
195
  #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
196
  c = __builtin_amdgcn_sdot4(a, b, c, false);
@@ -505,6 +509,14 @@ typedef struct {
505
  } block_iq2_xs;
506
  static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
507
 
 
 
 
 
 
 
 
 
508
  #define WARP_SIZE 32
509
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
510
 
@@ -1613,6 +1625,41 @@ static const __device__ uint64_t iq2xs_grid[512] = {
1613
  0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
1614
  };
1615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1616
  static const __device__ uint8_t ksigns_iq2xs[128] = {
1617
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
1618
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -1624,6 +1671,43 @@ static const __device__ uint8_t ksigns_iq2xs[128] = {
1624
  240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
1625
  };
1626
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1627
  static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
1628
 
1629
  inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
@@ -1690,6 +1774,34 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
1690
 
1691
  }
1692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1693
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1694
 
1695
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@@ -4313,6 +4425,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
4313
 
4314
  static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
4315
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
 
4316
  #if QK_K == 256
4317
  const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
4318
 
@@ -4323,20 +4436,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
4323
  const uint8_t ls2 = bq2->scales[ib32] >> 4;
4324
  int sumi1 = 0;
4325
  for (int l = 0; l < 2; ++l) {
4326
- const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
4327
- const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
4328
- for (int j = 0; j < 8; ++j) {
4329
- sumi1 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
4330
- }
 
4331
  q8 += 8;
4332
  }
4333
  int sumi2 = 0;
4334
  for (int l = 2; l < 4; ++l) {
4335
- const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
4336
- const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
4337
- for (int j = 0; j < 8; ++j) {
4338
- sumi2 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
4339
- }
 
4340
  q8 += 8;
4341
  }
4342
  const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f;
@@ -4345,6 +4460,45 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
4345
  assert(false);
4346
  return 0.f;
4347
  #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4348
  }
4349
 
4350
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
@@ -6394,6 +6548,12 @@ static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k,
6394
  dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
6395
  }
6396
 
 
 
 
 
 
 
6397
  template <typename src_t, typename dst_t>
6398
  static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
6399
  const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@@ -6431,6 +6591,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
6431
  return dequantize_row_iq2_xxs_cuda;
6432
  case GGML_TYPE_IQ2_XS:
6433
  return dequantize_row_iq2_xs_cuda;
 
 
6434
  case GGML_TYPE_F32:
6435
  return convert_unary_cuda<float>;
6436
  default:
@@ -6464,6 +6626,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
6464
  return dequantize_row_iq2_xxs_cuda;
6465
  case GGML_TYPE_IQ2_XS:
6466
  return dequantize_row_iq2_xs_cuda;
 
 
6467
  case GGML_TYPE_F16:
6468
  return convert_unary_cuda<half>;
6469
  default:
@@ -6676,6 +6840,15 @@ static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, float
6676
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6677
  }
6678
 
 
 
 
 
 
 
 
 
 
6679
  static void ggml_mul_mat_q4_0_q8_1_cuda(
6680
  const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
6681
  const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -8239,6 +8412,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
8239
  case GGML_TYPE_Q6_K:
8240
  case GGML_TYPE_IQ2_XXS:
8241
  case GGML_TYPE_IQ2_XS:
 
8242
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
8243
  default:
8244
  GGML_ASSERT(false);
@@ -8261,6 +8435,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
8261
  case GGML_TYPE_Q5_K:
8262
  case GGML_TYPE_IQ2_XXS:
8263
  case GGML_TYPE_IQ2_XS:
 
8264
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
8265
  case GGML_TYPE_Q6_K:
8266
  return 64;
@@ -8332,6 +8507,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
8332
  case GGML_TYPE_IQ2_XS:
8333
  mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8334
  break;
 
 
 
8335
  default:
8336
  GGML_ASSERT(false);
8337
  break;
@@ -10968,7 +11146,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
10968
  return false;
10969
  }
10970
  ggml_type a_type = a->type;
10971
- if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS) {
10972
  if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
10973
  return false;
10974
  }
 
191
  #endif // __has_builtin(__builtin_elementwise_sub_sat)
192
  }
193
 
194
+ static __device__ __forceinline__ int __vsub4(const int a, const int b) {
195
+ return __vsubss4(a, b);
196
+ }
197
+
198
  static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
199
  #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
200
  c = __builtin_amdgcn_sdot4(a, b, c, false);
 
509
  } block_iq2_xs;
510
  static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
511
 
512
+ #define QR3_XXS 8
513
+ #define QI3_XXS (QK_K / (4*QR3_XXS))
514
+ typedef struct {
515
+ half d;
516
+ uint8_t qs[3*(QK_K/8)];
517
+ } block_iq3_xxs;
518
+ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
519
+
520
  #define WARP_SIZE 32
521
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
522
 
 
1625
  0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
1626
  };
1627
 
1628
+ static const __device__ uint32_t iq3xxs_grid[256] = {
1629
+ 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
1630
+ 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
1631
+ 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
1632
+ 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
1633
+ 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
1634
+ 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
1635
+ 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
1636
+ 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
1637
+ 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
1638
+ 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
1639
+ 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
1640
+ 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
1641
+ 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
1642
+ 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
1643
+ 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
1644
+ 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
1645
+ 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
1646
+ 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
1647
+ 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
1648
+ 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
1649
+ 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
1650
+ 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
1651
+ 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
1652
+ 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
1653
+ 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
1654
+ 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
1655
+ 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
1656
+ 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
1657
+ 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
1658
+ 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
1659
+ 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
1660
+ 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
1661
+ };
1662
+
1663
  static const __device__ uint8_t ksigns_iq2xs[128] = {
1664
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
1665
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
 
1671
  240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
1672
  };
1673
 
1674
+ //#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
1675
+ static const __device__ uint64_t ksigns64[128] = {
1676
+ 0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
1677
+ 0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
1678
+ 0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff,
1679
+ 0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff,
1680
+ 0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff,
1681
+ 0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff,
1682
+ 0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff,
1683
+ 0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff,
1684
+ 0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff,
1685
+ 0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff,
1686
+ 0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff,
1687
+ 0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff,
1688
+ 0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff,
1689
+ 0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff,
1690
+ 0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff,
1691
+ 0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff,
1692
+ 0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff,
1693
+ 0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff,
1694
+ 0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff,
1695
+ 0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff,
1696
+ 0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff,
1697
+ 0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff,
1698
+ 0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff,
1699
+ 0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff,
1700
+ 0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff,
1701
+ 0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff,
1702
+ 0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff,
1703
+ 0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff,
1704
+ 0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff,
1705
+ 0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff,
1706
+ 0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
1707
+ 0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
1708
+ };
1709
+ //#endif
1710
+
1711
  static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
1712
 
1713
  inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
 
1774
 
1775
  }
1776
 
1777
+ template<typename dst_t>
1778
+ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
1779
+
1780
+ const int i = blockIdx.x;
1781
+ const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
1782
+
1783
+ const int tid = threadIdx.x;
1784
+ #if QK_K == 256
1785
+ const int il = tid/8; // 0...3
1786
+ const int ib = tid%8; // 0...7
1787
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
1788
+ const uint8_t * q3 = x[i].qs + 8*ib;
1789
+ const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
1790
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
1791
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
1792
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
1793
+ const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
1794
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
1795
+ for (int j = 0; j < 4; ++j) {
1796
+ y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
1797
+ y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
1798
+ }
1799
+ #else
1800
+ assert(false);
1801
+ #endif
1802
+
1803
+ }
1804
+
1805
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1806
 
1807
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 
4425
 
4426
  static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
4427
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4428
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
4429
  #if QK_K == 256
4430
  const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
4431
 
 
4436
  const uint8_t ls2 = bq2->scales[ib32] >> 4;
4437
  int sumi1 = 0;
4438
  for (int l = 0; l < 2; ++l) {
4439
+ const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
4440
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
4441
+ const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]);
4442
+ const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]);
4443
+ sumi1 = __dp4a(grid_l, *((const int *)q8 + 0), sumi1);
4444
+ sumi1 = __dp4a(grid_h, *((const int *)q8 + 1), sumi1);
4445
  q8 += 8;
4446
  }
4447
  int sumi2 = 0;
4448
  for (int l = 2; l < 4; ++l) {
4449
+ const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
4450
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
4451
+ const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]);
4452
+ const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]);
4453
+ sumi2 = __dp4a(grid_l, *((const int *)q8 + 0), sumi2);
4454
+ sumi2 = __dp4a(grid_h, *((const int *)q8 + 1), sumi2);
4455
  q8 += 8;
4456
  }
4457
  const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f;
 
4460
  assert(false);
4461
  return 0.f;
4462
  #endif
4463
+ #else
4464
+ assert(false);
4465
+ return 0.f;
4466
+ #endif
4467
+ }
4468
+
4469
+ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
4470
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4471
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
4472
+ #if QK_K == 256
4473
+ const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
4474
+
4475
+ const int ib32 = iqs;
4476
+ const uint8_t * q3 = bq2->qs + 8*ib32;
4477
+ const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32;
4478
+ const int8_t * q8 = bq8_1[ib32].qs;
4479
+ uint32_t aux32 = gas[0] | (gas[1] << 16);
4480
+ int sumi = 0;
4481
+ for (int l = 0; l < 4; ++l) {
4482
+ const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0];
4483
+ const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1];
4484
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127));
4485
+ const int grid_l = __vsub4(grid1[0] ^ signs[0], signs[0]);
4486
+ const int grid_h = __vsub4(grid2[0] ^ signs[1], signs[1]);
4487
+ sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
4488
+ sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
4489
+ q8 += 8;
4490
+ aux32 >>= 7;
4491
+ }
4492
+ const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.5f;
4493
+ return d * sumi;
4494
+ #else
4495
+ assert(false);
4496
+ return 0.f;
4497
+ #endif
4498
+ #else
4499
+ assert(false);
4500
+ return 0.f;
4501
+ #endif
4502
  }
4503
 
4504
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
 
6548
  dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
6549
  }
6550
 
6551
+ template<typename dst_t>
6552
+ static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6553
+ const int nb = k / QK_K;
6554
+ dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
6555
+ }
6556
+
6557
  template <typename src_t, typename dst_t>
6558
  static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
6559
  const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
 
6591
  return dequantize_row_iq2_xxs_cuda;
6592
  case GGML_TYPE_IQ2_XS:
6593
  return dequantize_row_iq2_xs_cuda;
6594
+ case GGML_TYPE_IQ3_XXS:
6595
+ return dequantize_row_iq3_xxs_cuda;
6596
  case GGML_TYPE_F32:
6597
  return convert_unary_cuda<float>;
6598
  default:
 
6626
  return dequantize_row_iq2_xxs_cuda;
6627
  case GGML_TYPE_IQ2_XS:
6628
  return dequantize_row_iq2_xs_cuda;
6629
+ case GGML_TYPE_IQ3_XXS:
6630
+ return dequantize_row_iq3_xxs_cuda;
6631
  case GGML_TYPE_F16:
6632
  return convert_unary_cuda<half>;
6633
  default:
 
6840
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6841
  }
6842
 
6843
+ static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6844
+ GGML_ASSERT(ncols % QK_K == 0);
6845
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6846
+ const dim3 block_nums(block_num_y, 1, 1);
6847
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6848
+ mul_mat_vec_q<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
6849
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6850
+ }
6851
+
6852
  static void ggml_mul_mat_q4_0_q8_1_cuda(
6853
  const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
6854
  const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
8412
  case GGML_TYPE_Q6_K:
8413
  case GGML_TYPE_IQ2_XXS:
8414
  case GGML_TYPE_IQ2_XS:
8415
+ case GGML_TYPE_IQ3_XXS:
8416
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
8417
  default:
8418
  GGML_ASSERT(false);
 
8435
  case GGML_TYPE_Q5_K:
8436
  case GGML_TYPE_IQ2_XXS:
8437
  case GGML_TYPE_IQ2_XS:
8438
+ case GGML_TYPE_IQ3_XXS:
8439
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
8440
  case GGML_TYPE_Q6_K:
8441
  return 64;
 
8507
  case GGML_TYPE_IQ2_XS:
8508
  mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8509
  break;
8510
+ case GGML_TYPE_IQ3_XXS:
8511
+ mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8512
+ break;
8513
  default:
8514
  GGML_ASSERT(false);
8515
  break;
 
11146
  return false;
11147
  }
11148
  ggml_type a_type = a->type;
11149
+ if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS) {
11150
  if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
11151
  return false;
11152
  }
ggml-metal.m CHANGED
@@ -60,6 +60,7 @@ enum ggml_metal_kernel_type {
60
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
61
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
 
63
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
64
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
65
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -81,6 +82,7 @@ enum ggml_metal_kernel_type {
81
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
82
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
83
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
 
84
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
85
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
86
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -98,6 +100,7 @@ enum ggml_metal_kernel_type {
98
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
99
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
100
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
 
101
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
102
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
103
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -112,6 +115,7 @@ enum ggml_metal_kernel_type {
112
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
113
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
114
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
 
115
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
116
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
117
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -126,6 +130,7 @@ enum ggml_metal_kernel_type {
126
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
127
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
128
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
 
129
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
130
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
131
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -426,6 +431,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
426
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
427
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
428
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
 
429
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
430
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
431
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -447,6 +453,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
447
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
448
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
449
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
 
450
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
451
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
452
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -464,6 +471,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
464
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
465
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
466
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
 
467
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
468
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
469
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -478,6 +486,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
478
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
479
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
480
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
 
481
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
482
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
483
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -492,6 +501,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
492
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
493
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
494
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
 
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
@@ -1279,6 +1289,7 @@ static bool ggml_metal_graph_compute(
1279
  case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1280
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1281
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
 
1282
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1283
  }
1284
 
@@ -1407,6 +1418,12 @@ static bool ggml_metal_graph_compute(
1407
  nth1 = 16;
1408
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
1409
  } break;
 
 
 
 
 
 
1410
  default:
1411
  {
1412
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1449,6 +1466,11 @@ static bool ggml_metal_graph_compute(
1449
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1450
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1451
  }
 
 
 
 
 
1452
  else if (src0t == GGML_TYPE_Q4_K) {
1453
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1454
  }
@@ -1543,6 +1565,7 @@ static bool ggml_metal_graph_compute(
1543
  case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
1544
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1545
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
 
1546
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1547
  }
1548
 
@@ -1674,6 +1697,12 @@ static bool ggml_metal_graph_compute(
1674
  nth1 = 16;
1675
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
1676
  } break;
 
 
 
 
 
 
1677
  default:
1678
  {
1679
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1732,6 +1761,11 @@ static bool ggml_metal_graph_compute(
1732
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1733
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1734
  }
 
 
 
 
 
1735
  else if (src2t == GGML_TYPE_Q4_K) {
1736
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1737
  }
@@ -1772,6 +1806,7 @@ static bool ggml_metal_graph_compute(
1772
  case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
1773
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1774
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
 
1775
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1776
  default: GGML_ASSERT(false && "not implemented");
1777
  }
 
60
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
61
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
64
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
65
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
66
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
 
82
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
83
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
84
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
85
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
86
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
87
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
88
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
 
100
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
101
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
102
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
103
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
104
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
105
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
106
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
 
115
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
116
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
117
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
118
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
119
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
120
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
121
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
 
130
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
131
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
132
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
133
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
134
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
135
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
136
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
 
431
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
432
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
433
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
434
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
435
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
436
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
437
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
 
453
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
454
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
455
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
456
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
457
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
458
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
459
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
 
471
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
472
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
473
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
474
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
475
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
 
486
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
487
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
488
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
489
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
490
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
491
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
492
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
 
501
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
502
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
503
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
504
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
505
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
506
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
507
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
 
1289
  case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1290
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1291
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1292
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1293
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1294
  }
1295
 
 
1418
  nth1 = 16;
1419
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
1420
  } break;
1421
+ case GGML_TYPE_IQ3_XXS:
1422
+ {
1423
+ nth0 = 4;
1424
+ nth1 = 16;
1425
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1426
+ } break;
1427
  default:
1428
  {
1429
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
 
1466
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1467
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1468
  }
1469
+ else if (src0t == GGML_TYPE_IQ3_XXS) {
1470
+ const int mem_size = 256*4+128;
1471
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1472
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1473
+ }
1474
  else if (src0t == GGML_TYPE_Q4_K) {
1475
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1476
  }
 
1565
  case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
1566
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1567
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1568
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1569
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1570
  }
1571
 
 
1697
  nth1 = 16;
1698
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
1699
  } break;
1700
+ case GGML_TYPE_IQ3_XXS:
1701
+ {
1702
+ nth0 = 4;
1703
+ nth1 = 16;
1704
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
1705
+ } break;
1706
  default:
1707
  {
1708
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
 
1761
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1762
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1763
  }
1764
+ else if (src2t == GGML_TYPE_IQ3_XXS) {
1765
+ const int mem_size = 256*4+128;
1766
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1767
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1768
+ }
1769
  else if (src2t == GGML_TYPE_Q4_K) {
1770
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1771
  }
 
1806
  case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
1807
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1808
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1809
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1810
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1811
  default: GGML_ASSERT(false && "not implemented");
1812
  }
ggml-metal.metal CHANGED
@@ -2459,6 +2459,12 @@ typedef struct {
2459
  } block_iq2_xs;
2460
  // 74 bytes / block for QK_K = 256, so 2.3125 bpw
2461
 
 
 
 
 
 
 
2462
  //====================================== dot products =========================
2463
 
2464
  void kernel_mul_mv_q2_K_f32_impl(
@@ -3681,6 +3687,42 @@ constexpr constant static uint64_t iq2xs_grid[512] = {
3681
  0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3682
  };
3683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3684
  constexpr constant static uint8_t ksigns_iq2xs[128] = {
3685
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3686
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -3970,6 +4012,143 @@ kernel void kernel_mul_mv_iq2_xs_f32(
3970
  kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3971
  }
3972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3973
  //============================= templates and their specializations =============================
3974
 
3975
  // NOTE: this is not dequantizing - we are simply fitting the template
@@ -4287,6 +4466,33 @@ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4
4287
  }
4288
  }
4289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4290
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
4291
  kernel void kernel_get_rows(
4292
  device const void * src0,
@@ -4828,6 +5034,7 @@ template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows
4828
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
4829
  template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4830
  template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
 
4831
 
4832
  //
4833
  // matrix-matrix multiplication
@@ -4866,6 +5073,7 @@ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
4866
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4867
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4868
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
 
4869
 
4870
  //
4871
  // indirect matrix-matrix multiplication
@@ -4916,6 +5124,7 @@ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mu
4916
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4917
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4918
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
 
4919
 
4920
  //
4921
  // matrix-vector multiplication
@@ -5818,3 +6027,68 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
5818
  tiisg,
5819
  sgitg);
5820
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2459
  } block_iq2_xs;
2460
  // 74 bytes / block for QK_K = 256, so 2.3125 bpw
2461
 
2462
+ typedef struct {
2463
+ half d;
2464
+ uint8_t qs[3*QK_K/8];
2465
+ } block_iq3_xxs;
2466
+ // 98 bytes / block for QK_K = 256, so 3.0625 bpw
2467
+
2468
  //====================================== dot products =========================
2469
 
2470
  void kernel_mul_mv_q2_K_f32_impl(
 
3687
  0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3688
  };
3689
 
3690
+ constexpr constant static uint32_t iq3xxs_grid[256] = {
3691
+ 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3c, 0x04041404, 0x04041414,
3692
+ 0x04041c0c, 0x04042414, 0x04043c1c, 0x04043c2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
3693
+ 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3c04, 0x04140404,
3694
+ 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3c,
3695
+ 0x04142c0c, 0x04142c3c, 0x04143c2c, 0x041c040c, 0x041c043c, 0x041c0c04, 0x041c0c14, 0x041c142c,
3696
+ 0x041c3c04, 0x04240c1c, 0x04241c3c, 0x04242424, 0x04242c3c, 0x04243c1c, 0x04243c2c, 0x042c040c,
3697
+ 0x042c043c, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043c0c04, 0x043c0c24, 0x043c0c34,
3698
+ 0x043c241c, 0x043c340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
3699
+ 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243c, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
3700
+ 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
3701
+ 0x0c143c14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
3702
+ 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3c0c, 0x0c34042c, 0x0c3c1414,
3703
+ 0x0c3c2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
3704
+ 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
3705
+ 0x140c1c04, 0x140c341c, 0x140c343c, 0x140c3c04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3c,
3706
+ 0x14141404, 0x14141414, 0x14141c3c, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
3707
+ 0x141c3c04, 0x141c3c24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143c, 0x142c240c, 0x142c3c24,
3708
+ 0x143c040c, 0x143c041c, 0x143c0c34, 0x143c242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
3709
+ 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043c14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
3710
+ 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143c14,
3711
+ 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243c, 0x1c243c14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
3712
+ 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3c1c1c, 0x1c3c3404, 0x24040424, 0x24040c3c,
3713
+ 0x24041c2c, 0x24041c3c, 0x24042c1c, 0x24042c3c, 0x240c3c24, 0x24141404, 0x24141c3c, 0x24142404,
3714
+ 0x24143404, 0x24143434, 0x241c043c, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
3715
+ 0x242c241c, 0x242c3c04, 0x243c042c, 0x243c0c04, 0x243c0c14, 0x243c1c04, 0x2c040c14, 0x2c04240c,
3716
+ 0x2c043c04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143c14,
3717
+ 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143c, 0x2c243c14, 0x2c2c0414, 0x2c2c1c0c,
3718
+ 0x2c342c04, 0x2c3c1424, 0x2c3c2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
3719
+ 0x340c340c, 0x34140c3c, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
3720
+ 0x34341c1c, 0x343c041c, 0x343c140c, 0x3c04041c, 0x3c04042c, 0x3c04043c, 0x3c040c04, 0x3c041c14,
3721
+ 0x3c042c14, 0x3c0c1434, 0x3c0c2404, 0x3c140c14, 0x3c14242c, 0x3c142c14, 0x3c1c0404, 0x3c1c0c2c,
3722
+ 0x3c1c1c1c, 0x3c1c3404, 0x3c24140c, 0x3c24240c, 0x3c2c0404, 0x3c2c0414, 0x3c2c1424, 0x3c341c04,
3723
+ };
3724
+
3725
+
3726
  constexpr constant static uint8_t ksigns_iq2xs[128] = {
3727
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3728
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
 
4012
  kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4013
  }
4014
 
4015
+ void kernel_mul_mv_iq3_xxs_f32_impl(
4016
+ device const void * src0,
4017
+ device const float * src1,
4018
+ device float * dst,
4019
+ constant int64_t & ne00,
4020
+ constant int64_t & ne01,
4021
+ constant int64_t & ne02,
4022
+ constant int64_t & ne10,
4023
+ constant int64_t & ne12,
4024
+ constant int64_t & ne0,
4025
+ constant int64_t & ne1,
4026
+ constant uint & r2,
4027
+ constant uint & r3,
4028
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4029
+ uint3 tgpig[[threadgroup_position_in_grid]],
4030
+ uint tiisg[[thread_index_in_simdgroup]],
4031
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4032
+
4033
+ const int nb = ne00/QK_K;
4034
+ const int r0 = tgpig.x;
4035
+ const int r1 = tgpig.y;
4036
+ const int im = tgpig.z;
4037
+
4038
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4039
+ const int ib_row = first_row * nb;
4040
+
4041
+ const uint i12 = im%ne12;
4042
+ const uint i13 = im/ne12;
4043
+
4044
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4045
+
4046
+ device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0;
4047
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4048
+
4049
+ float yl[32];
4050
+ float sumf[N_DST]={0.f}, all_sum;
4051
+
4052
+ const int nb32 = nb * (QK_K / 32);
4053
+
4054
+ threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
4055
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
4056
+ {
4057
+ int nval = 4;
4058
+ int pos = (32*sgitg + tiisg)*nval;
4059
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i];
4060
+ nval = 2;
4061
+ pos = (32*sgitg + tiisg)*nval;
4062
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
4063
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4064
+ }
4065
+
4066
+ #if QK_K == 256
4067
+ const int ix = tiisg;
4068
+
4069
+ device const float * y4 = y + 32 * ix;
4070
+
4071
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
4072
+
4073
+ for (int i = 0; i < 32; ++i) {
4074
+ yl[i] = y4[i];
4075
+ }
4076
+
4077
+ const int ibl = ib32 / (QK_K / 32);
4078
+ const int ib = ib32 % (QK_K / 32);
4079
+
4080
+ device const block_iq3_xxs * xr = x + ibl;
4081
+ device const uint8_t * q3 = xr->qs + 8 * ib;
4082
+ device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
4083
+ device const half * dh = &xr->d;
4084
+
4085
+ for (int row = 0; row < N_DST; row++) {
4086
+
4087
+ const float db = dh[0];
4088
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
4089
+ const float d = db * (0.5f + (aux32 >> 28));
4090
+
4091
+ float2 sum = {0};
4092
+ for (int l = 0; l < 4; ++l) {
4093
+ const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]);
4094
+ const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]);
4095
+ const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
4096
+ for (int j = 0; j < 4; ++j) {
4097
+ sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
4098
+ sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
4099
+ }
4100
+ }
4101
+ sumf[row] += d * (sum[0] + sum[1]);
4102
+
4103
+ dh += nb*sizeof(block_iq3_xxs)/2;
4104
+ q3 += nb*sizeof(block_iq3_xxs);
4105
+ gas += nb*sizeof(block_iq3_xxs)/2;
4106
+ }
4107
+
4108
+ y4 += 32 * 32;
4109
+ }
4110
+ #else
4111
+ // TODO
4112
+ #endif
4113
+
4114
+ for (int row = 0; row < N_DST; ++row) {
4115
+ all_sum = simd_sum(sumf[row]);
4116
+ if (tiisg == 0) {
4117
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
4118
+ }
4119
+ }
4120
+ }
4121
+
4122
+ [[host_name("kernel_mul_mv_iq3_xxs_f32")]]
4123
+ kernel void kernel_mul_mv_iq3_xxs_f32(
4124
+ device const void * src0,
4125
+ device const float * src1,
4126
+ device float * dst,
4127
+ constant int64_t & ne00,
4128
+ constant int64_t & ne01,
4129
+ constant int64_t & ne02,
4130
+ constant uint64_t & nb00,
4131
+ constant uint64_t & nb01,
4132
+ constant uint64_t & nb02,
4133
+ constant int64_t & ne10,
4134
+ constant int64_t & ne11,
4135
+ constant int64_t & ne12,
4136
+ constant uint64_t & nb10,
4137
+ constant uint64_t & nb11,
4138
+ constant uint64_t & nb12,
4139
+ constant int64_t & ne0,
4140
+ constant int64_t & ne1,
4141
+ constant uint & r2,
4142
+ constant uint & r3,
4143
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4144
+ uint3 tgpig[[threadgroup_position_in_grid]],
4145
+ uint tiisg[[thread_index_in_simdgroup]],
4146
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4147
+
4148
+ kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4149
+ }
4150
+
4151
+
4152
  //============================= templates and their specializations =============================
4153
 
4154
  // NOTE: this is not dequantizing - we are simply fitting the template
 
4466
  }
4467
  }
4468
 
4469
+ template <typename type4x4>
4470
+ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
4471
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4472
+ const float d = xb->d;
4473
+ const int ib32 = il/2;
4474
+ il = il%2;
4475
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4476
+ device const uint8_t * q3 = xb->qs + 8*ib32;
4477
+ device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
4478
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
4479
+ const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
4480
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
4481
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
4482
+ uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
4483
+ for (int i = 0; i < 4; ++i) {
4484
+ reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
4485
+ reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
4486
+ }
4487
+ grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
4488
+ grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
4489
+ signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
4490
+ for (int i = 0; i < 4; ++i) {
4491
+ reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
4492
+ reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
4493
+ }
4494
+ }
4495
+
4496
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
4497
  kernel void kernel_get_rows(
4498
  device const void * src0,
 
5034
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
5035
  template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5036
  template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5037
+ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5038
 
5039
  //
5040
  // matrix-matrix multiplication
 
5073
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
5074
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5075
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5076
+ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5077
 
5078
  //
5079
  // indirect matrix-matrix multiplication
 
5124
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
5125
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5126
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5127
+ template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5128
 
5129
  //
5130
  // matrix-vector multiplication
 
6027
  tiisg,
6028
  sgitg);
6029
  }
6030
+
6031
+ [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
6032
+ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6033
+ device const char * ids,
6034
+ device const char * src1,
6035
+ device float * dst,
6036
+ constant uint64_t & nbi1,
6037
+ constant int64_t & ne00,
6038
+ constant int64_t & ne01,
6039
+ constant int64_t & ne02,
6040
+ constant uint64_t & nb00,
6041
+ constant uint64_t & nb01,
6042
+ constant uint64_t & nb02,
6043
+ constant int64_t & ne10,
6044
+ constant int64_t & ne11,
6045
+ constant int64_t & ne12,
6046
+ constant int64_t & ne13,
6047
+ constant uint64_t & nb10,
6048
+ constant uint64_t & nb11,
6049
+ constant uint64_t & nb12,
6050
+ constant int64_t & ne0,
6051
+ constant int64_t & ne1,
6052
+ constant uint64_t & nb1,
6053
+ constant uint & r2,
6054
+ constant uint & r3,
6055
+ constant int & idx,
6056
+ device const char * src00,
6057
+ device const char * src01,
6058
+ device const char * src02,
6059
+ device const char * src03,
6060
+ device const char * src04,
6061
+ device const char * src05,
6062
+ device const char * src06,
6063
+ device const char * src07,
6064
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
6065
+ uint3 tgpig[[threadgroup_position_in_grid]],
6066
+ uint tiitg[[thread_index_in_threadgroup]],
6067
+ uint tiisg[[thread_index_in_simdgroup]],
6068
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6069
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6070
+
6071
+ const int64_t bid = tgpig.z/(ne12*ne13);
6072
+
6073
+ tgpig.z = tgpig.z%(ne12*ne13);
6074
+
6075
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6076
+
6077
+ kernel_mul_mv_iq3_xxs_f32_impl(
6078
+ src0[id],
6079
+ (device const float *) (src1 + bid*nb11),
6080
+ dst + bid*ne0,
6081
+ ne00,
6082
+ ne01,
6083
+ ne02,
6084
+ ne10,
6085
+ ne12,
6086
+ ne0,
6087
+ ne1,
6088
+ r2,
6089
+ r3,
6090
+ shared_values,
6091
+ tgpig,
6092
+ tiisg,
6093
+ sgitg);
6094
+ }
ggml-quants.c CHANGED
@@ -3441,6 +3441,41 @@ static const uint64_t iq2xs_grid[512] = {
3441
  0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3442
  };
3443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3444
  static const uint8_t ksigns_iq2xs[128] = {
3445
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3446
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -3507,6 +3542,38 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y,
3507
  }
3508
  }
3509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3510
  //===================================== Q8_K ==============================================
3511
 
3512
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -8551,6 +8618,136 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8551
  #endif
8552
  }
8553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8554
  // ================================ IQ2 quantization =============================================
8555
 
8556
  typedef struct {
@@ -9189,3 +9386,436 @@ size_t quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, i
9189
  return nrow * nblock * sizeof(block_iq2_xs);
9190
  }
9191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3441
  0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3442
  };
3443
 
3444
+ static const uint32_t iq3xxs_grid[256] = {
3445
+ 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
3446
+ 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
3447
+ 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
3448
+ 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
3449
+ 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
3450
+ 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
3451
+ 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
3452
+ 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
3453
+ 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
3454
+ 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
3455
+ 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
3456
+ 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
3457
+ 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
3458
+ 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
3459
+ 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
3460
+ 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
3461
+ 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
3462
+ 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
3463
+ 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
3464
+ 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
3465
+ 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
3466
+ 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
3467
+ 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
3468
+ 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
3469
+ 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
3470
+ 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
3471
+ 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
3472
+ 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
3473
+ 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
3474
+ 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
3475
+ 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
3476
+ 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3477
+ };
3478
+
3479
  static const uint8_t ksigns_iq2xs[128] = {
3480
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3481
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
 
3542
  }
3543
  }
3544
 
3545
+ // ====================== 3.0625 bpw (de)-quantization
3546
+
3547
+ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k) {
3548
+ assert(k % QK_K == 0);
3549
+ const int nb = k / QK_K;
3550
+
3551
+ uint32_t aux32;
3552
+
3553
+ for (int i = 0; i < nb; i++) {
3554
+
3555
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3556
+ const uint8_t * qs = x[i].qs;
3557
+ const uint8_t * scales_and_signs = qs + QK_K/4;
3558
+
3559
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
3560
+ memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t));
3561
+ const float db = d * (0.5f + (aux32 >> 28)) * 0.5f;
3562
+ for (int l = 0; l < 4; ++l) {
3563
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
3564
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]);
3565
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]);
3566
+ for (int j = 0; j < 4; ++j) {
3567
+ y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
3568
+ y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
3569
+ }
3570
+ y += 8;
3571
+ }
3572
+ qs += 8;
3573
+ }
3574
+ }
3575
+ }
3576
+
3577
  //===================================== Q8_K ==============================================
3578
 
3579
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
 
8618
  #endif
8619
  }
8620
 
8621
+ // TODO
8622
+ void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8623
+ assert(n % QK_K == 0);
8624
+
8625
+ const block_iq3_xxs * restrict x = vx;
8626
+ const block_q8_K * restrict y = vy;
8627
+
8628
+ const int nb = n / QK_K;
8629
+
8630
+ #if defined(__ARM_NEON)
8631
+
8632
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8633
+
8634
+ uint32_t aux32[2];
8635
+
8636
+ ggml_int8x16x4_t q3s;
8637
+ ggml_int8x16x4_t q8b;
8638
+
8639
+ float sumf = 0;
8640
+ for (int i = 0; i < nb; ++i) {
8641
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8642
+ const uint8_t * restrict q3 = x[i].qs;
8643
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
8644
+ const int8_t * restrict q8 = y[i].qs;
8645
+ float sumf1 = 0, sumf2 = 0;
8646
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8647
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
8648
+ memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
8649
+ const uint32x4_t aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};
8650
+ const uint32x4_t aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};
8651
+ const uint32x4_t aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};
8652
+ const uint32x4_t aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};
8653
+ q3 += 16;
8654
+ q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
8655
+ q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
8656
+ q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
8657
+ q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
8658
+ q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
8659
+ q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
8660
+ q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
8661
+ q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
8662
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
8663
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
8664
+ sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
8665
+ sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
8666
+ }
8667
+ sumf += d*(sumf1 + sumf2);
8668
+ }
8669
+ *s = 0.5f * sumf;
8670
+
8671
+ #elif defined(__AVX2__)
8672
+
8673
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8674
+
8675
+ uint32_t aux32[2];
8676
+
8677
+ __m256 accumf = _mm256_setzero_ps();
8678
+ for (int i = 0; i < nb; ++i) {
8679
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8680
+ const uint8_t * restrict q3 = x[i].qs;
8681
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
8682
+ const int8_t * restrict q8 = y[i].qs;
8683
+ __m256i sumi1 = _mm256_setzero_si256();
8684
+ __m256i sumi2 = _mm256_setzero_si256();
8685
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8686
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8687
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8688
+ const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
8689
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
8690
+ q3 += 8;
8691
+ const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
8692
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
8693
+ q3 += 8;
8694
+ memcpy(aux32, gas, 8); gas += 8;
8695
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
8696
+ signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
8697
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
8698
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
8699
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
8700
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
8701
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
8702
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
8703
+ const uint16_t ls1 = aux32[0] >> 28;
8704
+ const uint16_t ls2 = aux32[1] >> 28;
8705
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
8706
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
8707
+ sumi1 = _mm256_add_epi32(sumi1, p1);
8708
+ sumi2 = _mm256_add_epi32(sumi2, p2);
8709
+ }
8710
+
8711
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
8712
+
8713
+ }
8714
+
8715
+ *s = 0.25f * hsum_float_8(accumf);
8716
+
8717
+ #else
8718
+
8719
+ uint32_t aux32;
8720
+
8721
+ float sumf = 0.f;
8722
+ for (int i = 0; i < nb; ++i) {
8723
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8724
+ const uint8_t * restrict q3 = x[i].qs;
8725
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
8726
+ const int8_t * restrict q8 = y[i].qs;
8727
+ int32_t bsum = 0;
8728
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
8729
+ memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
8730
+ const uint32_t ls = 2*(aux32 >> 28) + 1;
8731
+ int32_t sumi = 0;
8732
+ for (int l = 0; l < 4; ++l) {
8733
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
8734
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
8735
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
8736
+ for (int j = 0; j < 4; ++j) {
8737
+ sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
8738
+ sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
8739
+ }
8740
+ q8 += 8;
8741
+ }
8742
+ q3 += 8;
8743
+ bsum += sumi * ls;
8744
+ }
8745
+ sumf += d * bsum;
8746
+ }
8747
+ *s = 0.25f * sumf;
8748
+ #endif
8749
+ }
8750
+
8751
  // ================================ IQ2 quantization =============================================
8752
 
8753
  typedef struct {
 
9386
  return nrow * nblock * sizeof(block_iq2_xs);
9387
  }
9388
 
9389
+ //
9390
+ // ============================================= 3-bit using D4 lattice
9391
+ //
9392
+
9393
+ typedef struct {
9394
+ uint32_t * grid;
9395
+ int * map;
9396
+ uint16_t * neighbours;
9397
+ } iq3_entry_t;
9398
+
9399
+ static iq3_entry_t iq3_data[1] = {
9400
+ {NULL, NULL, NULL},
9401
+ };
9402
+
9403
+ static inline int iq3_data_index(int grid_size) {
9404
+ (void)grid_size;
9405
+ GGML_ASSERT(grid_size == 256);
9406
+ return 0;
9407
+ }
9408
+
9409
+ static int iq3_compare_func(const void * left, const void * right) {
9410
+ const int * l = (const int *)left;
9411
+ const int * r = (const int *)right;
9412
+ return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
9413
+ }
9414
+
9415
+ void iq3xs_init_impl(int grid_size) {
9416
+ const int gindex = iq3_data_index(grid_size);
9417
+ if (iq3_data[gindex].grid) {
9418
+ return;
9419
+ }
9420
+ static const uint16_t kgrid_256[256] = {
9421
+ 0, 2, 4, 9, 11, 15, 16, 18, 25, 34, 59, 61, 65, 67, 72, 74,
9422
+ 81, 85, 88, 90, 97, 108, 120, 128, 130, 132, 137, 144, 146, 153, 155, 159,
9423
+ 169, 175, 189, 193, 199, 200, 202, 213, 248, 267, 287, 292, 303, 315, 317, 321,
9424
+ 327, 346, 362, 413, 436, 456, 460, 462, 483, 497, 513, 515, 520, 522, 529, 531,
9425
+ 536, 538, 540, 551, 552, 576, 578, 585, 592, 594, 641, 643, 648, 650, 657, 664,
9426
+ 698, 704, 706, 720, 729, 742, 758, 769, 773, 808, 848, 852, 870, 889, 901, 978,
9427
+ 992, 1024, 1026, 1033, 1035, 1040, 1042, 1046, 1049, 1058, 1089, 1091, 1093, 1096, 1098, 1105,
9428
+ 1112, 1139, 1143, 1144, 1152, 1154, 1161, 1167, 1168, 1170, 1183, 1184, 1197, 1217, 1224, 1228,
9429
+ 1272, 1276, 1309, 1323, 1347, 1367, 1377, 1404, 1473, 1475, 1486, 1509, 1537, 1544, 1546, 1553,
9430
+ 1555, 1576, 1589, 1594, 1600, 1602, 1616, 1625, 1636, 1638, 1665, 1667, 1672, 1685, 1706, 1722,
9431
+ 1737, 1755, 1816, 1831, 1850, 1856, 1862, 1874, 1901, 1932, 1950, 1971, 2011, 2032, 2052, 2063,
9432
+ 2077, 2079, 2091, 2095, 2172, 2192, 2207, 2208, 2224, 2230, 2247, 2277, 2308, 2345, 2356, 2389,
9433
+ 2403, 2424, 2501, 2504, 2506, 2520, 2570, 2593, 2616, 2624, 2630, 2646, 2669, 2700, 2714, 2746,
9434
+ 2754, 2795, 2824, 2835, 2839, 2874, 2882, 2905, 2984, 3028, 3042, 3092, 3108, 3110, 3124, 3153,
9435
+ 3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610,
9436
+ 3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992,
9437
+ };
9438
+ const int kmap_size = 4096;
9439
+ const int nwant = 2;
9440
+ const uint16_t * kgrid = kgrid_256;
9441
+ uint32_t * kgrid_q3xs;
9442
+ int * kmap_q3xs;
9443
+ uint16_t * kneighbors_q3xs;
9444
+
9445
+ printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
9446
+ uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t));
9447
+ for (int k = 0; k < grid_size; ++k) {
9448
+ int8_t * pos = (int8_t *)(the_grid + k);
9449
+ for (int i = 0; i < 4; ++i) {
9450
+ int l = (kgrid[k] >> 3*i) & 0x7;
9451
+ pos[i] = 2*l + 1;
9452
+ }
9453
+ }
9454
+ kgrid_q3xs = the_grid;
9455
+ iq3_data[gindex].grid = the_grid;
9456
+ kmap_q3xs = (int *)malloc(kmap_size*sizeof(int));
9457
+ iq3_data[gindex].map = kmap_q3xs;
9458
+ for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1;
9459
+ uint32_t aux32;
9460
+ uint8_t * aux8 = (uint8_t *)&aux32;
9461
+ for (int i = 0; i < grid_size; ++i) {
9462
+ aux32 = kgrid_q3xs[i];
9463
+ uint16_t index = 0;
9464
+ for (int k=0; k<4; ++k) {
9465
+ uint16_t q = (aux8[k] - 1)/2;
9466
+ index |= (q << 3*k);
9467
+ }
9468
+ kmap_q3xs[index] = i;
9469
+ }
9470
+ int8_t pos[4];
9471
+ int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
9472
+ int num_neighbors = 0, num_not_in_map = 0;
9473
+ for (int i = 0; i < kmap_size; ++i) {
9474
+ if (kmap_q3xs[i] >= 0) continue;
9475
+ ++num_not_in_map;
9476
+ for (int k = 0; k < 4; ++k) {
9477
+ int l = (i >> 3*k) & 0x7;
9478
+ pos[k] = 2*l + 1;
9479
+ }
9480
+ for (int j = 0; j < grid_size; ++j) {
9481
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
9482
+ int d2 = 0;
9483
+ for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
9484
+ dist2[2*j+0] = d2;
9485
+ dist2[2*j+1] = j;
9486
+ }
9487
+ qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
9488
+ int n = 0; int d2 = dist2[0];
9489
+ int nhave = 1;
9490
+ for (int j = 0; j < grid_size; ++j) {
9491
+ if (dist2[2*j] > d2) {
9492
+ if (nhave == nwant) break;
9493
+ d2 = dist2[2*j];
9494
+ ++nhave;
9495
+ }
9496
+ ++n;
9497
+ }
9498
+ num_neighbors += n;
9499
+ }
9500
+ printf("%s: %d neighbours in total\n", __func__, num_neighbors);
9501
+ kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
9502
+ iq3_data[gindex].neighbours = kneighbors_q3xs;
9503
+ int counter = 0;
9504
+ for (int i = 0; i < kmap_size; ++i) {
9505
+ if (kmap_q3xs[i] >= 0) continue;
9506
+ for (int k = 0; k < 4; ++k) {
9507
+ int l = (i >> 3*k) & 0x7;
9508
+ pos[k] = 2*l + 1;
9509
+ }
9510
+ for (int j = 0; j < grid_size; ++j) {
9511
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
9512
+ int d2 = 0;
9513
+ for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
9514
+ dist2[2*j+0] = d2;
9515
+ dist2[2*j+1] = j;
9516
+ }
9517
+ qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
9518
+ kmap_q3xs[i] = -(counter + 1);
9519
+ int d2 = dist2[0];
9520
+ uint16_t * start = &kneighbors_q3xs[counter++];
9521
+ int n = 0, nhave = 1;
9522
+ for (int j = 0; j < grid_size; ++j) {
9523
+ if (dist2[2*j] > d2) {
9524
+ if (nhave == nwant) break;
9525
+ d2 = dist2[2*j];
9526
+ ++nhave;
9527
+ }
9528
+ kneighbors_q3xs[counter++] = dist2[2*j+1];
9529
+ ++n;
9530
+ }
9531
+ *start = n;
9532
+ }
9533
+ free(dist2);
9534
+ }
9535
+
9536
+ void iq3xs_free_impl(int grid_size) {
9537
+ GGML_ASSERT(grid_size == 256);
9538
+ const int gindex = iq3_data_index(grid_size);
9539
+ if (iq3_data[gindex].grid) {
9540
+ free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL;
9541
+ free(iq3_data[gindex].map); iq3_data[gindex].map = NULL;
9542
+ free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL;
9543
+ }
9544
+ }
9545
+
9546
+ static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const uint32_t * restrict grid,
9547
+ const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
9548
+ int num_neighbors = neighbours[0];
9549
+ GGML_ASSERT(num_neighbors > 0);
9550
+ float best_d2 = FLT_MAX;
9551
+ int grid_index = -1;
9552
+ for (int j = 1; j <= num_neighbors; ++j) {
9553
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
9554
+ float d2 = 0;
9555
+ for (int i = 0; i < 4; ++i) {
9556
+ float q = pg[i];
9557
+ float diff = scale*q - xval[i];
9558
+ d2 += weight[i]*diff*diff;
9559
+ }
9560
+ if (d2 < best_d2) {
9561
+ best_d2 = d2; grid_index = neighbours[j];
9562
+ }
9563
+ }
9564
+ GGML_ASSERT(grid_index >= 0);
9565
+ const int8_t * pg = (const int8_t *)(grid + grid_index);
9566
+ for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2;
9567
+ return grid_index;
9568
+ }
9569
+
9570
+ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
9571
+
9572
+ const int gindex = iq3_data_index(256);
9573
+
9574
+ const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
9575
+ const int * kmap_q3xs = iq3_data[gindex].map;
9576
+ const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
9577
+
9578
+ //GGML_ASSERT(quant_weights && "missing quantization weights");
9579
+ GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?");
9580
+ GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?");
9581
+ GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
9582
+ GGML_ASSERT(n%QK_K == 0);
9583
+
9584
+ const int kMaxQ = 8;
9585
+
9586
+ const int nbl = n/256;
9587
+
9588
+ block_iq3_xxs * y = vy;
9589
+
9590
+ float scales[QK_K/32];
9591
+ float weight[32];
9592
+ float xval[32];
9593
+ int8_t L[32];
9594
+ int8_t Laux[32];
9595
+ float waux[32];
9596
+ bool is_on_grid[8];
9597
+ bool is_on_grid_aux[8];
9598
+ uint8_t block_signs[8];
9599
+ uint8_t q3[3*(QK_K/8)];
9600
+ uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
9601
+
9602
+ for (int ibl = 0; ibl < nbl; ++ibl) {
9603
+
9604
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
9605
+ memset(q3, 0, 3*QK_K/8);
9606
+
9607
+ float max_scale = 0;
9608
+
9609
+ const float * xbl = x + QK_K*ibl;
9610
+ float sumx2 = 0;
9611
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
9612
+ float sigma2 = sumx2/QK_K;
9613
+
9614
+ for (int ib = 0; ib < QK_K/32; ++ib) {
9615
+ const float * xb = xbl + 32*ib;
9616
+ if (quant_weights) {
9617
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
9618
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
9619
+ } else {
9620
+ for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
9621
+ }
9622
+ for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
9623
+ for (int k = 0; k < 4; ++k) {
9624
+ int nflip = 0;
9625
+ uint8_t s = 0;
9626
+ for (int i = 0; i < 8; ++i) {
9627
+ if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
9628
+ else {
9629
+ xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
9630
+ }
9631
+ }
9632
+ if (nflip%2) {
9633
+ int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
9634
+ for (int i = 1; i < 8; ++i) {
9635
+ float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
9636
+ if (ax < min) {
9637
+ min = ax; imin = i;
9638
+ }
9639
+ }
9640
+ xval[8*k+imin] = -xval[8*k+imin];
9641
+ s ^= (1 << imin);
9642
+ }
9643
+ block_signs[k] = s & 127;
9644
+ }
9645
+ float max = xval[0];
9646
+ for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
9647
+ if (!max) {
9648
+ scales[ib] = 0;
9649
+ memset(L, 0, 32);
9650
+ continue;
9651
+ }
9652
+ float best = 0;
9653
+ float scale = max/(2*kMaxQ-1);
9654
+ for (int is = -15; is <= 15; ++is) {
9655
+ float id = (2*kMaxQ-1+is*0.2f)/max;
9656
+ float this_scale = 1/id;
9657
+ for (int k = 0; k < 8; ++k) {
9658
+ for (int i = 0; i < 4; ++i) {
9659
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
9660
+ Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
9661
+ }
9662
+ uint16_t u = 0;
9663
+ for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
9664
+ int grid_index = kmap_q3xs[u];
9665
+ is_on_grid_aux[k] = true;
9666
+ if (grid_index < 0) {
9667
+ is_on_grid_aux[k] = false;
9668
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
9669
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
9670
+ }
9671
+ }
9672
+ float sumqx = 0, sumq2 = 0;
9673
+ for (int i = 0; i < 32; ++i) {
9674
+ float w = weight[i];
9675
+ float q = 2*Laux[i] + 1;
9676
+ sumqx += w*xval[i]*q;
9677
+ sumq2 += w*q*q;
9678
+ }
9679
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
9680
+ scale = sumqx/sumq2; best = scale*sumqx;
9681
+ for (int i = 0; i < 32; ++i) L[i] = Laux[i];
9682
+ for (int k = 0; k < 8; ++k) is_on_grid[k] = is_on_grid_aux[k];
9683
+ }
9684
+ }
9685
+ int n_not_ongrid = 0;
9686
+ for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
9687
+ if (n_not_ongrid > 0 && scale > 0) {
9688
+ float id = 1/scale;
9689
+ for (int k = 0; k < 8; ++k) {
9690
+ if (is_on_grid[k]) continue;
9691
+ uint16_t u = 0;
9692
+ for (int i = 0; i < 4; ++i) {
9693
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
9694
+ l = MAX(0, MIN(kMaxQ-1, l));
9695
+ u |= (l << 3*i);
9696
+ }
9697
+ int grid_index = kmap_q3xs[u];
9698
+ if (grid_index < 0) {
9699
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
9700
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
9701
+ }
9702
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
9703
+ for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
9704
+ }
9705
+ float sumqx = 0, sumq2 = 0;
9706
+ for (int i = 0; i < 32; ++i) {
9707
+ float w = weight[i];
9708
+ float q = 2*L[i] + 1;
9709
+ sumqx += w*xval[i]*q;
9710
+ sumq2 += w*q*q;
9711
+ }
9712
+ if (sumq2 > 0) scale = sumqx/sumq2;
9713
+ }
9714
+ if (scale < 0) {
9715
+ // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
9716
+ // and correspondingly flip quant signs.
9717
+ scale = -scale;
9718
+ for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
9719
+ }
9720
+ for (int k = 0; k < 8; ++k) {
9721
+ uint16_t u = 0;
9722
+ for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
9723
+ int grid_index = kmap_q3xs[u];
9724
+ if (grid_index < 0) {
9725
+ printf("Oops: found point %u not on grid:", u);
9726
+ for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
9727
+ printf("\n");
9728
+ GGML_ASSERT(false);
9729
+ }
9730
+ q3[8*ib+k] = grid_index;
9731
+ }
9732
+ scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
9733
+ GGML_ASSERT(scale >= 0);
9734
+ scales[ib] = scale;
9735
+ max_scale = MAX(max_scale, scale);
9736
+ }
9737
+
9738
+ if (!max_scale) {
9739
+ memset(y[ibl].qs, 0, 3*QK_K/8);
9740
+ continue;
9741
+ }
9742
+
9743
+ float d = max_scale/31;
9744
+ y[ibl].d = GGML_FP32_TO_FP16(d);
9745
+ float id = 1/d;
9746
+ float sumqx = 0, sumq2 = 0;
9747
+ for (int ib = 0; ib < QK_K/32; ++ib) {
9748
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
9749
+ l = MAX(0, MIN(15, l));
9750
+ scales_and_signs[ib] |= ((uint32_t)l << 28);
9751
+ if (false) {
9752
+ const float * xb = xbl + 32*ib;
9753
+ if (quant_weights) {
9754
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
9755
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
9756
+ } else {
9757
+ for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
9758
+ }
9759
+ const float db = 0.25f * d * (1 + 2*l);
9760
+ for (int k = 0; k < 8; ++k) {
9761
+ const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2);
9762
+ const float * xk = xb + 4*k;
9763
+ const float * wk = weight + 4*k;
9764
+ //const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]);
9765
+ const uint8_t * grid = (const uint8_t *)(iq3xxs_grid + q3[8*ib+k]);
9766
+ float best_mse = 0; int best_index = q3[8*ib+k];
9767
+ for (int j = 0; j < 4; ++j) {
9768
+ float diff = db * grid[j] * signs[j] - xk[j];
9769
+ best_mse += wk[j] * diff * diff;
9770
+ }
9771
+ for (int idx = 0; idx < 256; ++idx) {
9772
+ //grid = (const uint8_t *)(kgrid_q3xs + idx);
9773
+ grid = (const uint8_t *)(iq3xxs_grid + idx);
9774
+ float mse = 0;
9775
+ for (int j = 0; j < 4; ++j) {
9776
+ float diff = db * grid[j] * signs[j] - xk[j];
9777
+ mse += wk[j] * diff * diff;
9778
+ }
9779
+ if (mse < best_mse) {
9780
+ best_mse = mse; best_index = idx;
9781
+ }
9782
+ }
9783
+ q3[8*ib+k] = best_index;
9784
+ //grid = (const uint8_t *)(kgrid_q3xs + best_index);
9785
+ grid = (const uint8_t *)(iq3xxs_grid + best_index);
9786
+ for (int j = 0; j < 4; ++j) {
9787
+ float q = db * grid[j] * signs[j];
9788
+ sumqx += wk[j] * q * xk[j];
9789
+ sumq2 += wk[j] * q * q;
9790
+ }
9791
+ }
9792
+ if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
9793
+ }
9794
+ }
9795
+ memcpy(y[ibl].qs, q3, 3*QK_K/8);
9796
+ }
9797
+ }
9798
+
9799
+ size_t quantize_iq3_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
9800
+ (void)hist;
9801
+ GGML_ASSERT(n_per_row%QK_K == 0);
9802
+ int nblock = n_per_row/QK_K;
9803
+ char * qrow = (char *)dst;
9804
+ for (int row = 0; row < nrow; ++row) {
9805
+ quantize_row_iq3_xxs_impl(src, qrow, n_per_row, quant_weights);
9806
+ src += n_per_row;
9807
+ qrow += nblock*sizeof(block_iq3_xxs);
9808
+ }
9809
+ return nrow * nblock * sizeof(block_iq3_xxs);
9810
+ }
9811
+
9812
+ void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) {
9813
+ assert(k % QK_K == 0);
9814
+ block_iq3_xxs * restrict y = vy;
9815
+ quantize_row_iq3_xxs_reference(x, y, k);
9816
+ }
9817
+
9818
+ void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
9819
+ assert(k % QK_K == 0);
9820
+ quantize_row_iq3_xxs_impl(x, y, k, NULL);
9821
+ }
ggml-quants.h CHANGED
@@ -166,7 +166,7 @@ typedef struct {
166
  static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
167
 
168
  // (Almost) "true" 2-bit quantization.
169
- // Due to the need to use blocks as per ggml dsign, it ends up using
170
  // 2.0625 bpw because of the 16-bit scale for each block of 256.
171
  typedef struct {
172
  ggml_fp16_t d;
@@ -182,6 +182,15 @@ typedef struct {
182
  } block_iq2_xs;
183
  static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
184
 
 
 
 
 
 
 
 
 
 
185
  // Quantization
186
  void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
187
  void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
@@ -196,6 +205,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
196
  void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
197
  void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
198
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
 
199
 
200
  void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
201
  void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
@@ -210,6 +220,7 @@ void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
210
  void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
211
  void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
212
  void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
 
213
 
214
  // Dequantization
215
  void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
@@ -227,6 +238,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
227
  void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
228
  void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
229
  void dequantize_row_iq2_xs (const block_iq2_xs * restrict x, float * restrict y, int k);
 
230
 
231
  // Dot product
232
  void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
@@ -242,12 +254,14 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx,
242
  void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
243
  void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
244
  void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 
245
 
246
  //
247
  // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
248
  //
249
  size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
250
  size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 
251
  size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
252
  size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
253
  size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
@@ -260,3 +274,5 @@ size_t quantize_q5_1 (const float * src, void * dst, int nrows, int n_per_row,
260
 
261
  void iq2xs_init_impl(int grid_size);
262
  void iq2xs_free_impl(int grid_size);
 
 
 
166
  static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
167
 
168
  // (Almost) "true" 2-bit quantization.
169
+ // Due to the need to use blocks as per ggml design, it ends up using
170
  // 2.0625 bpw because of the 16-bit scale for each block of 256.
171
  typedef struct {
172
  ggml_fp16_t d;
 
182
  } block_iq2_xs;
183
  static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
184
 
185
+ // (Almost) "true" 3-bit quantization.
186
+ // Due to the need to use blocks as per ggml design, it ends up using
187
+ // 3.0625 bpw because of the 16-bit scale for each block of 256.
188
+ typedef struct {
189
+ ggml_fp16_t d;
190
+ uint8_t qs[3*QK_K/8];
191
+ } block_iq3_xxs;
192
+ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
193
+
194
  // Quantization
195
  void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
196
  void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
 
205
  void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
206
  void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
207
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
208
+ void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k);
209
 
210
  void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
211
  void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
 
220
  void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
221
  void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
222
  void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
223
+ void quantize_row_iq3_xxs(const float * restrict x, void * restrict y, int k);
224
 
225
  // Dequantization
226
  void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
 
238
  void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
239
  void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
240
  void dequantize_row_iq2_xs (const block_iq2_xs * restrict x, float * restrict y, int k);
241
+ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k);
242
 
243
  // Dot product
244
  void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 
254
  void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
255
  void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
256
  void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy);
257
+ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
258
 
259
  //
260
  // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
261
  //
262
  size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
263
  size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
264
+ size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
265
  size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
266
  size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
267
  size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 
274
 
275
  void iq2xs_init_impl(int grid_size);
276
  void iq2xs_free_impl(int grid_size);
277
+ void iq3xs_init_impl(int grid_size);
278
+ void iq3xs_free_impl(int grid_size);
ggml.c CHANGED
@@ -632,6 +632,17 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
632
  .vec_dot = ggml_vec_dot_iq2_xs_q8_K,
633
  .vec_dot_type = GGML_TYPE_Q8_K,
634
  },
 
 
 
 
 
 
 
 
 
 
 
635
  [GGML_TYPE_Q8_K] = {
636
  .type_name = "q8_K",
637
  .blck_size = QK_K,
@@ -2177,6 +2188,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
2177
  case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
2178
  case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
2179
  case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
 
2180
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
2181
  case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
2182
  }
@@ -7570,6 +7582,7 @@ static void ggml_compute_forward_add(
7570
  case GGML_TYPE_Q6_K:
7571
  case GGML_TYPE_IQ2_XXS:
7572
  case GGML_TYPE_IQ2_XS:
 
7573
  {
7574
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
7575
  } break;
@@ -7836,6 +7849,7 @@ static void ggml_compute_forward_add1(
7836
  case GGML_TYPE_Q6_K:
7837
  case GGML_TYPE_IQ2_XXS:
7838
  case GGML_TYPE_IQ2_XS:
 
7839
  {
7840
  ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
7841
  } break;
@@ -7955,6 +7969,7 @@ static void ggml_compute_forward_acc(
7955
  case GGML_TYPE_Q6_K:
7956
  case GGML_TYPE_IQ2_XXS:
7957
  case GGML_TYPE_IQ2_XS:
 
7958
  default:
7959
  {
7960
  GGML_ASSERT(false);
@@ -10706,6 +10721,7 @@ static void ggml_compute_forward_out_prod(
10706
  case GGML_TYPE_Q6_K:
10707
  case GGML_TYPE_IQ2_XXS:
10708
  case GGML_TYPE_IQ2_XS:
 
10709
  {
10710
  ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
10711
  } break;
@@ -10885,6 +10901,7 @@ static void ggml_compute_forward_set(
10885
  case GGML_TYPE_Q6_K:
10886
  case GGML_TYPE_IQ2_XXS:
10887
  case GGML_TYPE_IQ2_XS:
 
10888
  default:
10889
  {
10890
  GGML_ASSERT(false);
@@ -11081,6 +11098,7 @@ static void ggml_compute_forward_get_rows(
11081
  case GGML_TYPE_Q6_K:
11082
  case GGML_TYPE_IQ2_XXS:
11083
  case GGML_TYPE_IQ2_XS:
 
11084
  {
11085
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
11086
  } break;
@@ -11728,6 +11746,7 @@ static void ggml_compute_forward_alibi(
11728
  case GGML_TYPE_Q6_K:
11729
  case GGML_TYPE_IQ2_XXS:
11730
  case GGML_TYPE_IQ2_XS:
 
11731
  case GGML_TYPE_Q8_K:
11732
  case GGML_TYPE_I8:
11733
  case GGML_TYPE_I16:
@@ -11804,6 +11823,7 @@ static void ggml_compute_forward_clamp(
11804
  case GGML_TYPE_Q6_K:
11805
  case GGML_TYPE_IQ2_XXS:
11806
  case GGML_TYPE_IQ2_XS:
 
11807
  case GGML_TYPE_Q8_K:
11808
  case GGML_TYPE_I8:
11809
  case GGML_TYPE_I16:
@@ -18860,6 +18880,7 @@ void ggml_quantize_init(enum ggml_type type) {
18860
  switch (type) {
18861
  case GGML_TYPE_IQ2_XXS: iq2xs_init_impl(256); break;
18862
  case GGML_TYPE_IQ2_XS: iq2xs_init_impl(512); break;
 
18863
  default: // nothing
18864
  break;
18865
  }
@@ -19122,6 +19143,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
19122
  result = quantize_iq2_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19123
  GGML_ASSERT(result == row_size * nrows);
19124
  } break;
 
 
 
 
 
 
 
 
 
19125
  case GGML_TYPE_F16:
19126
  {
19127
  size_t elemsize = sizeof(ggml_fp16_t);
 
632
  .vec_dot = ggml_vec_dot_iq2_xs_q8_K,
633
  .vec_dot_type = GGML_TYPE_Q8_K,
634
  },
635
+ [GGML_TYPE_IQ3_XXS] = {
636
+ .type_name = "iq3_xxs",
637
+ .blck_size = QK_K,
638
+ .type_size = sizeof(block_iq3_xxs),
639
+ .is_quantized = true,
640
+ .to_float = (ggml_to_float_t) dequantize_row_iq3_xxs,
641
+ .from_float = quantize_row_iq3_xxs,
642
+ .from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference,
643
+ .vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
644
+ .vec_dot_type = GGML_TYPE_Q8_K,
645
+ },
646
  [GGML_TYPE_Q8_K] = {
647
  .type_name = "q8_K",
648
  .blck_size = QK_K,
 
2188
  case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
2189
  case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
2190
  case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
2191
+ case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
2192
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
2193
  case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
2194
  }
 
7582
  case GGML_TYPE_Q6_K:
7583
  case GGML_TYPE_IQ2_XXS:
7584
  case GGML_TYPE_IQ2_XS:
7585
+ case GGML_TYPE_IQ3_XXS:
7586
  {
7587
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
7588
  } break;
 
7849
  case GGML_TYPE_Q6_K:
7850
  case GGML_TYPE_IQ2_XXS:
7851
  case GGML_TYPE_IQ2_XS:
7852
+ case GGML_TYPE_IQ3_XXS:
7853
  {
7854
  ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
7855
  } break;
 
7969
  case GGML_TYPE_Q6_K:
7970
  case GGML_TYPE_IQ2_XXS:
7971
  case GGML_TYPE_IQ2_XS:
7972
+ case GGML_TYPE_IQ3_XXS:
7973
  default:
7974
  {
7975
  GGML_ASSERT(false);
 
10721
  case GGML_TYPE_Q6_K:
10722
  case GGML_TYPE_IQ2_XXS:
10723
  case GGML_TYPE_IQ2_XS:
10724
+ case GGML_TYPE_IQ3_XXS:
10725
  {
10726
  ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
10727
  } break;
 
10901
  case GGML_TYPE_Q6_K:
10902
  case GGML_TYPE_IQ2_XXS:
10903
  case GGML_TYPE_IQ2_XS:
10904
+ case GGML_TYPE_IQ3_XXS:
10905
  default:
10906
  {
10907
  GGML_ASSERT(false);
 
11098
  case GGML_TYPE_Q6_K:
11099
  case GGML_TYPE_IQ2_XXS:
11100
  case GGML_TYPE_IQ2_XS:
11101
+ case GGML_TYPE_IQ3_XXS:
11102
  {
11103
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
11104
  } break;
 
11746
  case GGML_TYPE_Q6_K:
11747
  case GGML_TYPE_IQ2_XXS:
11748
  case GGML_TYPE_IQ2_XS:
11749
+ case GGML_TYPE_IQ3_XXS:
11750
  case GGML_TYPE_Q8_K:
11751
  case GGML_TYPE_I8:
11752
  case GGML_TYPE_I16:
 
11823
  case GGML_TYPE_Q6_K:
11824
  case GGML_TYPE_IQ2_XXS:
11825
  case GGML_TYPE_IQ2_XS:
11826
+ case GGML_TYPE_IQ3_XXS:
11827
  case GGML_TYPE_Q8_K:
11828
  case GGML_TYPE_I8:
11829
  case GGML_TYPE_I16:
 
18880
  switch (type) {
18881
  case GGML_TYPE_IQ2_XXS: iq2xs_init_impl(256); break;
18882
  case GGML_TYPE_IQ2_XS: iq2xs_init_impl(512); break;
18883
+ case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
18884
  default: // nothing
18885
  break;
18886
  }
 
19143
  result = quantize_iq2_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19144
  GGML_ASSERT(result == row_size * nrows);
19145
  } break;
19146
+ case GGML_TYPE_IQ3_XXS:
19147
+ {
19148
+ GGML_ASSERT(start % QK_K == 0);
19149
+ GGML_ASSERT(start % n_per_row == 0);
19150
+ size_t start_row = start / n_per_row;
19151
+ size_t row_size = ggml_row_size(type, n_per_row);
19152
+ result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19153
+ GGML_ASSERT(result == row_size * nrows);
19154
+ } break;
19155
  case GGML_TYPE_F16:
19156
  {
19157
  size_t elemsize = sizeof(ggml_fp16_t);
ggml.h CHANGED
@@ -353,6 +353,7 @@ extern "C" {
353
  GGML_TYPE_Q8_K = 15,
354
  GGML_TYPE_IQ2_XXS = 16,
355
  GGML_TYPE_IQ2_XS = 17,
 
356
  GGML_TYPE_I8,
357
  GGML_TYPE_I16,
358
  GGML_TYPE_I32,
@@ -389,6 +390,7 @@ extern "C" {
389
  GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
390
  GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
391
  GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
 
392
  };
393
 
394
  // available tensor operations:
 
353
  GGML_TYPE_Q8_K = 15,
354
  GGML_TYPE_IQ2_XXS = 16,
355
  GGML_TYPE_IQ2_XS = 17,
356
+ GGML_TYPE_IQ3_XXS = 18,
357
  GGML_TYPE_I8,
358
  GGML_TYPE_I16,
359
  GGML_TYPE_I32,
 
390
  GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
391
  GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
392
  GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
393
+ GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
394
  };
395
 
396
  // available tensor operations: