Spaces:
Running
Running
Commit
·
f5cd546
1
Parent(s):
e093044
CUDA: use switch statements in constexpr functions (llama/13095)
Browse files- ggml/src/ggml-cuda/mmq.cuh +42 -38
- ggml/src/ggml-cuda/mmvq.cu +42 -38
ggml/src/ggml-cuda/mmq.cuh
CHANGED
|
@@ -155,25 +155,27 @@ static constexpr __device__ int get_mmq_y_device() {
|
|
| 155 |
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
|
| 156 |
|
| 157 |
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
| 177 |
}
|
| 178 |
|
| 179 |
#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
|
|
@@ -189,25 +191,27 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
|
|
| 189 |
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
| 190 |
|
| 191 |
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
| 211 |
}
|
| 212 |
|
| 213 |
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
|
|
|
|
| 155 |
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
|
| 156 |
|
| 157 |
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
|
| 158 |
+
switch (type) {
|
| 159 |
+
case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
|
| 160 |
+
case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
|
| 161 |
+
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
|
| 162 |
+
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
|
| 163 |
+
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
|
| 164 |
+
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
|
| 165 |
+
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
|
| 166 |
+
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
|
| 167 |
+
case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K;
|
| 168 |
+
case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K;
|
| 169 |
+
case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0;
|
| 170 |
+
case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16;
|
| 171 |
+
case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16;
|
| 172 |
+
case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;
|
| 173 |
+
case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0;
|
| 174 |
+
case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0;
|
| 175 |
+
case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0;
|
| 176 |
+
case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0;
|
| 177 |
+
default: return tile_x_sizes{0, 0, 0};
|
| 178 |
+
}
|
| 179 |
}
|
| 180 |
|
| 181 |
#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
|
|
|
|
| 191 |
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
| 192 |
|
| 193 |
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
| 194 |
+
switch (type) {
|
| 195 |
+
case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
| 196 |
+
case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
| 197 |
+
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
| 198 |
+
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
| 199 |
+
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
| 200 |
+
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
| 201 |
+
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
| 202 |
+
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
| 203 |
+
case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
| 204 |
+
case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K;
|
| 205 |
+
case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
|
| 206 |
+
case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K;
|
| 207 |
+
case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K;
|
| 208 |
+
case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
|
| 209 |
+
case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0;
|
| 210 |
+
case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0;
|
| 211 |
+
case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0;
|
| 212 |
+
case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0;
|
| 213 |
+
default: return 0;
|
| 214 |
+
}
|
| 215 |
}
|
| 216 |
|
| 217 |
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
|
ggml/src/ggml-cuda/mmvq.cu
CHANGED
|
@@ -7,47 +7,51 @@
|
|
| 7 |
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
|
| 8 |
|
| 9 |
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
}
|
| 31 |
|
| 32 |
static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
| 51 |
}
|
| 52 |
|
| 53 |
enum mmvq_parameter_table_id {
|
|
|
|
| 7 |
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
|
| 8 |
|
| 9 |
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
|
| 10 |
+
switch (type) {
|
| 11 |
+
case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1;
|
| 12 |
+
case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1;
|
| 13 |
+
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
|
| 14 |
+
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
|
| 15 |
+
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
|
| 16 |
+
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
|
| 17 |
+
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
|
| 18 |
+
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
|
| 19 |
+
case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1;
|
| 20 |
+
case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1;
|
| 21 |
+
case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1;
|
| 22 |
+
case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1;
|
| 23 |
+
case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1;
|
| 24 |
+
case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1;
|
| 25 |
+
case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1;
|
| 26 |
+
case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1;
|
| 27 |
+
case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1;
|
| 28 |
+
case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1;
|
| 29 |
+
case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1;
|
| 30 |
+
default: return nullptr;
|
| 31 |
+
}
|
| 32 |
}
|
| 33 |
|
| 34 |
static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
| 35 |
+
switch (type) {
|
| 36 |
+
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
|
| 37 |
+
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
|
| 38 |
+
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
|
| 39 |
+
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
|
| 40 |
+
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
|
| 41 |
+
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
|
| 42 |
+
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
|
| 43 |
+
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
|
| 44 |
+
case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ;
|
| 45 |
+
case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ;
|
| 46 |
+
case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ;
|
| 47 |
+
case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ;
|
| 48 |
+
case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ;
|
| 49 |
+
case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ;
|
| 50 |
+
case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ;
|
| 51 |
+
case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ;
|
| 52 |
+
case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ;
|
| 53 |
+
default: return 1;
|
| 54 |
+
}
|
| 55 |
}
|
| 56 |
|
| 57 |
enum mmvq_parameter_table_id {
|