JohannesGaessler commited on
Commit
f5cd546
·
1 Parent(s): e093044

CUDA: use switch statements in constexpr functions (llama/13095)

Browse files
ggml/src/ggml-cuda/mmq.cuh CHANGED
@@ -155,25 +155,27 @@ static constexpr __device__ int get_mmq_y_device() {
155
  #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
156
 
157
  static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
158
- return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
159
- type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
160
- type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
161
- type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
162
- type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
163
- type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
164
- type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
165
- type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
166
- type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
167
- type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
168
- type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
169
- type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
170
- type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
171
- type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
172
- type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
173
- type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
174
- type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
175
- type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
176
- tile_x_sizes{0, 0, 0};
 
 
177
  }
178
 
179
  #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
@@ -189,25 +191,27 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
189
  static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
190
 
191
  static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
192
- return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
193
- type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
194
- type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
195
- type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
196
- type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
197
- type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
198
- type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
199
- type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
200
- type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
201
- type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
202
- type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
203
- type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
204
- type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
205
- type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
206
- type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
207
- type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
208
- type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
209
- type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
210
- 0;
 
 
211
  }
212
 
213
  #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
 
155
  #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
156
 
157
  static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
158
+ switch (type) {
159
+ case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
160
+ case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
161
+ case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
162
+ case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
163
+ case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
164
+ case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
165
+ case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
166
+ case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
167
+ case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K;
168
+ case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K;
169
+ case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0;
170
+ case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16;
171
+ case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16;
172
+ case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;
173
+ case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0;
174
+ case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0;
175
+ case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0;
176
+ case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0;
177
+ default: return tile_x_sizes{0, 0, 0};
178
+ }
179
  }
180
 
181
  #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
 
191
  static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
192
 
193
  static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
194
+ switch (type) {
195
+ case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
196
+ case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
197
+ case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
198
+ case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
199
+ case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
200
+ case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
201
+ case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
202
+ case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
203
+ case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1;
204
+ case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K;
205
+ case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
206
+ case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K;
207
+ case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K;
208
+ case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
209
+ case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0;
210
+ case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0;
211
+ case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0;
212
+ case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0;
213
+ default: return 0;
214
+ }
215
  }
216
 
217
  #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
ggml/src/ggml-cuda/mmvq.cu CHANGED
@@ -7,47 +7,51 @@
7
  typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
8
 
9
  static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
10
- return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
11
- type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
12
- type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
13
- type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
14
- type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
15
- type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
16
- type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
17
- type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
18
- type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
19
- type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
20
- type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
21
- type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
22
- type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
23
- type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
24
- type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
25
- type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
26
- type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
27
- type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
28
- type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
29
- nullptr;
 
 
30
  }
31
 
32
  static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
33
- return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
34
- type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
35
- type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
36
- type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
37
- type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
38
- type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
39
- type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
40
- type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
41
- type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
42
- type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
43
- type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
44
- type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ :
45
- type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ :
46
- type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ :
47
- type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ :
48
- type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ :
49
- type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ :
50
- 1;
 
 
51
  }
52
 
53
  enum mmvq_parameter_table_id {
 
7
  typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
8
 
9
  static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
10
+ switch (type) {
11
+ case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1;
12
+ case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1;
13
+ case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
14
+ case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
15
+ case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
16
+ case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
17
+ case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
18
+ case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
19
+ case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1;
20
+ case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1;
21
+ case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1;
22
+ case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1;
23
+ case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1;
24
+ case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1;
25
+ case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1;
26
+ case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1;
27
+ case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1;
28
+ case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1;
29
+ case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1;
30
+ default: return nullptr;
31
+ }
32
  }
33
 
34
  static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
35
+ switch (type) {
36
+ case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
37
+ case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
38
+ case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
39
+ case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
40
+ case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
41
+ case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
42
+ case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
43
+ case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
44
+ case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ;
45
+ case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ;
46
+ case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ;
47
+ case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ;
48
+ case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ;
49
+ case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ;
50
+ case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ;
51
+ case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ;
52
+ case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ;
53
+ default: return 1;
54
+ }
55
  }
56
 
57
  enum mmvq_parameter_table_id {