Spaces:
Running
Running
Commit
·
154bf2b
1
Parent(s):
78a5b67
CUDA: int8 tensor cores for MMQ (q4_K, q5_K, q6_K) (llama/7860)
Browse files- ggml-cuda/mma.cuh +66 -0
- ggml-cuda/mmq.cuh +294 -6
ggml-cuda/mma.cuh
CHANGED
|
@@ -1,5 +1,27 @@
|
|
| 1 |
#include "common.cuh"
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
struct mma_int_A_I16K8 {
|
| 4 |
static constexpr int I = 16;
|
| 5 |
static constexpr int K = 8;
|
|
@@ -22,6 +44,28 @@ struct mma_int_A_I16K8 {
|
|
| 22 |
}
|
| 23 |
};
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
struct mma_int_B_J8K8 {
|
| 26 |
static constexpr int J = 8;
|
| 27 |
static constexpr int K = 8;
|
|
@@ -65,6 +109,28 @@ struct mma_int_C_I16J8 {
|
|
| 65 |
return ret;
|
| 66 |
}
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
|
| 69 |
#ifdef INT8_MMA_AVAILABLE
|
| 70 |
#if __CUDA_ARCH__ >= CC_AMPERE
|
|
|
|
| 1 |
#include "common.cuh"
|
| 2 |
|
| 3 |
+
struct mma_int_A_I16K4 {
|
| 4 |
+
static constexpr int I = 16;
|
| 5 |
+
static constexpr int K = 4;
|
| 6 |
+
static constexpr int ne = 2;
|
| 7 |
+
|
| 8 |
+
int x[ne] = {0};
|
| 9 |
+
|
| 10 |
+
static __device__ __forceinline__ int get_i(const int l) {
|
| 11 |
+
const int ret = (l%2) * (I/2) + threadIdx.x / K;
|
| 12 |
+
GGML_CUDA_ASSUME(ret >= 0);
|
| 13 |
+
GGML_CUDA_ASSUME(ret < I);
|
| 14 |
+
return ret;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
static __device__ __forceinline__ int get_k(const int /* l */) {
|
| 18 |
+
const int ret = threadIdx.x % K;
|
| 19 |
+
GGML_CUDA_ASSUME(ret >= 0);
|
| 20 |
+
GGML_CUDA_ASSUME(ret < K);
|
| 21 |
+
return ret;
|
| 22 |
+
}
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
struct mma_int_A_I16K8 {
|
| 26 |
static constexpr int I = 16;
|
| 27 |
static constexpr int K = 8;
|
|
|
|
| 44 |
}
|
| 45 |
};
|
| 46 |
|
| 47 |
+
struct mma_int_B_J8K4 {
|
| 48 |
+
static constexpr int J = 8;
|
| 49 |
+
static constexpr int K = 4;
|
| 50 |
+
static constexpr int ne = 1;
|
| 51 |
+
|
| 52 |
+
int x[ne] = {0};
|
| 53 |
+
|
| 54 |
+
static __device__ __forceinline__ int get_j(const int /* l */) {
|
| 55 |
+
const int ret = threadIdx.x / K;
|
| 56 |
+
GGML_CUDA_ASSUME(ret >= 0);
|
| 57 |
+
GGML_CUDA_ASSUME(ret < J);
|
| 58 |
+
return ret;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
static __device__ __forceinline__ int get_k(const int /* l */) {
|
| 62 |
+
const int ret = threadIdx.x % K;
|
| 63 |
+
GGML_CUDA_ASSUME(ret >= 0);
|
| 64 |
+
GGML_CUDA_ASSUME(ret < K);
|
| 65 |
+
return ret;
|
| 66 |
+
}
|
| 67 |
+
};
|
| 68 |
+
|
| 69 |
struct mma_int_B_J8K8 {
|
| 70 |
static constexpr int J = 8;
|
| 71 |
static constexpr int K = 8;
|
|
|
|
| 109 |
return ret;
|
| 110 |
}
|
| 111 |
|
| 112 |
+
__device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
|
| 113 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 114 |
+
#if __CUDA_ARCH__ >= CC_AMPERE
|
| 115 |
+
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 116 |
+
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
| 117 |
+
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
|
| 118 |
+
#else
|
| 119 |
+
// On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
|
| 120 |
+
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
| 121 |
+
: "+r"(x[0]), "+r"(x[1])
|
| 122 |
+
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
|
| 123 |
+
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
| 124 |
+
: "+r"(x[2]), "+r"(x[3])
|
| 125 |
+
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
|
| 126 |
+
#endif // __CUDA_ARCH__ >= CC_AMPERE
|
| 127 |
+
#else
|
| 128 |
+
GGML_UNUSED(mma_A);
|
| 129 |
+
GGML_UNUSED(mma_B);
|
| 130 |
+
NO_DEVICE_CODE;
|
| 131 |
+
#endif // INT8_MMA_AVAILABLE
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
|
| 135 |
#ifdef INT8_MMA_AVAILABLE
|
| 136 |
#if __CUDA_ARCH__ >= CC_AMPERE
|
ggml-cuda/mmq.cuh
CHANGED
|
@@ -1089,7 +1089,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1089 |
}
|
| 1090 |
|
| 1091 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1092 |
-
static __device__ __forceinline__ void
|
| 1093 |
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
| 1094 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1095 |
|
|
@@ -1115,6 +1115,97 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
|
|
| 1115 |
}
|
| 1116 |
}
|
| 1117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1118 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
|
| 1119 |
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
| 1120 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
|
@@ -1188,7 +1279,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1188 |
}
|
| 1189 |
|
| 1190 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1191 |
-
static __device__ __forceinline__ void
|
| 1192 |
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
| 1193 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1194 |
|
|
@@ -1214,6 +1305,97 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
|
|
| 1214 |
}
|
| 1215 |
}
|
| 1216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1217 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
|
| 1218 |
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
| 1219 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
|
@@ -1280,7 +1462,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1280 |
}
|
| 1281 |
|
| 1282 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1283 |
-
static __device__ __forceinline__ void
|
| 1284 |
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
| 1285 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1286 |
|
|
@@ -1307,6 +1489,97 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
|
|
| 1307 |
}
|
| 1308 |
}
|
| 1309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1310 |
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1311 |
static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
|
| 1312 |
#pragma unroll
|
|
@@ -1448,24 +1721,39 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
| 1448 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
|
| 1449 |
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
|
| 1450 |
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
|
| 1451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1452 |
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
|
|
|
| 1453 |
};
|
| 1454 |
|
| 1455 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1456 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
|
| 1457 |
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
|
| 1458 |
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
|
| 1459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1460 |
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
|
|
|
| 1461 |
};
|
| 1462 |
|
| 1463 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1464 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
|
| 1465 |
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
|
| 1466 |
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
|
| 1467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1468 |
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
|
|
|
| 1469 |
};
|
| 1470 |
|
| 1471 |
static int mmq_need_sum(const ggml_type type_x) {
|
|
|
|
| 1089 |
}
|
| 1090 |
|
| 1091 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1092 |
+
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
| 1093 |
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
| 1094 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1095 |
|
|
|
|
| 1115 |
}
|
| 1116 |
}
|
| 1117 |
|
| 1118 |
+
template <int mmq_x, int mmq_y, int nwarps>
|
| 1119 |
+
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
|
| 1120 |
+
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
| 1121 |
+
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1122 |
+
|
| 1123 |
+
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
| 1124 |
+
|
| 1125 |
+
typedef mma_int_A_I16K8 mma_A;
|
| 1126 |
+
typedef mma_int_B_J8K8 mma_B;
|
| 1127 |
+
typedef mma_int_C_I16J8 mma_C;
|
| 1128 |
+
|
| 1129 |
+
const int * y_qs = (const int *) y + 4;
|
| 1130 |
+
const half2 * y_ds = (const half2 *) y;
|
| 1131 |
+
|
| 1132 |
+
const int i0 = threadIdx.y*mma_A::I;
|
| 1133 |
+
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
| 1134 |
+
|
| 1135 |
+
mma_A A[2];
|
| 1136 |
+
int scA[mma_C::ne/2][2];
|
| 1137 |
+
int mA[mma_C::ne/2][2];
|
| 1138 |
+
half2 dmA[mma_C::ne/2];
|
| 1139 |
+
#pragma unroll
|
| 1140 |
+
for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
|
| 1141 |
+
#pragma unroll
|
| 1142 |
+
for (int l = 0; l < mma_A::ne; ++l) {
|
| 1143 |
+
const int i = i0 + mma_A::get_i(l);
|
| 1144 |
+
const int k = k0 + mma_A::get_k(l);
|
| 1145 |
+
|
| 1146 |
+
A[kvdr/4].x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
|
| 1147 |
+
}
|
| 1148 |
+
|
| 1149 |
+
#pragma unroll
|
| 1150 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1151 |
+
const int i = i0 + mma_C::get_i(2*l);
|
| 1152 |
+
|
| 1153 |
+
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
|
| 1154 |
+
const uint8_t * m = sc + 8;
|
| 1155 |
+
|
| 1156 |
+
scA[l][kvdr/4] = sc[kvdr/4];
|
| 1157 |
+
mA[l][kvdr/4] = m[kvdr/4];
|
| 1158 |
+
}
|
| 1159 |
+
}
|
| 1160 |
+
|
| 1161 |
+
#pragma unroll
|
| 1162 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1163 |
+
const int i = i0 + mma_C::get_i(2*l);
|
| 1164 |
+
|
| 1165 |
+
dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
|
| 1166 |
+
}
|
| 1167 |
+
|
| 1168 |
+
#pragma unroll
|
| 1169 |
+
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
| 1170 |
+
float tmpd[mma_C::ne] = {0.0f};
|
| 1171 |
+
float tmpm[mma_C::ne] = {0.0f};
|
| 1172 |
+
|
| 1173 |
+
#pragma unroll
|
| 1174 |
+
for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
|
| 1175 |
+
mma_C C;
|
| 1176 |
+
mma_B B;
|
| 1177 |
+
half2 dsB[mma_C::ne/2];
|
| 1178 |
+
|
| 1179 |
+
#pragma unroll
|
| 1180 |
+
for (int l = 0; l < mma_B::ne; ++l) {
|
| 1181 |
+
const int j = j0 + mma_B::get_j(l);
|
| 1182 |
+
const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
|
| 1183 |
+
|
| 1184 |
+
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
|
| 1185 |
+
}
|
| 1186 |
+
#pragma unroll
|
| 1187 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1188 |
+
const int j = j0 + mma_C::get_j(l);
|
| 1189 |
+
|
| 1190 |
+
dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
|
| 1191 |
+
}
|
| 1192 |
+
|
| 1193 |
+
C.mma_K8(A[kvdr/4], B);
|
| 1194 |
+
|
| 1195 |
+
#pragma unroll
|
| 1196 |
+
for (int l = 0; l < mma_C::ne; ++l) {
|
| 1197 |
+
tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
|
| 1198 |
+
tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
|
| 1199 |
+
}
|
| 1200 |
+
}
|
| 1201 |
+
|
| 1202 |
+
#pragma unroll
|
| 1203 |
+
for (int l = 0; l < mma_C::ne; ++l) {
|
| 1204 |
+
sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
|
| 1205 |
+
}
|
| 1206 |
+
}
|
| 1207 |
+
}
|
| 1208 |
+
|
| 1209 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
|
| 1210 |
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
| 1211 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
|
|
|
| 1279 |
}
|
| 1280 |
|
| 1281 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1282 |
+
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
| 1283 |
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
| 1284 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1285 |
|
|
|
|
| 1305 |
}
|
| 1306 |
}
|
| 1307 |
|
| 1308 |
+
template <int mmq_x, int mmq_y, int nwarps>
|
| 1309 |
+
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
|
| 1310 |
+
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
| 1311 |
+
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1312 |
+
|
| 1313 |
+
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
| 1314 |
+
|
| 1315 |
+
typedef mma_int_A_I16K8 mma_A;
|
| 1316 |
+
typedef mma_int_B_J8K8 mma_B;
|
| 1317 |
+
typedef mma_int_C_I16J8 mma_C;
|
| 1318 |
+
|
| 1319 |
+
const int * y_qs = (const int *) y + 4;
|
| 1320 |
+
const half2 * y_ds = (const half2 *) y;
|
| 1321 |
+
|
| 1322 |
+
const int i0 = threadIdx.y*mma_A::I;
|
| 1323 |
+
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
| 1324 |
+
|
| 1325 |
+
mma_A A[2];
|
| 1326 |
+
int scA[mma_C::ne/2][2];
|
| 1327 |
+
int mA[mma_C::ne/2][2];
|
| 1328 |
+
half2 dmA[mma_C::ne/2];
|
| 1329 |
+
#pragma unroll
|
| 1330 |
+
for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
|
| 1331 |
+
#pragma unroll
|
| 1332 |
+
for (int l = 0; l < mma_A::ne; ++l) {
|
| 1333 |
+
const int i = i0 + mma_A::get_i(l);
|
| 1334 |
+
const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
|
| 1335 |
+
|
| 1336 |
+
A[kvdr/4].x[l] = x_ql[i*(QR5_K*WARP_SIZE + 1) + k];
|
| 1337 |
+
}
|
| 1338 |
+
|
| 1339 |
+
#pragma unroll
|
| 1340 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1341 |
+
const int i = i0 + mma_C::get_i(2*l);
|
| 1342 |
+
|
| 1343 |
+
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
|
| 1344 |
+
const uint8_t * m = sc + 8;
|
| 1345 |
+
|
| 1346 |
+
scA[l][kvdr/4] = sc[kvdr/4];
|
| 1347 |
+
mA[l][kvdr/4] = m[kvdr/4];
|
| 1348 |
+
}
|
| 1349 |
+
}
|
| 1350 |
+
|
| 1351 |
+
#pragma unroll
|
| 1352 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1353 |
+
const int i = i0 + mma_C::get_i(2*l);
|
| 1354 |
+
|
| 1355 |
+
dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
|
| 1356 |
+
}
|
| 1357 |
+
|
| 1358 |
+
#pragma unroll
|
| 1359 |
+
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
| 1360 |
+
float tmpd[mma_C::ne] = {0.0f};
|
| 1361 |
+
float tmpm[mma_C::ne] = {0.0f};
|
| 1362 |
+
|
| 1363 |
+
#pragma unroll
|
| 1364 |
+
for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
|
| 1365 |
+
mma_C C;
|
| 1366 |
+
mma_B B;
|
| 1367 |
+
half2 dsB[mma_C::ne/2];
|
| 1368 |
+
|
| 1369 |
+
#pragma unroll
|
| 1370 |
+
for (int l = 0; l < mma_B::ne; ++l) {
|
| 1371 |
+
const int j = j0 + mma_B::get_j(l);
|
| 1372 |
+
const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
|
| 1373 |
+
|
| 1374 |
+
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
|
| 1375 |
+
}
|
| 1376 |
+
#pragma unroll
|
| 1377 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1378 |
+
const int j = j0 + mma_C::get_j(l);
|
| 1379 |
+
|
| 1380 |
+
dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
|
| 1381 |
+
}
|
| 1382 |
+
|
| 1383 |
+
C.mma_K8(A[kvdr/4], B);
|
| 1384 |
+
|
| 1385 |
+
#pragma unroll
|
| 1386 |
+
for (int l = 0; l < mma_C::ne; ++l) {
|
| 1387 |
+
tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
|
| 1388 |
+
tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
|
| 1389 |
+
}
|
| 1390 |
+
}
|
| 1391 |
+
|
| 1392 |
+
#pragma unroll
|
| 1393 |
+
for (int l = 0; l < mma_C::ne; ++l) {
|
| 1394 |
+
sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
|
| 1395 |
+
}
|
| 1396 |
+
}
|
| 1397 |
+
}
|
| 1398 |
+
|
| 1399 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
|
| 1400 |
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
| 1401 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
|
|
|
| 1462 |
}
|
| 1463 |
|
| 1464 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1465 |
+
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
| 1466 |
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
| 1467 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1468 |
|
|
|
|
| 1489 |
}
|
| 1490 |
}
|
| 1491 |
|
| 1492 |
+
template <int mmq_x, int mmq_y, int nwarps>
|
| 1493 |
+
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
| 1494 |
+
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
| 1495 |
+
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1496 |
+
|
| 1497 |
+
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
| 1498 |
+
|
| 1499 |
+
typedef mma_int_A_I16K4 mma_A;
|
| 1500 |
+
typedef mma_int_B_J8K4 mma_B;
|
| 1501 |
+
typedef mma_int_C_I16J8 mma_C;
|
| 1502 |
+
|
| 1503 |
+
const float * x_df = (const float *) x_dm;
|
| 1504 |
+
const int * y_qs = (const int *) y + 4;
|
| 1505 |
+
const float * y_df = (const float *) y;
|
| 1506 |
+
|
| 1507 |
+
const int i0 = threadIdx.y*mma_A::I;
|
| 1508 |
+
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
| 1509 |
+
|
| 1510 |
+
mma_A A[4];
|
| 1511 |
+
int scA[mma_C::ne/2][4];
|
| 1512 |
+
float dA[mma_C::ne/2];
|
| 1513 |
+
#pragma unroll
|
| 1514 |
+
for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
|
| 1515 |
+
#pragma unroll
|
| 1516 |
+
for (int l = 0; l < mma_A::ne; ++l) {
|
| 1517 |
+
const int i = i0 + mma_A::get_i(l);
|
| 1518 |
+
const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
|
| 1519 |
+
|
| 1520 |
+
A[kvdr/2 + 0].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + 0];
|
| 1521 |
+
A[kvdr/2 + 1].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
|
| 1522 |
+
}
|
| 1523 |
+
|
| 1524 |
+
#pragma unroll
|
| 1525 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1526 |
+
const int i = i0 + mma_C::get_i(2*l);
|
| 1527 |
+
|
| 1528 |
+
const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
|
| 1529 |
+
|
| 1530 |
+
scA[l][kvdr/2 + 0] = sc[kvdr/2 + 0];
|
| 1531 |
+
scA[l][kvdr/2 + 1] = sc[kvdr/2 + 1];
|
| 1532 |
+
}
|
| 1533 |
+
}
|
| 1534 |
+
|
| 1535 |
+
#pragma unroll
|
| 1536 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1537 |
+
const int i = i0 + mma_C::get_i(2*l);
|
| 1538 |
+
|
| 1539 |
+
dA[l] = x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + k0/QI6_K];
|
| 1540 |
+
}
|
| 1541 |
+
|
| 1542 |
+
#pragma unroll
|
| 1543 |
+
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
| 1544 |
+
float tmp[mma_C::ne] = {0.0f};
|
| 1545 |
+
|
| 1546 |
+
#pragma unroll
|
| 1547 |
+
for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
|
| 1548 |
+
mma_C C[2];
|
| 1549 |
+
mma_B B[2];
|
| 1550 |
+
float dB[mma_C::ne/2];
|
| 1551 |
+
|
| 1552 |
+
#pragma unroll
|
| 1553 |
+
for (int l = 0; l < mma_B::ne; ++l) {
|
| 1554 |
+
const int j = j0 + mma_B::get_j(l);
|
| 1555 |
+
const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
|
| 1556 |
+
|
| 1557 |
+
B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
|
| 1558 |
+
B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
|
| 1559 |
+
}
|
| 1560 |
+
#pragma unroll
|
| 1561 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1562 |
+
const int j = j0 + mma_C::get_j(l);
|
| 1563 |
+
|
| 1564 |
+
dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
|
| 1565 |
+
}
|
| 1566 |
+
|
| 1567 |
+
C[0].mma_K4(A[kvdr/2 + 0], B[0]);
|
| 1568 |
+
C[1].mma_K4(A[kvdr/2 + 1], B[1]);
|
| 1569 |
+
|
| 1570 |
+
#pragma unroll
|
| 1571 |
+
for (int l = 0; l < mma_C::ne; ++l) {
|
| 1572 |
+
tmp[l] += (C[0].x[l]*scA[l/2][kvdr/2 + 0] + C[1].x[l]*scA[l/2][kvdr/2 + 1])*dB[l%2];
|
| 1573 |
+
}
|
| 1574 |
+
}
|
| 1575 |
+
|
| 1576 |
+
#pragma unroll
|
| 1577 |
+
for (int l = 0; l < mma_C::ne; ++l) {
|
| 1578 |
+
sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
|
| 1579 |
+
}
|
| 1580 |
+
}
|
| 1581 |
+
}
|
| 1582 |
+
|
| 1583 |
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1584 |
static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
|
| 1585 |
#pragma unroll
|
|
|
|
| 1721 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
|
| 1722 |
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
|
| 1723 |
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
|
| 1724 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 1725 |
+
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 1726 |
+
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
| 1727 |
+
#else
|
| 1728 |
+
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 1729 |
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 1730 |
+
#endif // INT8_MMA_AVAILABLE
|
| 1731 |
};
|
| 1732 |
|
| 1733 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1734 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
|
| 1735 |
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
|
| 1736 |
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
|
| 1737 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 1738 |
+
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 1739 |
+
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
| 1740 |
+
#else
|
| 1741 |
+
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 1742 |
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 1743 |
+
#endif // INT8_MMA_AVAILABLE
|
| 1744 |
};
|
| 1745 |
|
| 1746 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1747 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
|
| 1748 |
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
|
| 1749 |
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
|
| 1750 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 1751 |
+
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 1752 |
+
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
| 1753 |
+
#else
|
| 1754 |
+
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 1755 |
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 1756 |
+
#endif // INT8_MMA_AVAILABLE
|
| 1757 |
};
|
| 1758 |
|
| 1759 |
static int mmq_need_sum(const ggml_type type_x) {
|