JohannesGaessler commited on
Commit
154bf2b
·
1 Parent(s): 78a5b67

CUDA: int8 tensor cores for MMQ (q4_K, q5_K, q6_K) (llama/7860)

Browse files
Files changed (2) hide show
  1. ggml-cuda/mma.cuh +66 -0
  2. ggml-cuda/mmq.cuh +294 -6
ggml-cuda/mma.cuh CHANGED
@@ -1,5 +1,27 @@
1
  #include "common.cuh"
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  struct mma_int_A_I16K8 {
4
  static constexpr int I = 16;
5
  static constexpr int K = 8;
@@ -22,6 +44,28 @@ struct mma_int_A_I16K8 {
22
  }
23
  };
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  struct mma_int_B_J8K8 {
26
  static constexpr int J = 8;
27
  static constexpr int K = 8;
@@ -65,6 +109,28 @@ struct mma_int_C_I16J8 {
65
  return ret;
66
  }
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
69
  #ifdef INT8_MMA_AVAILABLE
70
  #if __CUDA_ARCH__ >= CC_AMPERE
 
1
  #include "common.cuh"
2
 
3
+ struct mma_int_A_I16K4 {
4
+ static constexpr int I = 16;
5
+ static constexpr int K = 4;
6
+ static constexpr int ne = 2;
7
+
8
+ int x[ne] = {0};
9
+
10
+ static __device__ __forceinline__ int get_i(const int l) {
11
+ const int ret = (l%2) * (I/2) + threadIdx.x / K;
12
+ GGML_CUDA_ASSUME(ret >= 0);
13
+ GGML_CUDA_ASSUME(ret < I);
14
+ return ret;
15
+ }
16
+
17
+ static __device__ __forceinline__ int get_k(const int /* l */) {
18
+ const int ret = threadIdx.x % K;
19
+ GGML_CUDA_ASSUME(ret >= 0);
20
+ GGML_CUDA_ASSUME(ret < K);
21
+ return ret;
22
+ }
23
+ };
24
+
25
  struct mma_int_A_I16K8 {
26
  static constexpr int I = 16;
27
  static constexpr int K = 8;
 
44
  }
45
  };
46
 
47
+ struct mma_int_B_J8K4 {
48
+ static constexpr int J = 8;
49
+ static constexpr int K = 4;
50
+ static constexpr int ne = 1;
51
+
52
+ int x[ne] = {0};
53
+
54
+ static __device__ __forceinline__ int get_j(const int /* l */) {
55
+ const int ret = threadIdx.x / K;
56
+ GGML_CUDA_ASSUME(ret >= 0);
57
+ GGML_CUDA_ASSUME(ret < J);
58
+ return ret;
59
+ }
60
+
61
+ static __device__ __forceinline__ int get_k(const int /* l */) {
62
+ const int ret = threadIdx.x % K;
63
+ GGML_CUDA_ASSUME(ret >= 0);
64
+ GGML_CUDA_ASSUME(ret < K);
65
+ return ret;
66
+ }
67
+ };
68
+
69
  struct mma_int_B_J8K8 {
70
  static constexpr int J = 8;
71
  static constexpr int K = 8;
 
109
  return ret;
110
  }
111
 
112
+ __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
113
+ #ifdef INT8_MMA_AVAILABLE
114
+ #if __CUDA_ARCH__ >= CC_AMPERE
115
+ asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
116
+ : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
117
+ : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
118
+ #else
119
+ // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
120
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
121
+ : "+r"(x[0]), "+r"(x[1])
122
+ : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
123
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
124
+ : "+r"(x[2]), "+r"(x[3])
125
+ : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
126
+ #endif // __CUDA_ARCH__ >= CC_AMPERE
127
+ #else
128
+ GGML_UNUSED(mma_A);
129
+ GGML_UNUSED(mma_B);
130
+ NO_DEVICE_CODE;
131
+ #endif // INT8_MMA_AVAILABLE
132
+ }
133
+
134
  __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
135
  #ifdef INT8_MMA_AVAILABLE
136
  #if __CUDA_ARCH__ >= CC_AMPERE
ggml-cuda/mmq.cuh CHANGED
@@ -1089,7 +1089,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1089
  }
1090
 
1091
  template <int mmq_x, int mmq_y, int nwarps>
1092
- static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
1093
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
1094
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1095
 
@@ -1115,6 +1115,97 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
1115
  }
1116
  }
1117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1118
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
1119
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1120
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@@ -1188,7 +1279,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1188
  }
1189
 
1190
  template <int mmq_x, int mmq_y, int nwarps>
1191
- static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
1192
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
1193
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1194
 
@@ -1214,6 +1305,97 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
1214
  }
1215
  }
1216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1217
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
1218
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1219
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@@ -1280,7 +1462,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1280
  }
1281
 
1282
  template <int mmq_x, int mmq_y, int nwarps>
1283
- static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
1284
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
1285
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1286
 
@@ -1307,6 +1489,97 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
1307
  }
1308
  }
1309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1310
  template<int mmq_x, int mmq_y, int nwarps, bool need_check>
1311
  static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
1312
  #pragma unroll
@@ -1448,24 +1721,39 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1448
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
1449
  static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
1450
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
1451
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
 
 
 
1452
  static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
 
1453
  };
1454
 
1455
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1456
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
1457
  static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
1458
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
1459
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
 
 
 
1460
  static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
 
1461
  };
1462
 
1463
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1464
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
1465
  static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
1466
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
1467
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
 
 
 
1468
  static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
 
1469
  };
1470
 
1471
  static int mmq_need_sum(const ggml_type type_x) {
 
1089
  }
1090
 
1091
  template <int mmq_x, int mmq_y, int nwarps>
1092
+ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1093
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
1094
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1095
 
 
1115
  }
1116
  }
1117
 
1118
+ template <int mmq_x, int mmq_y, int nwarps>
1119
+ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
1120
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
1121
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1122
+
1123
+ GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
1124
+
1125
+ typedef mma_int_A_I16K8 mma_A;
1126
+ typedef mma_int_B_J8K8 mma_B;
1127
+ typedef mma_int_C_I16J8 mma_C;
1128
+
1129
+ const int * y_qs = (const int *) y + 4;
1130
+ const half2 * y_ds = (const half2 *) y;
1131
+
1132
+ const int i0 = threadIdx.y*mma_A::I;
1133
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
1134
+
1135
+ mma_A A[2];
1136
+ int scA[mma_C::ne/2][2];
1137
+ int mA[mma_C::ne/2][2];
1138
+ half2 dmA[mma_C::ne/2];
1139
+ #pragma unroll
1140
+ for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
1141
+ #pragma unroll
1142
+ for (int l = 0; l < mma_A::ne; ++l) {
1143
+ const int i = i0 + mma_A::get_i(l);
1144
+ const int k = k0 + mma_A::get_k(l);
1145
+
1146
+ A[kvdr/4].x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
1147
+ }
1148
+
1149
+ #pragma unroll
1150
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1151
+ const int i = i0 + mma_C::get_i(2*l);
1152
+
1153
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
1154
+ const uint8_t * m = sc + 8;
1155
+
1156
+ scA[l][kvdr/4] = sc[kvdr/4];
1157
+ mA[l][kvdr/4] = m[kvdr/4];
1158
+ }
1159
+ }
1160
+
1161
+ #pragma unroll
1162
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1163
+ const int i = i0 + mma_C::get_i(2*l);
1164
+
1165
+ dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
1166
+ }
1167
+
1168
+ #pragma unroll
1169
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
1170
+ float tmpd[mma_C::ne] = {0.0f};
1171
+ float tmpm[mma_C::ne] = {0.0f};
1172
+
1173
+ #pragma unroll
1174
+ for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
1175
+ mma_C C;
1176
+ mma_B B;
1177
+ half2 dsB[mma_C::ne/2];
1178
+
1179
+ #pragma unroll
1180
+ for (int l = 0; l < mma_B::ne; ++l) {
1181
+ const int j = j0 + mma_B::get_j(l);
1182
+ const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
1183
+
1184
+ B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
1185
+ }
1186
+ #pragma unroll
1187
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1188
+ const int j = j0 + mma_C::get_j(l);
1189
+
1190
+ dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
1191
+ }
1192
+
1193
+ C.mma_K8(A[kvdr/4], B);
1194
+
1195
+ #pragma unroll
1196
+ for (int l = 0; l < mma_C::ne; ++l) {
1197
+ tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
1198
+ tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
1199
+ }
1200
+ }
1201
+
1202
+ #pragma unroll
1203
+ for (int l = 0; l < mma_C::ne; ++l) {
1204
+ sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
1205
+ }
1206
+ }
1207
+ }
1208
+
1209
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
1210
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1211
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
 
1279
  }
1280
 
1281
  template <int mmq_x, int mmq_y, int nwarps>
1282
+ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1283
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
1284
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1285
 
 
1305
  }
1306
  }
1307
 
1308
+ template <int mmq_x, int mmq_y, int nwarps>
1309
+ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
1310
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
1311
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1312
+
1313
+ GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
1314
+
1315
+ typedef mma_int_A_I16K8 mma_A;
1316
+ typedef mma_int_B_J8K8 mma_B;
1317
+ typedef mma_int_C_I16J8 mma_C;
1318
+
1319
+ const int * y_qs = (const int *) y + 4;
1320
+ const half2 * y_ds = (const half2 *) y;
1321
+
1322
+ const int i0 = threadIdx.y*mma_A::I;
1323
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
1324
+
1325
+ mma_A A[2];
1326
+ int scA[mma_C::ne/2][2];
1327
+ int mA[mma_C::ne/2][2];
1328
+ half2 dmA[mma_C::ne/2];
1329
+ #pragma unroll
1330
+ for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
1331
+ #pragma unroll
1332
+ for (int l = 0; l < mma_A::ne; ++l) {
1333
+ const int i = i0 + mma_A::get_i(l);
1334
+ const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
1335
+
1336
+ A[kvdr/4].x[l] = x_ql[i*(QR5_K*WARP_SIZE + 1) + k];
1337
+ }
1338
+
1339
+ #pragma unroll
1340
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1341
+ const int i = i0 + mma_C::get_i(2*l);
1342
+
1343
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
1344
+ const uint8_t * m = sc + 8;
1345
+
1346
+ scA[l][kvdr/4] = sc[kvdr/4];
1347
+ mA[l][kvdr/4] = m[kvdr/4];
1348
+ }
1349
+ }
1350
+
1351
+ #pragma unroll
1352
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1353
+ const int i = i0 + mma_C::get_i(2*l);
1354
+
1355
+ dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
1356
+ }
1357
+
1358
+ #pragma unroll
1359
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
1360
+ float tmpd[mma_C::ne] = {0.0f};
1361
+ float tmpm[mma_C::ne] = {0.0f};
1362
+
1363
+ #pragma unroll
1364
+ for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
1365
+ mma_C C;
1366
+ mma_B B;
1367
+ half2 dsB[mma_C::ne/2];
1368
+
1369
+ #pragma unroll
1370
+ for (int l = 0; l < mma_B::ne; ++l) {
1371
+ const int j = j0 + mma_B::get_j(l);
1372
+ const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
1373
+
1374
+ B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
1375
+ }
1376
+ #pragma unroll
1377
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1378
+ const int j = j0 + mma_C::get_j(l);
1379
+
1380
+ dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
1381
+ }
1382
+
1383
+ C.mma_K8(A[kvdr/4], B);
1384
+
1385
+ #pragma unroll
1386
+ for (int l = 0; l < mma_C::ne; ++l) {
1387
+ tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
1388
+ tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
1389
+ }
1390
+ }
1391
+
1392
+ #pragma unroll
1393
+ for (int l = 0; l < mma_C::ne; ++l) {
1394
+ sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
1395
+ }
1396
+ }
1397
+ }
1398
+
1399
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
1400
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1401
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
 
1462
  }
1463
 
1464
  template <int mmq_x, int mmq_y, int nwarps>
1465
+ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1466
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
1467
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1468
 
 
1489
  }
1490
  }
1491
 
1492
+ template <int mmq_x, int mmq_y, int nwarps>
1493
+ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1494
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
1495
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1496
+
1497
+ GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
1498
+
1499
+ typedef mma_int_A_I16K4 mma_A;
1500
+ typedef mma_int_B_J8K4 mma_B;
1501
+ typedef mma_int_C_I16J8 mma_C;
1502
+
1503
+ const float * x_df = (const float *) x_dm;
1504
+ const int * y_qs = (const int *) y + 4;
1505
+ const float * y_df = (const float *) y;
1506
+
1507
+ const int i0 = threadIdx.y*mma_A::I;
1508
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
1509
+
1510
+ mma_A A[4];
1511
+ int scA[mma_C::ne/2][4];
1512
+ float dA[mma_C::ne/2];
1513
+ #pragma unroll
1514
+ for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
1515
+ #pragma unroll
1516
+ for (int l = 0; l < mma_A::ne; ++l) {
1517
+ const int i = i0 + mma_A::get_i(l);
1518
+ const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
1519
+
1520
+ A[kvdr/2 + 0].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + 0];
1521
+ A[kvdr/2 + 1].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
1522
+ }
1523
+
1524
+ #pragma unroll
1525
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1526
+ const int i = i0 + mma_C::get_i(2*l);
1527
+
1528
+ const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
1529
+
1530
+ scA[l][kvdr/2 + 0] = sc[kvdr/2 + 0];
1531
+ scA[l][kvdr/2 + 1] = sc[kvdr/2 + 1];
1532
+ }
1533
+ }
1534
+
1535
+ #pragma unroll
1536
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1537
+ const int i = i0 + mma_C::get_i(2*l);
1538
+
1539
+ dA[l] = x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + k0/QI6_K];
1540
+ }
1541
+
1542
+ #pragma unroll
1543
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
1544
+ float tmp[mma_C::ne] = {0.0f};
1545
+
1546
+ #pragma unroll
1547
+ for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
1548
+ mma_C C[2];
1549
+ mma_B B[2];
1550
+ float dB[mma_C::ne/2];
1551
+
1552
+ #pragma unroll
1553
+ for (int l = 0; l < mma_B::ne; ++l) {
1554
+ const int j = j0 + mma_B::get_j(l);
1555
+ const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
1556
+
1557
+ B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
1558
+ B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
1559
+ }
1560
+ #pragma unroll
1561
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1562
+ const int j = j0 + mma_C::get_j(l);
1563
+
1564
+ dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
1565
+ }
1566
+
1567
+ C[0].mma_K4(A[kvdr/2 + 0], B[0]);
1568
+ C[1].mma_K4(A[kvdr/2 + 1], B[1]);
1569
+
1570
+ #pragma unroll
1571
+ for (int l = 0; l < mma_C::ne; ++l) {
1572
+ tmp[l] += (C[0].x[l]*scA[l/2][kvdr/2 + 0] + C[1].x[l]*scA[l/2][kvdr/2 + 1])*dB[l%2];
1573
+ }
1574
+ }
1575
+
1576
+ #pragma unroll
1577
+ for (int l = 0; l < mma_C::ne; ++l) {
1578
+ sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
1579
+ }
1580
+ }
1581
+ }
1582
+
1583
  template<int mmq_x, int mmq_y, int nwarps, bool need_check>
1584
  static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
1585
  #pragma unroll
 
1721
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
1722
  static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
1723
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
1724
+ #ifdef INT8_MMA_AVAILABLE
1725
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
1726
+ static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1727
+ #else
1728
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1729
  static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1730
+ #endif // INT8_MMA_AVAILABLE
1731
  };
1732
 
1733
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1734
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
1735
  static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
1736
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
1737
+ #ifdef INT8_MMA_AVAILABLE
1738
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
1739
+ static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1740
+ #else
1741
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1742
  static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1743
+ #endif // INT8_MMA_AVAILABLE
1744
  };
1745
 
1746
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1747
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
1748
  static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
1749
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
1750
+ #ifdef INT8_MMA_AVAILABLE
1751
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
1752
+ static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1753
+ #else
1754
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1755
  static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1756
+ #endif // INT8_MMA_AVAILABLE
1757
  };
1758
 
1759
  static int mmq_need_sum(const ggml_type type_x) {