slaren ggerganov commited on
Commit
26fdc9f
·
unverified ·
1 Parent(s): 7a97623

ggml : mul_mat_id use the same tensor for all the experts (llama/6387)

Browse files

* ggml : update mul_mat_id to use the same tensor for all the experts

* update cuda

* minor

* update metal

* update test-backend-ops

* fix cuda

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <[email protected]>

* update convert.py

* update convert-hf-to-gguf.py

* update convert.py for mixtral hf models

* Update convert-hf-to-gguf.py

Co-authored-by: Georgi Gerganov <[email protected]>

* cuda : support non-pow-2 number of experts

* allow quantize to work for split and merged experts models in the same way

* cleanup + disable mmap automatically with split tensors models

* update imatrix

* test-backend-ops : test qwen argsort

* update grok model loading

* llama : add merged experts tensors to the grok tensor map

* minor

* gguf : bump version

* fix quantizing of merged experts

* convert-hf-to-gguf.py : update grok (untested)

* make linter happy

* cuda/argsort : use shared memory instead of pool memory

* convert : fix grok tensor names

* metal : add support for non-pow-2 argsort

* llama : more loader cleanup, better error checking

* cuda : fix warning

* llama : still use mmap for loading old models, but copy the data to a host buffer

* add review note

* llama : remove ffn tensor counting + add sanity check

ggml-ci

* convert : fix handling of n_experts == None

ggml-ci

* imatrix : fix ncall counters

* llama : produce error if imatrix size does not match

* quantize : terminate on errors + trace logs

ggml-ci

* metal : pad shared memory to 16 bytes

---------

Co-authored-by: Georgi Gerganov <[email protected]>

Files changed (6) hide show
  1. ggml-cuda.cu +17 -197
  2. ggml-cuda/argsort.cu +39 -13
  3. ggml-metal.m +100 -109
  4. ggml-metal.metal +122 -288
  5. ggml.c +23 -34
  6. ggml.h +1 -2
ggml-cuda.cu CHANGED
@@ -401,10 +401,8 @@ GGML_CALL static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t
401
  GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
402
  ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
403
 
404
- if (tensor->view_src != NULL && tensor->view_offs == 0) {
405
  assert(tensor->view_src->buffer->buft == buffer->buft);
406
- tensor->backend = tensor->view_src->backend;
407
- tensor->extra = tensor->view_src->extra;
408
  return;
409
  }
410
 
@@ -1962,227 +1960,49 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1962
  }
1963
  }
1964
 
1965
- #if 0
1966
- template<typename ... Srcs>
1967
- static __global__ void k_compute_batched_ptrs_id(
1968
- const void ** ptrs_src, void ** ptrs_dst,
1969
- int ne12, int ne13,
1970
- int ne23,
1971
- int nb02, int nb03,
1972
- int nb12, int nb13,
1973
- int nb2, int nb3,
1974
- int r2, int r3,
1975
- ggml_type src0_type, half * src0_as_f16, int64_t src0_ne,
1976
- const half * src1_f16, half * dst_f16,
1977
- const int32_t * ids, const int id,
1978
- Srcs... src0s) {
1979
-
1980
- int i = ids[id];
1981
-
1982
- half * src0_f16;
1983
- const void * srcs_ar[] = { (const half *) src0s... };
1984
- if (src0_type == GGML_TYPE_F16) {
1985
- src0_f16 = (half *) srcs_ar[i];
1986
- } else {
1987
- src0_f16 = src0_as_f16;
1988
- if (threadIdx.x == 0 && threadIdx.y == 0) {
1989
- const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type);
1990
- to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget);
1991
- }
1992
- }
1993
-
1994
- int i13 = blockIdx.x * blockDim.x + threadIdx.x;
1995
- int i12 = blockIdx.y * blockDim.y + threadIdx.y;
1996
-
1997
- if (i13 >= ne13 || i12 >= ne12) {
1998
- return;
1999
- }
2000
-
2001
- int i03 = i13 / r3;
2002
- int i02 = i12 / r2;
2003
-
2004
- ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02 + i03*nb03;
2005
- ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2;
2006
- ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
2007
- }
2008
-
2009
- static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
2010
- const struct ggml_tensor * ids = dst->src[0];
2011
- const struct ggml_tensor * src1 = dst->src[1];
2012
- const struct ggml_tensor * src00 = dst->src[2];
2013
-
2014
- const int id = dst->op_params[0];
2015
-
2016
- GGML_ASSERT(!ggml_is_transposed(src00));
2017
- GGML_ASSERT(!ggml_is_transposed(src1));
2018
-
2019
- GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
2020
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
2021
-
2022
- const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00);
2023
- const int64_t ne01 = src00->ne[1];
2024
- const int64_t ne02 = src00->ne[2];
2025
- const int64_t ne03 = src00->ne[3];
2026
-
2027
- //const int64_t nb01 = src00->nb[1];
2028
- const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02);
2029
- const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03);
2030
-
2031
- const int64_t ne10 = src1->ne[0];
2032
- const int64_t ne11 = src1->ne[1];
2033
- const int64_t ne12 = src1->ne[2];
2034
- const int64_t ne13 = src1->ne[3];
2035
-
2036
- //const int64_t nb11 = src1->nb[1];
2037
- const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
2038
- const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
2039
-
2040
- const int64_t ne1 = ggml_nelements(src1);
2041
- const int64_t ne = ggml_nelements(dst);
2042
-
2043
- ggml_cuda_set_device(g_main_device);
2044
- cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
2045
-
2046
- CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
2047
-
2048
- //ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
2049
- //void * src0_ddq = src0_extra->data_device[g_main_device];
2050
- //half * src0_as_f16 = (half *) src0_ddq;
2051
-
2052
- ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
2053
- float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
2054
-
2055
- ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
2056
- float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
2057
-
2058
- // convert src1 to fp16
2059
- const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
2060
- GGML_ASSERT(to_fp16_cuda != nullptr);
2061
-
2062
- size_t src1_as = 0;
2063
- half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
2064
- to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
2065
-
2066
- size_t dst_as = 0;
2067
- half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
2068
-
2069
- GGML_ASSERT(ne12 % ne02 == 0);
2070
- GGML_ASSERT(ne13 % ne03 == 0);
2071
-
2072
- // broadcast factors
2073
- const int64_t r2 = ne12/ne02;
2074
- const int64_t r3 = ne13/ne03;
2075
-
2076
- const half alpha_f16 = 1.0f;
2077
- const half beta_f16 = 0.0f;
2078
-
2079
- // use cublasGemmBatchedEx
2080
- const int ne23 = ne12*ne13;
2081
-
2082
- const void ** ptrs_src = nullptr;
2083
- void ** ptrs_dst = nullptr;
2084
-
2085
- size_t ptrs_src_s = 0;
2086
- size_t ptrs_dst_s = 0;
2087
-
2088
- ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
2089
- ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
2090
-
2091
- int64_t src0_ne = ggml_nelements(src00);
2092
- half * src0_as_f16 = nullptr;
2093
- size_t src0_as = 0;
2094
- if (src00->type != GGML_TYPE_F16) {
2095
- src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as);
2096
- }
2097
-
2098
- static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6");
2099
- dim3 block_dims(ne13, ne12);
2100
- k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>(
2101
- ptrs_src, ptrs_dst,
2102
- ne12, ne13,
2103
- ne23,
2104
- ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half),
2105
- nb12, nb13,
2106
- dst->nb[2], dst->nb[3],
2107
- r2, r3,
2108
- src00->type, src0_as_f16, src0_ne,
2109
- src1_as_f16, dst_f16,
2110
- (const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id,
2111
- dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr,
2112
- dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr,
2113
- dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr,
2114
- dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr
2115
- );
2116
- CUDA_CHECK(cudaGetLastError());
2117
-
2118
- CUBLAS_CHECK(
2119
- cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
2120
- ne01, ne11, ne10,
2121
- &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00,
2122
- (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10,
2123
- &beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
2124
- ne23,
2125
- CUBLAS_COMPUTE_16F,
2126
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
2127
-
2128
- if (src0_as != 0) {
2129
- ggml_cuda_pool_free(src0_as_f16, src0_as);
2130
- }
2131
- if (ptrs_src_s != 0) {
2132
- ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
2133
- }
2134
- if (ptrs_dst_s != 0) {
2135
- ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
2136
- }
2137
-
2138
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
2139
- to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
2140
-
2141
- ggml_cuda_pool_free(src1_as_f16, src1_as);
2142
- ggml_cuda_pool_free(dst_f16, dst_as);
2143
- }
2144
- #endif
2145
-
2146
  static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
2147
- #if 0
2148
- ggml_cuda_mul_mat_id_cublas(dst);
2149
- // TODO: mmq/mmv support
2150
- #endif
2151
  const ggml_tensor * src0 = dst->src[0];
2152
  const ggml_tensor * src1 = dst->src[1];
 
 
 
2153
 
2154
  cudaStream_t stream = ctx.stream();
2155
 
2156
  const size_t nb11 = src1->nb[1];
2157
  const size_t nb1 = dst->nb[1];
2158
 
2159
- const struct ggml_tensor * ids = src0;
2160
  const int32_t id = ((int32_t *) dst->op_params)[0];
2161
- const int32_t n_as = ((int32_t *) dst->op_params)[1];
2162
 
2163
  std::vector<char> ids_host(ggml_nbytes(ids));
2164
  const char * ids_dev = (const char *) ids->data;
2165
  CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
2166
  CUDA_CHECK(cudaStreamSynchronize(stream));
2167
 
 
2168
  ggml_tensor src1_row = *src1;
2169
  ggml_tensor dst_row = *dst;
2170
 
 
2171
  char * src1_original = (char *) src1->data;
2172
  char * dst_original = (char *) dst->data;
2173
 
 
 
 
 
2174
  if (src1->ne[1] == 1) {
2175
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
2176
  const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
2177
 
2178
  GGML_ASSERT(row_id >= 0 && row_id < n_as);
2179
 
2180
- const struct ggml_tensor * src0_row = dst->src[row_id + 2];
2181
-
2182
  src1_row.data = src1_original + i01*src1->nb[1];
2183
  dst_row.data = dst_original + i01*dst->nb[1];
2184
 
2185
- ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
2186
  }
2187
  } else {
2188
  ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
@@ -2192,8 +2012,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2192
  dst_row.data = dst_contiguous.get();
2193
 
2194
  for (int32_t row_id = 0; row_id < n_as; ++row_id) {
2195
- const struct ggml_tensor * src0_row = dst->src[row_id + 2];
2196
-
2197
  int64_t num_src1_rows = 0;
2198
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
2199
  const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
@@ -2213,6 +2031,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2213
  continue;
2214
  }
2215
 
 
 
2216
  src1_row.ne[1] = num_src1_rows;
2217
  dst_row.ne[1] = num_src1_rows;
2218
 
@@ -2224,7 +2044,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2224
  dst_row.nb[2] = num_src1_rows*nb1;
2225
  dst_row.nb[3] = num_src1_rows*nb1;
2226
 
2227
- ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
2228
 
2229
  num_src1_rows = 0;
2230
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
@@ -2389,7 +2209,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2389
  cudaError_t err = cudaGetLastError();
2390
  if (err != cudaSuccess) {
2391
  fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
2392
- GGML_ASSERT(false);
2393
  }
2394
 
2395
  return true;
 
401
  GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
402
  ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
403
 
404
+ if (tensor->view_src != NULL) {
405
  assert(tensor->view_src->buffer->buft == buffer->buft);
 
 
406
  return;
407
  }
408
 
 
1960
  }
1961
  }
1962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1963
  static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
 
 
 
1964
  const ggml_tensor * src0 = dst->src[0];
1965
  const ggml_tensor * src1 = dst->src[1];
1966
+ const ggml_tensor * ids = dst->src[2];
1967
+
1968
+ GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
1969
 
1970
  cudaStream_t stream = ctx.stream();
1971
 
1972
  const size_t nb11 = src1->nb[1];
1973
  const size_t nb1 = dst->nb[1];
1974
 
 
1975
  const int32_t id = ((int32_t *) dst->op_params)[0];
1976
+ const int32_t n_as = src0->ne[2];
1977
 
1978
  std::vector<char> ids_host(ggml_nbytes(ids));
1979
  const char * ids_dev = (const char *) ids->data;
1980
  CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
1981
  CUDA_CHECK(cudaStreamSynchronize(stream));
1982
 
1983
+ ggml_tensor src0_row = *src0;
1984
  ggml_tensor src1_row = *src1;
1985
  ggml_tensor dst_row = *dst;
1986
 
1987
+ char * src0_original = (char *) src0->data;
1988
  char * src1_original = (char *) src1->data;
1989
  char * dst_original = (char *) dst->data;
1990
 
1991
+ src0_row.ne[2] = 1;
1992
+ src0_row.ne[3] = 1;
1993
+ src0_row.nb[3] = src0->nb[2];
1994
+
1995
  if (src1->ne[1] == 1) {
1996
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
1997
  const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
1998
 
1999
  GGML_ASSERT(row_id >= 0 && row_id < n_as);
2000
 
2001
+ src0_row.data = src0_original + row_id*src0->nb[2];
 
2002
  src1_row.data = src1_original + i01*src1->nb[1];
2003
  dst_row.data = dst_original + i01*dst->nb[1];
2004
 
2005
+ ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
2006
  }
2007
  } else {
2008
  ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
 
2012
  dst_row.data = dst_contiguous.get();
2013
 
2014
  for (int32_t row_id = 0; row_id < n_as; ++row_id) {
 
 
2015
  int64_t num_src1_rows = 0;
2016
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
2017
  const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
 
2031
  continue;
2032
  }
2033
 
2034
+ src0_row.data = src0_original + row_id*src0->nb[2];
2035
+
2036
  src1_row.ne[1] = num_src1_rows;
2037
  dst_row.ne[1] = num_src1_rows;
2038
 
 
2044
  dst_row.nb[2] = num_src1_rows*nb1;
2045
  dst_row.nb[3] = num_src1_rows*nb1;
2046
 
2047
+ ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
2048
 
2049
  num_src1_rows = 0;
2050
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
 
2209
  cudaError_t err = cudaGetLastError();
2210
  if (err != cudaSuccess) {
2211
  fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
2212
+ CUDA_CHECK(err);
2213
  }
2214
 
2215
  return true;
ggml-cuda/argsort.cu CHANGED
@@ -8,32 +8,41 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
8
  }
9
 
10
  template<ggml_sort_order order>
11
- static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
12
  // bitonic sort
13
  int col = threadIdx.x;
14
  int row = blockIdx.y;
15
 
16
- if (col >= ncols) return;
 
 
17
 
18
  const float * x_row = x + row * ncols;
19
- int * dst_row = dst + row * ncols;
20
 
21
  // initialize indices
22
- if (col < ncols) {
23
- dst_row[col] = col;
24
- }
25
  __syncthreads();
26
 
27
- for (int k = 2; k <= ncols; k *= 2) {
28
  for (int j = k / 2; j > 0; j /= 2) {
29
  int ixj = col ^ j;
30
  if (ixj > col) {
31
  if ((col & k) == 0) {
32
- if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
 
 
 
 
33
  ggml_cuda_swap(dst_row[col], dst_row[ixj]);
34
  }
35
  } else {
36
- if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
 
 
 
 
37
  ggml_cuda_swap(dst_row[col], dst_row[ixj]);
38
  }
39
  }
@@ -41,18 +50,35 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
41
  __syncthreads();
42
  }
43
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  }
45
 
46
  static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
47
  // bitonic sort requires ncols to be power of 2
48
- GGML_ASSERT((ncols & (ncols - 1)) == 0);
49
 
50
- const dim3 block_dims(ncols, 1, 1);
51
  const dim3 block_nums(1, nrows, 1);
 
 
 
 
52
  if (order == GGML_SORT_ORDER_ASC) {
53
- k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
54
  } else if (order == GGML_SORT_ORDER_DESC) {
55
- k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
56
  } else {
57
  GGML_ASSERT(false);
58
  }
 
8
  }
9
 
10
  template<ggml_sort_order order>
11
+ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
12
  // bitonic sort
13
  int col = threadIdx.x;
14
  int row = blockIdx.y;
15
 
16
+ if (col >= ncols_pad) {
17
+ return;
18
+ }
19
 
20
  const float * x_row = x + row * ncols;
21
+ extern __shared__ int dst_row[];
22
 
23
  // initialize indices
24
+ dst_row[col] = col;
25
+
 
26
  __syncthreads();
27
 
28
+ for (int k = 2; k <= ncols_pad; k *= 2) {
29
  for (int j = k / 2; j > 0; j /= 2) {
30
  int ixj = col ^ j;
31
  if (ixj > col) {
32
  if ((col & k) == 0) {
33
+ if (dst_row[col] >= ncols ||
34
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
35
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
36
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
37
+ ) {
38
  ggml_cuda_swap(dst_row[col], dst_row[ixj]);
39
  }
40
  } else {
41
+ if (dst_row[ixj] >= ncols ||
42
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
43
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
44
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
45
+ ) {
46
  ggml_cuda_swap(dst_row[col], dst_row[ixj]);
47
  }
48
  }
 
50
  __syncthreads();
51
  }
52
  }
53
+
54
+ // copy the result to dst without the padding
55
+ if (col < ncols) {
56
+ dst[row * ncols + col] = dst_row[col];
57
+ }
58
+ }
59
+
60
+ static int next_power_of_2(int x) {
61
+ int n = 1;
62
+ while (n < x) {
63
+ n *= 2;
64
+ }
65
+ return n;
66
  }
67
 
68
  static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
69
  // bitonic sort requires ncols to be power of 2
70
+ const int ncols_pad = next_power_of_2(ncols);
71
 
72
+ const dim3 block_dims(ncols_pad, 1, 1);
73
  const dim3 block_nums(1, nrows, 1);
74
+ const size_t shared_mem = ncols_pad * sizeof(int);
75
+
76
+ GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
77
+
78
  if (order == GGML_SORT_ORDER_ASC) {
79
+ k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
80
  } else if (order == GGML_SORT_ORDER_DESC) {
81
+ k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
82
  } else {
83
  GGML_ASSERT(false);
84
  }
ggml-metal.m CHANGED
@@ -1685,37 +1685,31 @@ static enum ggml_status ggml_metal_graph_compute(
1685
  {
1686
  //GGML_ASSERT(ne00 == ne10);
1687
  //GGML_ASSERT(ne03 == ne13);
1688
-
1689
- GGML_ASSERT(src0t == GGML_TYPE_I32);
1690
-
1691
- const int n_as = ((int32_t *) dst->op_params)[1];
1692
-
1693
- // TODO: make this more general
1694
- GGML_ASSERT(n_as <= 8);
1695
 
1696
  // max size of the src1ids array in the kernel shared buffer
1697
  GGML_ASSERT(ne11 <= 4096);
1698
 
1699
- const int64_t ne20 = src2 ? src2->ne[0] : 0;
1700
- const int64_t ne21 = src2 ? src2->ne[1] : 0;
1701
- const int64_t ne22 = src2 ? src2->ne[2] : 0;
1702
- const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
 
 
 
 
 
 
1703
 
1704
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1705
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1706
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1707
- const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1708
 
1709
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
1710
 
1711
- GGML_ASSERT(!ggml_is_transposed(src2));
1712
  GGML_ASSERT(!ggml_is_transposed(src1));
1713
 
1714
  GGML_ASSERT(src1t == GGML_TYPE_F32);
1715
 
1716
- const uint r2 = ne12/ne22;
1717
- const uint r3 = ne13/ne23;
1718
-
1719
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1720
  // to the matrix-vector kernel
1721
  int ne11_mm_min = n_as;
@@ -1723,7 +1717,10 @@ static enum ggml_status ggml_metal_graph_compute(
1723
  const int idx = ((int32_t *) dst->op_params)[0];
1724
 
1725
  // batch size
1726
- GGML_ASSERT(ne01 == ne11);
 
 
 
1727
 
1728
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1729
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -1732,7 +1729,7 @@ static enum ggml_status ggml_metal_graph_compute(
1732
  // indirect matrix multiplication
1733
  // !!!
1734
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1735
- ne20 % 32 == 0 && ne20 >= 64 &&
1736
  ne11 > ne11_mm_min) {
1737
 
1738
  // some Metal matrix data types require aligned pointers
@@ -1745,7 +1742,7 @@ static enum ggml_status ggml_metal_graph_compute(
1745
 
1746
  id<MTLComputePipelineState> pipeline = nil;
1747
 
1748
- switch (src2->type) {
1749
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
1750
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
1751
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
@@ -1774,36 +1771,27 @@ static enum ggml_status ggml_metal_graph_compute(
1774
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1775
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1776
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1777
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1778
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1779
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1780
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1781
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1782
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1783
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1784
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1785
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1786
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1787
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1788
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1789
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1790
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1791
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1792
- [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1793
- // TODO: how to make this an array? read Metal docs
1794
- for (int j = 0; j < 8; ++j) {
1795
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1796
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1797
-
1798
- size_t offs_src_cur = 0;
1799
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
1800
-
1801
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1802
- }
1803
 
1804
  [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
1805
 
1806
- [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1807
  } else {
1808
  int nth0 = 32;
1809
  int nth1 = 1;
@@ -1813,7 +1801,7 @@ static enum ggml_status ggml_metal_graph_compute(
1813
  id<MTLComputePipelineState> pipeline = nil;
1814
 
1815
  // use custom matrix x vector kernel
1816
- switch (src2t) {
1817
  case GGML_TYPE_F32:
1818
  {
1819
  GGML_ASSERT(src1t == GGML_TYPE_F32);
@@ -1947,8 +1935,8 @@ static enum ggml_status ggml_metal_graph_compute(
1947
  }
1948
  };
1949
 
1950
- if (ggml_is_quantized(src2t)) {
1951
- GGML_ASSERT(ne20 >= nth0*nth1);
1952
  }
1953
 
1954
  const int64_t _ne1 = 1; // kernels needs a reference in constant memory
@@ -1957,75 +1945,66 @@ static enum ggml_status ggml_metal_graph_compute(
1957
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1958
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1959
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1960
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1961
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1962
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1963
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1964
- [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1965
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1966
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1967
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1968
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1969
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1970
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1971
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1972
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1973
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1974
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1975
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1976
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1977
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1978
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1979
- [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1980
- // TODO: how to make this an array? read Metal docs
1981
- for (int j = 0; j < 8; ++j) {
1982
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1983
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1984
-
1985
- size_t offs_src_cur = 0;
1986
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
1987
-
1988
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1989
- }
1990
 
1991
- if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || src2t == GGML_TYPE_Q5_0 ||
1992
- src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || src2t == GGML_TYPE_Q2_K ||
1993
- src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ1_M || src2t == GGML_TYPE_IQ2_S) {
1994
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1995
  }
1996
- else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
1997
- const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1998
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1999
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2000
  }
2001
- else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
2002
- const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2003
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2004
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2005
  }
2006
- else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
2007
  const int mem_size = 32*sizeof(float);
2008
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2009
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2010
  }
2011
- else if (src2t == GGML_TYPE_Q4_K) {
2012
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2013
  }
2014
- else if (src2t == GGML_TYPE_Q3_K) {
2015
  #ifdef GGML_QKK_64
2016
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2017
  #else
2018
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2019
  #endif
2020
  }
2021
- else if (src2t == GGML_TYPE_Q5_K) {
2022
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2023
  }
2024
- else if (src2t == GGML_TYPE_Q6_K) {
2025
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2026
  } else {
2027
  const int64_t ny = (_ne1 + nrows - 1)/nrows;
2028
- [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2029
  }
2030
  }
2031
  } break;
@@ -2432,6 +2411,16 @@ static enum ggml_status ggml_metal_graph_compute(
2432
 
2433
  enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2434
 
 
 
 
 
 
 
 
 
 
 
2435
  id<MTLComputePipelineState> pipeline = nil;
2436
 
2437
  switch (order) {
@@ -2441,11 +2430,13 @@ static enum ggml_status ggml_metal_graph_compute(
2441
  };
2442
 
2443
  [encoder setComputePipelineState:pipeline];
2444
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2445
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2446
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
 
 
2447
 
2448
- [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2449
  } break;
2450
  case GGML_OP_LEAKY_RELU:
2451
  {
 
1685
  {
1686
  //GGML_ASSERT(ne00 == ne10);
1687
  //GGML_ASSERT(ne03 == ne13);
1688
+ const int n_as = src0->ne[2];
 
 
 
 
 
 
1689
 
1690
  // max size of the src1ids array in the kernel shared buffer
1691
  GGML_ASSERT(ne11 <= 4096);
1692
 
1693
+ // src2 = ids
1694
+ const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
1695
+ const int64_t ne21 = src2->ne[1];
1696
+ const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
1697
+ const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
1698
+
1699
+ const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
1700
+ const uint64_t nb21 = src2->nb[1];
1701
+ const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
1702
+ const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
1703
 
1704
+ const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
 
 
 
1705
 
1706
+ GGML_ASSERT(src2t == GGML_TYPE_I32);
1707
 
1708
+ GGML_ASSERT(!ggml_is_transposed(src0));
1709
  GGML_ASSERT(!ggml_is_transposed(src1));
1710
 
1711
  GGML_ASSERT(src1t == GGML_TYPE_F32);
1712
 
 
 
 
1713
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1714
  // to the matrix-vector kernel
1715
  int ne11_mm_min = n_as;
 
1717
  const int idx = ((int32_t *) dst->op_params)[0];
1718
 
1719
  // batch size
1720
+ GGML_ASSERT(ne21 == ne11); // ?
1721
+ GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
1722
+ const uint r2 = 1;
1723
+ const uint r3 = 1;
1724
 
1725
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1726
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
 
1729
  // indirect matrix multiplication
1730
  // !!!
1731
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1732
+ ne00 % 32 == 0 && ne00 >= 64 &&
1733
  ne11 > ne11_mm_min) {
1734
 
1735
  // some Metal matrix data types require aligned pointers
 
1742
 
1743
  id<MTLComputePipelineState> pipeline = nil;
1744
 
1745
+ switch (src0->type) {
1746
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
1747
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
1748
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
 
1771
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1772
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1773
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1774
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1775
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1776
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1777
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1778
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1779
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1780
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
1781
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
1782
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
1783
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
1784
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
1785
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
1786
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
1787
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
1788
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1789
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1790
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:19];
 
 
 
 
 
 
 
 
 
1791
 
1792
  [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
1793
 
1794
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1795
  } else {
1796
  int nth0 = 32;
1797
  int nth1 = 1;
 
1801
  id<MTLComputePipelineState> pipeline = nil;
1802
 
1803
  // use custom matrix x vector kernel
1804
+ switch (src0t) {
1805
  case GGML_TYPE_F32:
1806
  {
1807
  GGML_ASSERT(src1t == GGML_TYPE_F32);
 
1935
  }
1936
  };
1937
 
1938
+ if (ggml_is_quantized(src0t)) {
1939
+ GGML_ASSERT(ne00 >= nth0*nth1);
1940
  }
1941
 
1942
  const int64_t _ne1 = 1; // kernels needs a reference in constant memory
 
1945
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1946
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1947
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1948
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1949
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1950
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1951
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
1952
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
1953
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
1954
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
1955
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
1956
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1957
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
1958
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1959
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1960
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1961
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1962
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1963
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1964
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
1965
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
1966
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
1967
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
1968
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:23];
 
 
 
 
 
 
 
 
 
1969
 
1970
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
1971
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
1972
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
1973
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1974
  }
1975
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1976
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1977
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1978
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1979
  }
1980
+ else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
1981
+ const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1982
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1983
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1984
  }
1985
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
1986
  const int mem_size = 32*sizeof(float);
1987
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1988
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1989
  }
1990
+ else if (src0t == GGML_TYPE_Q4_K) {
1991
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1992
  }
1993
+ else if (src0t == GGML_TYPE_Q3_K) {
1994
  #ifdef GGML_QKK_64
1995
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1996
  #else
1997
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1998
  #endif
1999
  }
2000
+ else if (src0t == GGML_TYPE_Q5_K) {
2001
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2002
  }
2003
+ else if (src0t == GGML_TYPE_Q6_K) {
2004
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2005
  } else {
2006
  const int64_t ny = (_ne1 + nrows - 1)/nrows;
2007
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2008
  }
2009
  }
2010
  } break;
 
2411
 
2412
  enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2413
 
2414
+ // bitonic sort requires the number of elements to be power of 2
2415
+ int64_t ne00_padded = 1;
2416
+ while (ne00_padded < ne00) {
2417
+ ne00_padded *= 2;
2418
+ }
2419
+
2420
+ // Metal kernels require the buffer size to be multiple of 16 bytes
2421
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
2422
+ const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
2423
+
2424
  id<MTLComputePipelineState> pipeline = nil;
2425
 
2426
  switch (order) {
 
2430
  };
2431
 
2432
  [encoder setComputePipelineState:pipeline];
2433
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2434
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2435
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2436
+ [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
2437
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2438
 
2439
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
2440
  } break;
2441
  case GGML_OP_LEAKY_RELU:
2442
  {
ggml-metal.metal CHANGED
@@ -13,8 +13,8 @@ using namespace metal;
13
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
14
 
15
  enum ggml_sort_order {
16
- GGML_SORT_ASC,
17
- GGML_SORT_DESC,
18
  };
19
 
20
  // general-purpose kernel for addition, multiplication and division of two tensors
@@ -1973,9 +1973,11 @@ kernel void kernel_timestep_embedding_f32(
1973
 
1974
  // bitonic sort implementation following the CUDA kernels as reference
1975
  typedef void (argsort_t)(
1976
- device const float * x,
1977
- device int32_t * dst,
1978
- constant int64_t & ncols,
 
 
1979
  uint3 tgpig[[threadgroup_position_in_grid]],
1980
  uint3 tpitg[[thread_position_in_threadgroup]]);
1981
 
@@ -1984,33 +1986,42 @@ kernel void kernel_argsort_f32_i32(
1984
  device const float * x,
1985
  device int32_t * dst,
1986
  constant int64_t & ncols,
 
 
1987
  uint3 tgpig[[threadgroup_position_in_grid]],
1988
  uint3 tpitg[[thread_position_in_threadgroup]]) {
1989
  // bitonic sort
1990
  int col = tpitg[0];
1991
  int row = tgpig[1];
1992
 
1993
- if (col >= ncols) return;
1994
 
1995
- device const float * x_row = x + row * ncols;
1996
- device int32_t * dst_row = dst + row * ncols;
1997
 
1998
  // initialize indices
1999
- if (col < ncols) {
2000
- dst_row[col] = col;
2001
- }
2002
  threadgroup_barrier(mem_flags::mem_threadgroup);
2003
 
2004
- for (int k = 2; k <= ncols; k *= 2) {
2005
  for (int j = k / 2; j > 0; j /= 2) {
2006
  int ixj = col ^ j;
2007
  if (ixj > col) {
2008
  if ((col & k) == 0) {
2009
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
 
 
 
 
2010
  SWAP(dst_row[col], dst_row[ixj]);
2011
  }
2012
  } else {
2013
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
 
 
 
 
2014
  SWAP(dst_row[col], dst_row[ixj]);
2015
  }
2016
  }
@@ -2018,10 +2029,15 @@ kernel void kernel_argsort_f32_i32(
2018
  threadgroup_barrier(mem_flags::mem_threadgroup);
2019
  }
2020
  }
 
 
 
 
 
2021
  }
2022
 
2023
- template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
2024
- template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
2025
 
2026
  kernel void kernel_leaky_relu_f32(
2027
  device const float * src0,
@@ -5785,9 +5801,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
5785
 
5786
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
5787
  kernel void kernel_mul_mm_id(
5788
- device const uchar * ids,
5789
  device const uchar * src1,
5790
  device float * dst,
 
5791
  constant uint64_t & nbi1,
5792
  constant int64_t & ne00,
5793
  constant int64_t & ne02,
@@ -5804,22 +5821,14 @@ kernel void kernel_mul_mm_id(
5804
  constant uint & r2,
5805
  constant uint & r3,
5806
  constant int & idx,
5807
- device const uchar * src00,
5808
- device const uchar * src01,
5809
- device const uchar * src02,
5810
- device const uchar * src03,
5811
- device const uchar * src04,
5812
- device const uchar * src05,
5813
- device const uchar * src06,
5814
- device const uchar * src07,
5815
  threadgroup uchar * shared_memory [[threadgroup(0)]],
5816
  uint3 tgpig[[threadgroup_position_in_grid]],
5817
  uint tiitg[[thread_index_in_threadgroup]],
5818
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5819
- device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5820
 
5821
  // expert id
5822
  const int32_t id = tgpig.z/(ne12*ne13);
 
5823
 
5824
  tgpig.z = tgpig.z%(ne12*ne13);
5825
 
@@ -5834,7 +5843,7 @@ kernel void kernel_mul_mm_id(
5834
  }
5835
 
5836
  kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
5837
- src0s[id],
5838
  src1,
5839
  src1ids,
5840
  dst,
@@ -5960,9 +5969,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
5960
  //
5961
 
5962
  typedef void (mat_mm_id_t)(
5963
- device const uchar * ids,
5964
  device const uchar * src1,
5965
  device float * dst,
 
5966
  constant uint64_t & nbi1,
5967
  constant int64_t & ne00,
5968
  constant int64_t & ne02,
@@ -5979,14 +5989,6 @@ typedef void (mat_mm_id_t)(
5979
  constant uint & r2,
5980
  constant uint & r3,
5981
  constant int & idx,
5982
- device const uchar * src00,
5983
- device const uchar * src01,
5984
- device const uchar * src02,
5985
- device const uchar * src03,
5986
- device const uchar * src04,
5987
- device const uchar * src05,
5988
- device const uchar * src06,
5989
- device const uchar * src07,
5990
  threadgroup uchar *,
5991
  uint3, uint, uint);
5992
 
@@ -6022,9 +6024,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
6022
 
6023
  [[host_name("kernel_mul_mv_id_f32_f32")]]
6024
  kernel void kernel_mul_mv_id_f32_f32(
6025
- device const char * ids,
6026
  device const char * src1,
6027
  device float * dst,
 
6028
  constant uint64_t & nbi1,
6029
  constant int64_t & ne00,
6030
  constant int64_t & ne01,
@@ -6045,28 +6048,19 @@ kernel void kernel_mul_mv_id_f32_f32(
6045
  constant uint & r2,
6046
  constant uint & r3,
6047
  constant int & idx,
6048
- device const char * src00,
6049
- device const char * src01,
6050
- device const char * src02,
6051
- device const char * src03,
6052
- device const char * src04,
6053
- device const char * src05,
6054
- device const char * src06,
6055
- device const char * src07,
6056
  uint3 tgpig[[threadgroup_position_in_grid]],
6057
  uint tiitg[[thread_index_in_threadgroup]],
6058
  uint tiisg[[thread_index_in_simdgroup]],
6059
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6060
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6061
-
6062
  const int64_t bid = tgpig.z/(ne12*ne13);
6063
 
6064
  tgpig.z = tgpig.z%(ne12*ne13);
6065
 
6066
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
6067
 
6068
  kernel_mul_mv_f32_f32_impl(
6069
- src0[id],
6070
  src1 + bid*nb11,
6071
  dst + bid*ne0,
6072
  ne00,
@@ -6091,9 +6085,10 @@ kernel void kernel_mul_mv_id_f32_f32(
6091
 
6092
  [[host_name("kernel_mul_mv_id_f16_f32")]]
6093
  kernel void kernel_mul_mv_id_f16_f32(
6094
- device const char * ids,
6095
  device const char * src1,
6096
  device float * dst,
 
6097
  constant uint64_t & nbi1,
6098
  constant int64_t & ne00,
6099
  constant int64_t & ne01,
@@ -6114,28 +6109,19 @@ kernel void kernel_mul_mv_id_f16_f32(
6114
  constant uint & r2,
6115
  constant uint & r3,
6116
  constant int & idx,
6117
- device const char * src00,
6118
- device const char * src01,
6119
- device const char * src02,
6120
- device const char * src03,
6121
- device const char * src04,
6122
- device const char * src05,
6123
- device const char * src06,
6124
- device const char * src07,
6125
  uint3 tgpig[[threadgroup_position_in_grid]],
6126
  uint tiitg[[thread_index_in_threadgroup]],
6127
  uint tiisg[[thread_index_in_simdgroup]],
6128
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6129
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6130
-
6131
  const int64_t bid = tgpig.z/(ne12*ne13);
6132
 
6133
  tgpig.z = tgpig.z%(ne12*ne13);
6134
 
6135
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
6136
 
6137
  kernel_mul_mv_f16_f32_impl(
6138
- src0[id],
6139
  src1 + bid*nb11,
6140
  dst + bid*ne0,
6141
  ne00,
@@ -6160,9 +6146,10 @@ kernel void kernel_mul_mv_id_f16_f32(
6160
 
6161
  [[host_name("kernel_mul_mv_id_q8_0_f32")]]
6162
  kernel void kernel_mul_mv_id_q8_0_f32(
6163
- device const char * ids,
6164
  device const char * src1,
6165
  device float * dst,
 
6166
  constant uint64_t & nbi1,
6167
  constant int64_t & ne00,
6168
  constant int64_t & ne01,
@@ -6183,28 +6170,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
6183
  constant uint & r2,
6184
  constant uint & r3,
6185
  constant int & idx,
6186
- device const char * src00,
6187
- device const char * src01,
6188
- device const char * src02,
6189
- device const char * src03,
6190
- device const char * src04,
6191
- device const char * src05,
6192
- device const char * src06,
6193
- device const char * src07,
6194
  uint3 tgpig[[threadgroup_position_in_grid]],
6195
  uint tiitg[[thread_index_in_threadgroup]],
6196
  uint tiisg[[thread_index_in_simdgroup]],
6197
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6198
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6199
-
6200
  const int64_t bid = tgpig.z/(ne12*ne13);
6201
 
6202
  tgpig.z = tgpig.z%(ne12*ne13);
6203
 
6204
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
6205
 
6206
  kernel_mul_mv_q8_0_f32_impl(
6207
- src0[id],
6208
  (device const float *) (src1 + bid*nb11),
6209
  dst + bid*ne0,
6210
  ne00,
@@ -6223,9 +6201,10 @@ kernel void kernel_mul_mv_id_q8_0_f32(
6223
 
6224
  [[host_name("kernel_mul_mv_id_q4_0_f32")]]
6225
  kernel void kernel_mul_mv_id_q4_0_f32(
6226
- device const char * ids,
6227
  device const char * src1,
6228
  device float * dst,
 
6229
  constant uint64_t & nbi1,
6230
  constant int64_t & ne00,
6231
  constant int64_t & ne01,
@@ -6246,28 +6225,19 @@ kernel void kernel_mul_mv_id_q4_0_f32(
6246
  constant uint & r2,
6247
  constant uint & r3,
6248
  constant int & idx,
6249
- device const char * src00,
6250
- device const char * src01,
6251
- device const char * src02,
6252
- device const char * src03,
6253
- device const char * src04,
6254
- device const char * src05,
6255
- device const char * src06,
6256
- device const char * src07,
6257
  uint3 tgpig[[threadgroup_position_in_grid]],
6258
  uint tiitg[[thread_index_in_threadgroup]],
6259
  uint tiisg[[thread_index_in_simdgroup]],
6260
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6261
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6262
-
6263
  const int64_t bid = tgpig.z/(ne12*ne13);
6264
 
6265
  tgpig.z = tgpig.z%(ne12*ne13);
6266
 
6267
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
6268
 
6269
  mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6270
- src0[id],
6271
  (device const float *) (src1 + bid*nb11),
6272
  dst + bid*ne0,
6273
  ne00,
@@ -6286,9 +6256,10 @@ kernel void kernel_mul_mv_id_q4_0_f32(
6286
 
6287
  [[host_name("kernel_mul_mv_id_q4_1_f32")]]
6288
  kernel void kernel_mul_mv_id_q4_1_f32(
6289
- device const char * ids,
6290
  device const char * src1,
6291
  device float * dst,
 
6292
  constant uint64_t & nbi1,
6293
  constant int64_t & ne00,
6294
  constant int64_t & ne01,
@@ -6309,28 +6280,19 @@ kernel void kernel_mul_mv_id_q4_1_f32(
6309
  constant uint & r2,
6310
  constant uint & r3,
6311
  constant int & idx,
6312
- device const char * src00,
6313
- device const char * src01,
6314
- device const char * src02,
6315
- device const char * src03,
6316
- device const char * src04,
6317
- device const char * src05,
6318
- device const char * src06,
6319
- device const char * src07,
6320
  uint3 tgpig[[threadgroup_position_in_grid]],
6321
  uint tiitg[[thread_index_in_threadgroup]],
6322
  uint tiisg[[thread_index_in_simdgroup]],
6323
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6324
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6325
-
6326
  const int64_t bid = tgpig.z/(ne12*ne13);
6327
 
6328
  tgpig.z = tgpig.z%(ne12*ne13);
6329
 
6330
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
6331
 
6332
  mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6333
- src0[id],
6334
  (device const float *) (src1 + bid*nb11),
6335
  dst + bid*ne0,
6336
  ne00,
@@ -6349,9 +6311,10 @@ kernel void kernel_mul_mv_id_q4_1_f32(
6349
 
6350
  [[host_name("kernel_mul_mv_id_q5_0_f32")]]
6351
  kernel void kernel_mul_mv_id_q5_0_f32(
6352
- device const char * ids,
6353
  device const char * src1,
6354
  device float * dst,
 
6355
  constant uint64_t & nbi1,
6356
  constant int64_t & ne00,
6357
  constant int64_t & ne01,
@@ -6372,28 +6335,19 @@ kernel void kernel_mul_mv_id_q5_0_f32(
6372
  constant uint & r2,
6373
  constant uint & r3,
6374
  constant int & idx,
6375
- device const char * src00,
6376
- device const char * src01,
6377
- device const char * src02,
6378
- device const char * src03,
6379
- device const char * src04,
6380
- device const char * src05,
6381
- device const char * src06,
6382
- device const char * src07,
6383
  uint3 tgpig[[threadgroup_position_in_grid]],
6384
  uint tiitg[[thread_index_in_threadgroup]],
6385
  uint tiisg[[thread_index_in_simdgroup]],
6386
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6387
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6388
-
6389
  const int64_t bid = tgpig.z/(ne12*ne13);
6390
 
6391
  tgpig.z = tgpig.z%(ne12*ne13);
6392
 
6393
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
6394
 
6395
  mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6396
- src0[id],
6397
  (device const float *) (src1 + bid*nb11),
6398
  dst + bid*ne0,
6399
  ne00,
@@ -6412,9 +6366,10 @@ kernel void kernel_mul_mv_id_q5_0_f32(
6412
 
6413
  [[host_name("kernel_mul_mv_id_q5_1_f32")]]
6414
  kernel void kernel_mul_mv_id_q5_1_f32(
6415
- device const char * ids,
6416
  device const char * src1,
6417
  device float * dst,
 
6418
  constant uint64_t & nbi1,
6419
  constant int64_t & ne00,
6420
  constant int64_t & ne01,
@@ -6435,28 +6390,19 @@ kernel void kernel_mul_mv_id_q5_1_f32(
6435
  constant uint & r2,
6436
  constant uint & r3,
6437
  constant int & idx,
6438
- device const char * src00,
6439
- device const char * src01,
6440
- device const char * src02,
6441
- device const char * src03,
6442
- device const char * src04,
6443
- device const char * src05,
6444
- device const char * src06,
6445
- device const char * src07,
6446
  uint3 tgpig[[threadgroup_position_in_grid]],
6447
  uint tiitg[[thread_index_in_threadgroup]],
6448
  uint tiisg[[thread_index_in_simdgroup]],
6449
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6450
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6451
-
6452
  const int64_t bid = tgpig.z/(ne12*ne13);
6453
 
6454
  tgpig.z = tgpig.z%(ne12*ne13);
6455
 
6456
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
6457
 
6458
  mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6459
- src0[id],
6460
  (device const float *) (src1 + bid*nb11),
6461
  dst + bid*ne0,
6462
  ne00,
@@ -6475,9 +6421,10 @@ kernel void kernel_mul_mv_id_q5_1_f32(
6475
 
6476
  [[host_name("kernel_mul_mv_id_q2_K_f32")]]
6477
  kernel void kernel_mul_mv_id_q2_K_f32(
6478
- device const char * ids,
6479
  device const char * src1,
6480
  device float * dst,
 
6481
  constant uint64_t & nbi1,
6482
  constant int64_t & ne00,
6483
  constant int64_t & ne01,
@@ -6498,28 +6445,19 @@ kernel void kernel_mul_mv_id_q2_K_f32(
6498
  constant uint & r2,
6499
  constant uint & r3,
6500
  constant int & idx,
6501
- device const char * src00,
6502
- device const char * src01,
6503
- device const char * src02,
6504
- device const char * src03,
6505
- device const char * src04,
6506
- device const char * src05,
6507
- device const char * src06,
6508
- device const char * src07,
6509
  uint3 tgpig[[threadgroup_position_in_grid]],
6510
  uint tiitg[[thread_index_in_threadgroup]],
6511
  uint tiisg[[thread_index_in_simdgroup]],
6512
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6513
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6514
-
6515
  const int64_t bid = tgpig.z/(ne12*ne13);
6516
 
6517
  tgpig.z = tgpig.z%(ne12*ne13);
6518
 
6519
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
6520
 
6521
  kernel_mul_mv_q2_K_f32_impl(
6522
- src0[id],
6523
  (device const float *) (src1 + bid*nb11),
6524
  dst + bid*ne0,
6525
  ne00,
@@ -6538,9 +6476,10 @@ kernel void kernel_mul_mv_id_q2_K_f32(
6538
 
6539
  [[host_name("kernel_mul_mv_id_q3_K_f32")]]
6540
  kernel void kernel_mul_mv_id_q3_K_f32(
6541
- device const char * ids,
6542
  device const char * src1,
6543
  device float * dst,
 
6544
  constant uint64_t & nbi1,
6545
  constant int64_t & ne00,
6546
  constant int64_t & ne01,
@@ -6561,28 +6500,19 @@ kernel void kernel_mul_mv_id_q3_K_f32(
6561
  constant uint & r2,
6562
  constant uint & r3,
6563
  constant int & idx,
6564
- device const char * src00,
6565
- device const char * src01,
6566
- device const char * src02,
6567
- device const char * src03,
6568
- device const char * src04,
6569
- device const char * src05,
6570
- device const char * src06,
6571
- device const char * src07,
6572
  uint3 tgpig[[threadgroup_position_in_grid]],
6573
  uint tiitg[[thread_index_in_threadgroup]],
6574
  uint tiisg[[thread_index_in_simdgroup]],
6575
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6576
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6577
-
6578
  const int64_t bid = tgpig.z/(ne12*ne13);
6579
 
6580
  tgpig.z = tgpig.z%(ne12*ne13);
6581
 
6582
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
6583
 
6584
  kernel_mul_mv_q3_K_f32_impl(
6585
- src0[id],
6586
  (device const float *) (src1 + bid*nb11),
6587
  dst + bid*ne0,
6588
  ne00,
@@ -6601,9 +6531,10 @@ kernel void kernel_mul_mv_id_q3_K_f32(
6601
 
6602
  [[host_name("kernel_mul_mv_id_q4_K_f32")]]
6603
  kernel void kernel_mul_mv_id_q4_K_f32(
6604
- device const char * ids,
6605
  device const char * src1,
6606
  device float * dst,
 
6607
  constant uint64_t & nbi1,
6608
  constant int64_t & ne00,
6609
  constant int64_t & ne01,
@@ -6624,28 +6555,19 @@ kernel void kernel_mul_mv_id_q4_K_f32(
6624
  constant uint & r2,
6625
  constant uint & r3,
6626
  constant int & idx,
6627
- device const char * src00,
6628
- device const char * src01,
6629
- device const char * src02,
6630
- device const char * src03,
6631
- device const char * src04,
6632
- device const char * src05,
6633
- device const char * src06,
6634
- device const char * src07,
6635
  uint3 tgpig[[threadgroup_position_in_grid]],
6636
  uint tiitg[[thread_index_in_threadgroup]],
6637
  uint tiisg[[thread_index_in_simdgroup]],
6638
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6639
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6640
-
6641
  const int64_t bid = tgpig.z/(ne12*ne13);
6642
 
6643
  tgpig.z = tgpig.z%(ne12*ne13);
6644
 
6645
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
6646
 
6647
  kernel_mul_mv_q4_K_f32_impl(
6648
- src0[id],
6649
  (device const float *) (src1 + bid*nb11),
6650
  dst + bid*ne0,
6651
  ne00,
@@ -6664,9 +6586,10 @@ kernel void kernel_mul_mv_id_q4_K_f32(
6664
 
6665
  [[host_name("kernel_mul_mv_id_q5_K_f32")]]
6666
  kernel void kernel_mul_mv_id_q5_K_f32(
6667
- device const char * ids,
6668
  device const char * src1,
6669
  device float * dst,
 
6670
  constant uint64_t & nbi1,
6671
  constant int64_t & ne00,
6672
  constant int64_t & ne01,
@@ -6687,28 +6610,19 @@ kernel void kernel_mul_mv_id_q5_K_f32(
6687
  constant uint & r2,
6688
  constant uint & r3,
6689
  constant int & idx,
6690
- device const char * src00,
6691
- device const char * src01,
6692
- device const char * src02,
6693
- device const char * src03,
6694
- device const char * src04,
6695
- device const char * src05,
6696
- device const char * src06,
6697
- device const char * src07,
6698
  uint3 tgpig[[threadgroup_position_in_grid]],
6699
  uint tiitg[[thread_index_in_threadgroup]],
6700
  uint tiisg[[thread_index_in_simdgroup]],
6701
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6702
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6703
-
6704
  const int64_t bid = tgpig.z/(ne12*ne13);
6705
 
6706
  tgpig.z = tgpig.z%(ne12*ne13);
6707
 
6708
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
6709
 
6710
  kernel_mul_mv_q5_K_f32_impl(
6711
- src0[id],
6712
  (device const float *) (src1 + bid*nb11),
6713
  dst + bid*ne0,
6714
  ne00,
@@ -6727,9 +6641,10 @@ kernel void kernel_mul_mv_id_q5_K_f32(
6727
 
6728
  [[host_name("kernel_mul_mv_id_q6_K_f32")]]
6729
  kernel void kernel_mul_mv_id_q6_K_f32(
6730
- device const char * ids,
6731
  device const char * src1,
6732
  device float * dst,
 
6733
  constant uint64_t & nbi1,
6734
  constant int64_t & ne00,
6735
  constant int64_t & ne01,
@@ -6750,28 +6665,19 @@ kernel void kernel_mul_mv_id_q6_K_f32(
6750
  constant uint & r2,
6751
  constant uint & r3,
6752
  constant int & idx,
6753
- device const char * src00,
6754
- device const char * src01,
6755
- device const char * src02,
6756
- device const char * src03,
6757
- device const char * src04,
6758
- device const char * src05,
6759
- device const char * src06,
6760
- device const char * src07,
6761
  uint3 tgpig[[threadgroup_position_in_grid]],
6762
  uint tiitg[[thread_index_in_threadgroup]],
6763
  uint tiisg[[thread_index_in_simdgroup]],
6764
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6765
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6766
-
6767
  const int64_t bid = tgpig.z/(ne12*ne13);
6768
 
6769
  tgpig.z = tgpig.z%(ne12*ne13);
6770
 
6771
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
6772
 
6773
  kernel_mul_mv_q6_K_f32_impl(
6774
- src0[id],
6775
  (device const float *) (src1 + bid*nb11),
6776
  dst + bid*ne0,
6777
  ne00,
@@ -6790,9 +6696,10 @@ kernel void kernel_mul_mv_id_q6_K_f32(
6790
 
6791
  [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
6792
  kernel void kernel_mul_mv_id_iq2_xxs_f32(
6793
- device const char * ids,
6794
  device const char * src1,
6795
  device float * dst,
 
6796
  constant uint64_t & nbi1,
6797
  constant int64_t & ne00,
6798
  constant int64_t & ne01,
@@ -6813,29 +6720,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
6813
  constant uint & r2,
6814
  constant uint & r3,
6815
  constant int & idx,
6816
- device const char * src00,
6817
- device const char * src01,
6818
- device const char * src02,
6819
- device const char * src03,
6820
- device const char * src04,
6821
- device const char * src05,
6822
- device const char * src06,
6823
- device const char * src07,
6824
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6825
  uint3 tgpig[[threadgroup_position_in_grid]],
6826
  uint tiitg[[thread_index_in_threadgroup]],
6827
  uint tiisg[[thread_index_in_simdgroup]],
6828
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6829
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6830
-
6831
  const int64_t bid = tgpig.z/(ne12*ne13);
6832
 
6833
  tgpig.z = tgpig.z%(ne12*ne13);
6834
 
6835
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
6836
 
6837
  kernel_mul_mv_iq2_xxs_f32_impl(
6838
- src0[id],
6839
  (device const float *) (src1 + bid*nb11),
6840
  dst + bid*ne0,
6841
  ne00,
@@ -6855,9 +6753,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
6855
 
6856
  [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
6857
  kernel void kernel_mul_mv_id_iq2_xs_f32(
6858
- device const char * ids,
6859
  device const char * src1,
6860
  device float * dst,
 
6861
  constant uint64_t & nbi1,
6862
  constant int64_t & ne00,
6863
  constant int64_t & ne01,
@@ -6878,29 +6777,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
6878
  constant uint & r2,
6879
  constant uint & r3,
6880
  constant int & idx,
6881
- device const char * src00,
6882
- device const char * src01,
6883
- device const char * src02,
6884
- device const char * src03,
6885
- device const char * src04,
6886
- device const char * src05,
6887
- device const char * src06,
6888
- device const char * src07,
6889
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6890
  uint3 tgpig[[threadgroup_position_in_grid]],
6891
  uint tiitg[[thread_index_in_threadgroup]],
6892
  uint tiisg[[thread_index_in_simdgroup]],
6893
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6894
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6895
-
6896
  const int64_t bid = tgpig.z/(ne12*ne13);
6897
 
6898
  tgpig.z = tgpig.z%(ne12*ne13);
6899
 
6900
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
6901
 
6902
  kernel_mul_mv_iq2_xs_f32_impl(
6903
- src0[id],
6904
  (device const float *) (src1 + bid*nb11),
6905
  dst + bid*ne0,
6906
  ne00,
@@ -6920,9 +6810,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
6920
 
6921
  [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
6922
  kernel void kernel_mul_mv_id_iq3_xxs_f32(
6923
- device const char * ids,
6924
  device const char * src1,
6925
  device float * dst,
 
6926
  constant uint64_t & nbi1,
6927
  constant int64_t & ne00,
6928
  constant int64_t & ne01,
@@ -6943,29 +6834,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6943
  constant uint & r2,
6944
  constant uint & r3,
6945
  constant int & idx,
6946
- device const char * src00,
6947
- device const char * src01,
6948
- device const char * src02,
6949
- device const char * src03,
6950
- device const char * src04,
6951
- device const char * src05,
6952
- device const char * src06,
6953
- device const char * src07,
6954
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6955
  uint3 tgpig[[threadgroup_position_in_grid]],
6956
  uint tiitg[[thread_index_in_threadgroup]],
6957
  uint tiisg[[thread_index_in_simdgroup]],
6958
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6959
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6960
-
6961
  const int64_t bid = tgpig.z/(ne12*ne13);
6962
 
6963
  tgpig.z = tgpig.z%(ne12*ne13);
6964
 
6965
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
6966
 
6967
  kernel_mul_mv_iq3_xxs_f32_impl(
6968
- src0[id],
6969
  (device const float *) (src1 + bid*nb11),
6970
  dst + bid*ne0,
6971
  ne00,
@@ -6985,9 +6867,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6985
 
6986
  [[host_name("kernel_mul_mv_id_iq3_s_f32")]]
6987
  kernel void kernel_mul_mv_id_iq3_s_f32(
6988
- device const char * ids,
6989
  device const char * src1,
6990
  device float * dst,
 
6991
  constant uint64_t & nbi1,
6992
  constant int64_t & ne00,
6993
  constant int64_t & ne01,
@@ -7008,29 +6891,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
7008
  constant uint & r2,
7009
  constant uint & r3,
7010
  constant int & idx,
7011
- device const char * src00,
7012
- device const char * src01,
7013
- device const char * src02,
7014
- device const char * src03,
7015
- device const char * src04,
7016
- device const char * src05,
7017
- device const char * src06,
7018
- device const char * src07,
7019
  threadgroup int8_t * shared_values [[threadgroup(0)]],
7020
  uint3 tgpig[[threadgroup_position_in_grid]],
7021
  uint tiitg[[thread_index_in_threadgroup]],
7022
  uint tiisg[[thread_index_in_simdgroup]],
7023
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
7024
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7025
-
7026
  const int64_t bid = tgpig.z/(ne12*ne13);
7027
 
7028
  tgpig.z = tgpig.z%(ne12*ne13);
7029
 
7030
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
7031
 
7032
  kernel_mul_mv_iq3_s_f32_impl(
7033
- src0[id],
7034
  (device const float *) (src1 + bid*nb11),
7035
  dst + bid*ne0,
7036
  ne00,
@@ -7050,9 +6924,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
7050
 
7051
  [[host_name("kernel_mul_mv_id_iq2_s_f32")]]
7052
  kernel void kernel_mul_mv_id_iq2_s_f32(
7053
- device const char * ids,
7054
  device const char * src1,
7055
  device float * dst,
 
7056
  constant uint64_t & nbi1,
7057
  constant int64_t & ne00,
7058
  constant int64_t & ne01,
@@ -7073,29 +6948,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
7073
  constant uint & r2,
7074
  constant uint & r3,
7075
  constant int & idx,
7076
- device const char * src00,
7077
- device const char * src01,
7078
- device const char * src02,
7079
- device const char * src03,
7080
- device const char * src04,
7081
- device const char * src05,
7082
- device const char * src06,
7083
- device const char * src07,
7084
  threadgroup int8_t * shared_values [[threadgroup(0)]],
7085
  uint3 tgpig[[threadgroup_position_in_grid]],
7086
  uint tiitg[[thread_index_in_threadgroup]],
7087
  uint tiisg[[thread_index_in_simdgroup]],
7088
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
7089
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7090
-
7091
  const int64_t bid = tgpig.z/(ne12*ne13);
7092
 
7093
  tgpig.z = tgpig.z%(ne12*ne13);
7094
 
7095
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
7096
 
7097
  kernel_mul_mv_iq2_s_f32_impl(
7098
- src0[id],
7099
  (device const float *) (src1 + bid*nb11),
7100
  dst + bid*ne0,
7101
  ne00,
@@ -7115,9 +6981,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
7115
 
7116
  [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
7117
  kernel void kernel_mul_mv_id_iq1_s_f32(
7118
- device const char * ids,
7119
  device const char * src1,
7120
  device float * dst,
 
7121
  constant uint64_t & nbi1,
7122
  constant int64_t & ne00,
7123
  constant int64_t & ne01,
@@ -7138,28 +7005,19 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
7138
  constant uint & r2,
7139
  constant uint & r3,
7140
  constant int & idx,
7141
- device const char * src00,
7142
- device const char * src01,
7143
- device const char * src02,
7144
- device const char * src03,
7145
- device const char * src04,
7146
- device const char * src05,
7147
- device const char * src06,
7148
- device const char * src07,
7149
  uint3 tgpig[[threadgroup_position_in_grid]],
7150
  uint tiitg[[thread_index_in_threadgroup]],
7151
  uint tiisg[[thread_index_in_simdgroup]],
7152
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
7153
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7154
-
7155
  const int64_t bid = tgpig.z/(ne12*ne13);
7156
 
7157
  tgpig.z = tgpig.z%(ne12*ne13);
7158
 
7159
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
7160
 
7161
  kernel_mul_mv_iq1_s_f32_impl(
7162
- src0[id],
7163
  (device const float *) (src1 + bid*nb11),
7164
  dst + bid*ne0,
7165
  ne00,
@@ -7178,9 +7036,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
7178
 
7179
  [[host_name("kernel_mul_mv_id_iq1_m_f32")]]
7180
  kernel void kernel_mul_mv_id_iq1_m_f32(
7181
- device const char * ids,
7182
  device const char * src1,
7183
  device float * dst,
 
7184
  constant uint64_t & nbi1,
7185
  constant int64_t & ne00,
7186
  constant int64_t & ne01,
@@ -7201,28 +7060,19 @@ kernel void kernel_mul_mv_id_iq1_m_f32(
7201
  constant uint & r2,
7202
  constant uint & r3,
7203
  constant int & idx,
7204
- device const char * src00,
7205
- device const char * src01,
7206
- device const char * src02,
7207
- device const char * src03,
7208
- device const char * src04,
7209
- device const char * src05,
7210
- device const char * src06,
7211
- device const char * src07,
7212
  uint3 tgpig[[threadgroup_position_in_grid]],
7213
  uint tiitg[[thread_index_in_threadgroup]],
7214
  uint tiisg[[thread_index_in_simdgroup]],
7215
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
7216
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7217
-
7218
  const int64_t bid = tgpig.z/(ne12*ne13);
7219
 
7220
  tgpig.z = tgpig.z%(ne12*ne13);
7221
 
7222
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
7223
 
7224
  kernel_mul_mv_iq1_m_f32_impl(
7225
- src0[id],
7226
  (device const float *) (src1 + bid*nb11),
7227
  dst + bid*ne0,
7228
  ne00,
@@ -7241,9 +7091,10 @@ kernel void kernel_mul_mv_id_iq1_m_f32(
7241
 
7242
  [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
7243
  kernel void kernel_mul_mv_id_iq4_nl_f32(
7244
- device const char * ids,
7245
  device const char * src1,
7246
  device float * dst,
 
7247
  constant uint64_t & nbi1,
7248
  constant int64_t & ne00,
7249
  constant int64_t & ne01,
@@ -7264,29 +7115,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
7264
  constant uint & r2,
7265
  constant uint & r3,
7266
  constant int & idx,
7267
- device const char * src00,
7268
- device const char * src01,
7269
- device const char * src02,
7270
- device const char * src03,
7271
- device const char * src04,
7272
- device const char * src05,
7273
- device const char * src06,
7274
- device const char * src07,
7275
  threadgroup float * shared_values [[threadgroup(0)]],
7276
  uint3 tgpig[[threadgroup_position_in_grid]],
7277
  uint tiitg[[thread_index_in_threadgroup]],
7278
  uint tiisg[[thread_index_in_simdgroup]],
7279
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
7280
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7281
-
7282
  const int64_t bid = tgpig.z/(ne12*ne13);
7283
 
7284
  tgpig.z = tgpig.z%(ne12*ne13);
7285
 
7286
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
7287
 
7288
  kernel_mul_mv_iq4_nl_f32_impl(
7289
- src0[id],
7290
  (device const float *) (src1 + bid*nb11),
7291
  dst + bid*ne0,
7292
  ne00,
@@ -7306,9 +7148,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
7306
 
7307
  [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
7308
  kernel void kernel_mul_mv_id_iq4_xs_f32(
7309
- device const char * ids,
7310
  device const char * src1,
7311
  device float * dst,
 
7312
  constant uint64_t & nbi1,
7313
  constant int64_t & ne00,
7314
  constant int64_t & ne01,
@@ -7329,33 +7172,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
7329
  constant uint & r2,
7330
  constant uint & r3,
7331
  constant int & idx,
7332
- device const char * src00,
7333
- device const char * src01,
7334
- device const char * src02,
7335
- device const char * src03,
7336
- device const char * src04,
7337
- device const char * src05,
7338
- device const char * src06,
7339
- device const char * src07,
7340
  threadgroup float * shared_values [[threadgroup(0)]],
7341
  uint3 tgpig[[threadgroup_position_in_grid]],
7342
  uint tiitg[[thread_index_in_threadgroup]],
7343
  uint tiisg[[thread_index_in_simdgroup]],
7344
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
7345
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7346
-
7347
  const int64_t bid = tgpig.z/(ne12*ne13);
7348
 
7349
  tgpig.z = tgpig.z%(ne12*ne13);
7350
 
7351
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
7352
 
7353
  #if QK_K == 64
7354
  kernel_mul_mv_iq4_nl_f32_impl(
7355
  #else
7356
  kernel_mul_mv_iq4_xs_f32_impl(
7357
  #endif
7358
- src0[id],
7359
  (device const float *) (src1 + bid*nb11),
7360
  dst + bid*ne0,
7361
  ne00,
 
13
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
14
 
15
  enum ggml_sort_order {
16
+ GGML_SORT_ORDER_ASC,
17
+ GGML_SORT_ORDER_DESC,
18
  };
19
 
20
  // general-purpose kernel for addition, multiplication and division of two tensors
 
1973
 
1974
  // bitonic sort implementation following the CUDA kernels as reference
1975
  typedef void (argsort_t)(
1976
+ device const float * x,
1977
+ device int32_t * dst,
1978
+ constant int64_t & ncols,
1979
+ constant int64_t & ncols_pad,
1980
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
1981
  uint3 tgpig[[threadgroup_position_in_grid]],
1982
  uint3 tpitg[[thread_position_in_threadgroup]]);
1983
 
 
1986
  device const float * x,
1987
  device int32_t * dst,
1988
  constant int64_t & ncols,
1989
+ constant int64_t & ncols_pad,
1990
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
1991
  uint3 tgpig[[threadgroup_position_in_grid]],
1992
  uint3 tpitg[[thread_position_in_threadgroup]]) {
1993
  // bitonic sort
1994
  int col = tpitg[0];
1995
  int row = tgpig[1];
1996
 
1997
+ if (col >= ncols_pad) return;
1998
 
1999
+ device const float * x_row = x + row * ncols;
2000
+ threadgroup int32_t * dst_row = shared_values;
2001
 
2002
  // initialize indices
2003
+ dst_row[col] = col;
2004
+
 
2005
  threadgroup_barrier(mem_flags::mem_threadgroup);
2006
 
2007
+ for (int k = 2; k <= ncols_pad; k *= 2) {
2008
  for (int j = k / 2; j > 0; j /= 2) {
2009
  int ixj = col ^ j;
2010
  if (ixj > col) {
2011
  if ((col & k) == 0) {
2012
+ if (dst_row[col] >= ncols ||
2013
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
2014
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
2015
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
2016
+ ) {
2017
  SWAP(dst_row[col], dst_row[ixj]);
2018
  }
2019
  } else {
2020
+ if (dst_row[ixj] >= ncols ||
2021
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
2022
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
2023
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
2024
+ ) {
2025
  SWAP(dst_row[col], dst_row[ixj]);
2026
  }
2027
  }
 
2029
  threadgroup_barrier(mem_flags::mem_threadgroup);
2030
  }
2031
  }
2032
+
2033
+ // copy the result to dst without the padding
2034
+ if (col < ncols) {
2035
+ dst[row * ncols + col] = dst_row[col];
2036
+ }
2037
  }
2038
 
2039
+ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
2040
+ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
2041
 
2042
  kernel void kernel_leaky_relu_f32(
2043
  device const float * src0,
 
5801
 
5802
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
5803
  kernel void kernel_mul_mm_id(
5804
+ device const uchar * src0s,
5805
  device const uchar * src1,
5806
  device float * dst,
5807
+ device const uchar * ids,
5808
  constant uint64_t & nbi1,
5809
  constant int64_t & ne00,
5810
  constant int64_t & ne02,
 
5821
  constant uint & r2,
5822
  constant uint & r3,
5823
  constant int & idx,
 
 
 
 
 
 
 
 
5824
  threadgroup uchar * shared_memory [[threadgroup(0)]],
5825
  uint3 tgpig[[threadgroup_position_in_grid]],
5826
  uint tiitg[[thread_index_in_threadgroup]],
5827
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
5828
 
5829
  // expert id
5830
  const int32_t id = tgpig.z/(ne12*ne13);
5831
+ device const uchar * src0 = src0s + id*nb02;
5832
 
5833
  tgpig.z = tgpig.z%(ne12*ne13);
5834
 
 
5843
  }
5844
 
5845
  kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
5846
+ src0,
5847
  src1,
5848
  src1ids,
5849
  dst,
 
5969
  //
5970
 
5971
  typedef void (mat_mm_id_t)(
5972
+ device const uchar * src0s,
5973
  device const uchar * src1,
5974
  device float * dst,
5975
+ device const uchar * ids,
5976
  constant uint64_t & nbi1,
5977
  constant int64_t & ne00,
5978
  constant int64_t & ne02,
 
5989
  constant uint & r2,
5990
  constant uint & r3,
5991
  constant int & idx,
 
 
 
 
 
 
 
 
5992
  threadgroup uchar *,
5993
  uint3, uint, uint);
5994
 
 
6024
 
6025
  [[host_name("kernel_mul_mv_id_f32_f32")]]
6026
  kernel void kernel_mul_mv_id_f32_f32(
6027
+ device const char * src0s,
6028
  device const char * src1,
6029
  device float * dst,
6030
+ device const char * ids,
6031
  constant uint64_t & nbi1,
6032
  constant int64_t & ne00,
6033
  constant int64_t & ne01,
 
6048
  constant uint & r2,
6049
  constant uint & r3,
6050
  constant int & idx,
 
 
 
 
 
 
 
 
6051
  uint3 tgpig[[threadgroup_position_in_grid]],
6052
  uint tiitg[[thread_index_in_threadgroup]],
6053
  uint tiisg[[thread_index_in_simdgroup]],
6054
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6055
  const int64_t bid = tgpig.z/(ne12*ne13);
6056
 
6057
  tgpig.z = tgpig.z%(ne12*ne13);
6058
 
6059
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6060
+ device const char * src0 = src0s + id*nb02;
6061
 
6062
  kernel_mul_mv_f32_f32_impl(
6063
+ src0,
6064
  src1 + bid*nb11,
6065
  dst + bid*ne0,
6066
  ne00,
 
6085
 
6086
  [[host_name("kernel_mul_mv_id_f16_f32")]]
6087
  kernel void kernel_mul_mv_id_f16_f32(
6088
+ device const char * src0s,
6089
  device const char * src1,
6090
  device float * dst,
6091
+ device const char * ids,
6092
  constant uint64_t & nbi1,
6093
  constant int64_t & ne00,
6094
  constant int64_t & ne01,
 
6109
  constant uint & r2,
6110
  constant uint & r3,
6111
  constant int & idx,
 
 
 
 
 
 
 
 
6112
  uint3 tgpig[[threadgroup_position_in_grid]],
6113
  uint tiitg[[thread_index_in_threadgroup]],
6114
  uint tiisg[[thread_index_in_simdgroup]],
6115
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6116
  const int64_t bid = tgpig.z/(ne12*ne13);
6117
 
6118
  tgpig.z = tgpig.z%(ne12*ne13);
6119
 
6120
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6121
+ device const char * src0 = src0s + id*nb02;
6122
 
6123
  kernel_mul_mv_f16_f32_impl(
6124
+ src0,
6125
  src1 + bid*nb11,
6126
  dst + bid*ne0,
6127
  ne00,
 
6146
 
6147
  [[host_name("kernel_mul_mv_id_q8_0_f32")]]
6148
  kernel void kernel_mul_mv_id_q8_0_f32(
6149
+ device const char * src0s,
6150
  device const char * src1,
6151
  device float * dst,
6152
+ device const char * ids,
6153
  constant uint64_t & nbi1,
6154
  constant int64_t & ne00,
6155
  constant int64_t & ne01,
 
6170
  constant uint & r2,
6171
  constant uint & r3,
6172
  constant int & idx,
 
 
 
 
 
 
 
 
6173
  uint3 tgpig[[threadgroup_position_in_grid]],
6174
  uint tiitg[[thread_index_in_threadgroup]],
6175
  uint tiisg[[thread_index_in_simdgroup]],
6176
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6177
  const int64_t bid = tgpig.z/(ne12*ne13);
6178
 
6179
  tgpig.z = tgpig.z%(ne12*ne13);
6180
 
6181
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6182
+ device const char * src0 = src0s + id*nb02;
6183
 
6184
  kernel_mul_mv_q8_0_f32_impl(
6185
+ src0,
6186
  (device const float *) (src1 + bid*nb11),
6187
  dst + bid*ne0,
6188
  ne00,
 
6201
 
6202
  [[host_name("kernel_mul_mv_id_q4_0_f32")]]
6203
  kernel void kernel_mul_mv_id_q4_0_f32(
6204
+ device const char * src0s,
6205
  device const char * src1,
6206
  device float * dst,
6207
+ device const char * ids,
6208
  constant uint64_t & nbi1,
6209
  constant int64_t & ne00,
6210
  constant int64_t & ne01,
 
6225
  constant uint & r2,
6226
  constant uint & r3,
6227
  constant int & idx,
 
 
 
 
 
 
 
 
6228
  uint3 tgpig[[threadgroup_position_in_grid]],
6229
  uint tiitg[[thread_index_in_threadgroup]],
6230
  uint tiisg[[thread_index_in_simdgroup]],
6231
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6232
  const int64_t bid = tgpig.z/(ne12*ne13);
6233
 
6234
  tgpig.z = tgpig.z%(ne12*ne13);
6235
 
6236
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6237
+ device const char * src0 = src0s + id*nb02;
6238
 
6239
  mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6240
+ src0,
6241
  (device const float *) (src1 + bid*nb11),
6242
  dst + bid*ne0,
6243
  ne00,
 
6256
 
6257
  [[host_name("kernel_mul_mv_id_q4_1_f32")]]
6258
  kernel void kernel_mul_mv_id_q4_1_f32(
6259
+ device const char * src0s,
6260
  device const char * src1,
6261
  device float * dst,
6262
+ device const char * ids,
6263
  constant uint64_t & nbi1,
6264
  constant int64_t & ne00,
6265
  constant int64_t & ne01,
 
6280
  constant uint & r2,
6281
  constant uint & r3,
6282
  constant int & idx,
 
 
 
 
 
 
 
 
6283
  uint3 tgpig[[threadgroup_position_in_grid]],
6284
  uint tiitg[[thread_index_in_threadgroup]],
6285
  uint tiisg[[thread_index_in_simdgroup]],
6286
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6287
  const int64_t bid = tgpig.z/(ne12*ne13);
6288
 
6289
  tgpig.z = tgpig.z%(ne12*ne13);
6290
 
6291
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6292
+ device const char * src0 = src0s + id*nb02;
6293
 
6294
  mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6295
+ src0,
6296
  (device const float *) (src1 + bid*nb11),
6297
  dst + bid*ne0,
6298
  ne00,
 
6311
 
6312
  [[host_name("kernel_mul_mv_id_q5_0_f32")]]
6313
  kernel void kernel_mul_mv_id_q5_0_f32(
6314
+ device const char * src0s,
6315
  device const char * src1,
6316
  device float * dst,
6317
+ device const char * ids,
6318
  constant uint64_t & nbi1,
6319
  constant int64_t & ne00,
6320
  constant int64_t & ne01,
 
6335
  constant uint & r2,
6336
  constant uint & r3,
6337
  constant int & idx,
 
 
 
 
 
 
 
 
6338
  uint3 tgpig[[threadgroup_position_in_grid]],
6339
  uint tiitg[[thread_index_in_threadgroup]],
6340
  uint tiisg[[thread_index_in_simdgroup]],
6341
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6342
  const int64_t bid = tgpig.z/(ne12*ne13);
6343
 
6344
  tgpig.z = tgpig.z%(ne12*ne13);
6345
 
6346
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6347
+ device const char * src0 = src0s + id*nb02;
6348
 
6349
  mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6350
+ src0,
6351
  (device const float *) (src1 + bid*nb11),
6352
  dst + bid*ne0,
6353
  ne00,
 
6366
 
6367
  [[host_name("kernel_mul_mv_id_q5_1_f32")]]
6368
  kernel void kernel_mul_mv_id_q5_1_f32(
6369
+ device const char * src0s,
6370
  device const char * src1,
6371
  device float * dst,
6372
+ device const char * ids,
6373
  constant uint64_t & nbi1,
6374
  constant int64_t & ne00,
6375
  constant int64_t & ne01,
 
6390
  constant uint & r2,
6391
  constant uint & r3,
6392
  constant int & idx,
 
 
 
 
 
 
 
 
6393
  uint3 tgpig[[threadgroup_position_in_grid]],
6394
  uint tiitg[[thread_index_in_threadgroup]],
6395
  uint tiisg[[thread_index_in_simdgroup]],
6396
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6397
  const int64_t bid = tgpig.z/(ne12*ne13);
6398
 
6399
  tgpig.z = tgpig.z%(ne12*ne13);
6400
 
6401
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6402
+ device const char * src0 = src0s + id*nb02;
6403
 
6404
  mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6405
+ src0,
6406
  (device const float *) (src1 + bid*nb11),
6407
  dst + bid*ne0,
6408
  ne00,
 
6421
 
6422
  [[host_name("kernel_mul_mv_id_q2_K_f32")]]
6423
  kernel void kernel_mul_mv_id_q2_K_f32(
6424
+ device const char * src0s,
6425
  device const char * src1,
6426
  device float * dst,
6427
+ device const char * ids,
6428
  constant uint64_t & nbi1,
6429
  constant int64_t & ne00,
6430
  constant int64_t & ne01,
 
6445
  constant uint & r2,
6446
  constant uint & r3,
6447
  constant int & idx,
 
 
 
 
 
 
 
 
6448
  uint3 tgpig[[threadgroup_position_in_grid]],
6449
  uint tiitg[[thread_index_in_threadgroup]],
6450
  uint tiisg[[thread_index_in_simdgroup]],
6451
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6452
  const int64_t bid = tgpig.z/(ne12*ne13);
6453
 
6454
  tgpig.z = tgpig.z%(ne12*ne13);
6455
 
6456
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6457
+ device const char * src0 = src0s + id*nb02;
6458
 
6459
  kernel_mul_mv_q2_K_f32_impl(
6460
+ src0,
6461
  (device const float *) (src1 + bid*nb11),
6462
  dst + bid*ne0,
6463
  ne00,
 
6476
 
6477
  [[host_name("kernel_mul_mv_id_q3_K_f32")]]
6478
  kernel void kernel_mul_mv_id_q3_K_f32(
6479
+ device const char * src0s,
6480
  device const char * src1,
6481
  device float * dst,
6482
+ device const char * ids,
6483
  constant uint64_t & nbi1,
6484
  constant int64_t & ne00,
6485
  constant int64_t & ne01,
 
6500
  constant uint & r2,
6501
  constant uint & r3,
6502
  constant int & idx,
 
 
 
 
 
 
 
 
6503
  uint3 tgpig[[threadgroup_position_in_grid]],
6504
  uint tiitg[[thread_index_in_threadgroup]],
6505
  uint tiisg[[thread_index_in_simdgroup]],
6506
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6507
  const int64_t bid = tgpig.z/(ne12*ne13);
6508
 
6509
  tgpig.z = tgpig.z%(ne12*ne13);
6510
 
6511
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6512
+ device const char * src0 = src0s + id*nb02;
6513
 
6514
  kernel_mul_mv_q3_K_f32_impl(
6515
+ src0,
6516
  (device const float *) (src1 + bid*nb11),
6517
  dst + bid*ne0,
6518
  ne00,
 
6531
 
6532
  [[host_name("kernel_mul_mv_id_q4_K_f32")]]
6533
  kernel void kernel_mul_mv_id_q4_K_f32(
6534
+ device const char * src0s,
6535
  device const char * src1,
6536
  device float * dst,
6537
+ device const char * ids,
6538
  constant uint64_t & nbi1,
6539
  constant int64_t & ne00,
6540
  constant int64_t & ne01,
 
6555
  constant uint & r2,
6556
  constant uint & r3,
6557
  constant int & idx,
 
 
 
 
 
 
 
 
6558
  uint3 tgpig[[threadgroup_position_in_grid]],
6559
  uint tiitg[[thread_index_in_threadgroup]],
6560
  uint tiisg[[thread_index_in_simdgroup]],
6561
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6562
  const int64_t bid = tgpig.z/(ne12*ne13);
6563
 
6564
  tgpig.z = tgpig.z%(ne12*ne13);
6565
 
6566
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6567
+ device const char * src0 = src0s + id*nb02;
6568
 
6569
  kernel_mul_mv_q4_K_f32_impl(
6570
+ src0,
6571
  (device const float *) (src1 + bid*nb11),
6572
  dst + bid*ne0,
6573
  ne00,
 
6586
 
6587
  [[host_name("kernel_mul_mv_id_q5_K_f32")]]
6588
  kernel void kernel_mul_mv_id_q5_K_f32(
6589
+ device const char * src0s,
6590
  device const char * src1,
6591
  device float * dst,
6592
+ device const char * ids,
6593
  constant uint64_t & nbi1,
6594
  constant int64_t & ne00,
6595
  constant int64_t & ne01,
 
6610
  constant uint & r2,
6611
  constant uint & r3,
6612
  constant int & idx,
 
 
 
 
 
 
 
 
6613
  uint3 tgpig[[threadgroup_position_in_grid]],
6614
  uint tiitg[[thread_index_in_threadgroup]],
6615
  uint tiisg[[thread_index_in_simdgroup]],
6616
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6617
  const int64_t bid = tgpig.z/(ne12*ne13);
6618
 
6619
  tgpig.z = tgpig.z%(ne12*ne13);
6620
 
6621
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6622
+ device const char * src0 = src0s + id*nb02;
6623
 
6624
  kernel_mul_mv_q5_K_f32_impl(
6625
+ src0,
6626
  (device const float *) (src1 + bid*nb11),
6627
  dst + bid*ne0,
6628
  ne00,
 
6641
 
6642
  [[host_name("kernel_mul_mv_id_q6_K_f32")]]
6643
  kernel void kernel_mul_mv_id_q6_K_f32(
6644
+ device const char * src0s,
6645
  device const char * src1,
6646
  device float * dst,
6647
+ device const char * ids,
6648
  constant uint64_t & nbi1,
6649
  constant int64_t & ne00,
6650
  constant int64_t & ne01,
 
6665
  constant uint & r2,
6666
  constant uint & r3,
6667
  constant int & idx,
 
 
 
 
 
 
 
 
6668
  uint3 tgpig[[threadgroup_position_in_grid]],
6669
  uint tiitg[[thread_index_in_threadgroup]],
6670
  uint tiisg[[thread_index_in_simdgroup]],
6671
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6672
  const int64_t bid = tgpig.z/(ne12*ne13);
6673
 
6674
  tgpig.z = tgpig.z%(ne12*ne13);
6675
 
6676
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6677
+ device const char * src0 = src0s + id*nb02;
6678
 
6679
  kernel_mul_mv_q6_K_f32_impl(
6680
+ src0,
6681
  (device const float *) (src1 + bid*nb11),
6682
  dst + bid*ne0,
6683
  ne00,
 
6696
 
6697
  [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
6698
  kernel void kernel_mul_mv_id_iq2_xxs_f32(
6699
+ device const char * src0s,
6700
  device const char * src1,
6701
  device float * dst,
6702
+ device const char * ids,
6703
  constant uint64_t & nbi1,
6704
  constant int64_t & ne00,
6705
  constant int64_t & ne01,
 
6720
  constant uint & r2,
6721
  constant uint & r3,
6722
  constant int & idx,
 
 
 
 
 
 
 
 
6723
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6724
  uint3 tgpig[[threadgroup_position_in_grid]],
6725
  uint tiitg[[thread_index_in_threadgroup]],
6726
  uint tiisg[[thread_index_in_simdgroup]],
6727
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6728
  const int64_t bid = tgpig.z/(ne12*ne13);
6729
 
6730
  tgpig.z = tgpig.z%(ne12*ne13);
6731
 
6732
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6733
+ device const char * src0 = src0s + id*nb02;
6734
 
6735
  kernel_mul_mv_iq2_xxs_f32_impl(
6736
+ src0,
6737
  (device const float *) (src1 + bid*nb11),
6738
  dst + bid*ne0,
6739
  ne00,
 
6753
 
6754
  [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
6755
  kernel void kernel_mul_mv_id_iq2_xs_f32(
6756
+ device const char * src0s,
6757
  device const char * src1,
6758
  device float * dst,
6759
+ device const char * ids,
6760
  constant uint64_t & nbi1,
6761
  constant int64_t & ne00,
6762
  constant int64_t & ne01,
 
6777
  constant uint & r2,
6778
  constant uint & r3,
6779
  constant int & idx,
 
 
 
 
 
 
 
 
6780
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6781
  uint3 tgpig[[threadgroup_position_in_grid]],
6782
  uint tiitg[[thread_index_in_threadgroup]],
6783
  uint tiisg[[thread_index_in_simdgroup]],
6784
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6785
  const int64_t bid = tgpig.z/(ne12*ne13);
6786
 
6787
  tgpig.z = tgpig.z%(ne12*ne13);
6788
 
6789
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6790
+ device const char * src0 = src0s + id*nb02;
6791
 
6792
  kernel_mul_mv_iq2_xs_f32_impl(
6793
+ src0,
6794
  (device const float *) (src1 + bid*nb11),
6795
  dst + bid*ne0,
6796
  ne00,
 
6810
 
6811
  [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
6812
  kernel void kernel_mul_mv_id_iq3_xxs_f32(
6813
+ device const char * src0s,
6814
  device const char * src1,
6815
  device float * dst,
6816
+ device const char * ids,
6817
  constant uint64_t & nbi1,
6818
  constant int64_t & ne00,
6819
  constant int64_t & ne01,
 
6834
  constant uint & r2,
6835
  constant uint & r3,
6836
  constant int & idx,
 
 
 
 
 
 
 
 
6837
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6838
  uint3 tgpig[[threadgroup_position_in_grid]],
6839
  uint tiitg[[thread_index_in_threadgroup]],
6840
  uint tiisg[[thread_index_in_simdgroup]],
6841
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6842
  const int64_t bid = tgpig.z/(ne12*ne13);
6843
 
6844
  tgpig.z = tgpig.z%(ne12*ne13);
6845
 
6846
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6847
+ device const char * src0 = src0s + id*nb02;
6848
 
6849
  kernel_mul_mv_iq3_xxs_f32_impl(
6850
+ src0,
6851
  (device const float *) (src1 + bid*nb11),
6852
  dst + bid*ne0,
6853
  ne00,
 
6867
 
6868
  [[host_name("kernel_mul_mv_id_iq3_s_f32")]]
6869
  kernel void kernel_mul_mv_id_iq3_s_f32(
6870
+ device const char * src0s,
6871
  device const char * src1,
6872
  device float * dst,
6873
+ device const char * ids,
6874
  constant uint64_t & nbi1,
6875
  constant int64_t & ne00,
6876
  constant int64_t & ne01,
 
6891
  constant uint & r2,
6892
  constant uint & r3,
6893
  constant int & idx,
 
 
 
 
 
 
 
 
6894
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6895
  uint3 tgpig[[threadgroup_position_in_grid]],
6896
  uint tiitg[[thread_index_in_threadgroup]],
6897
  uint tiisg[[thread_index_in_simdgroup]],
6898
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6899
  const int64_t bid = tgpig.z/(ne12*ne13);
6900
 
6901
  tgpig.z = tgpig.z%(ne12*ne13);
6902
 
6903
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6904
+ device const char * src0 = src0s + id*nb02;
6905
 
6906
  kernel_mul_mv_iq3_s_f32_impl(
6907
+ src0,
6908
  (device const float *) (src1 + bid*nb11),
6909
  dst + bid*ne0,
6910
  ne00,
 
6924
 
6925
  [[host_name("kernel_mul_mv_id_iq2_s_f32")]]
6926
  kernel void kernel_mul_mv_id_iq2_s_f32(
6927
+ device const char * src0s,
6928
  device const char * src1,
6929
  device float * dst,
6930
+ device const char * ids,
6931
  constant uint64_t & nbi1,
6932
  constant int64_t & ne00,
6933
  constant int64_t & ne01,
 
6948
  constant uint & r2,
6949
  constant uint & r3,
6950
  constant int & idx,
 
 
 
 
 
 
 
 
6951
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6952
  uint3 tgpig[[threadgroup_position_in_grid]],
6953
  uint tiitg[[thread_index_in_threadgroup]],
6954
  uint tiisg[[thread_index_in_simdgroup]],
6955
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
6956
  const int64_t bid = tgpig.z/(ne12*ne13);
6957
 
6958
  tgpig.z = tgpig.z%(ne12*ne13);
6959
 
6960
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6961
+ device const char * src0 = src0s + id*nb02;
6962
 
6963
  kernel_mul_mv_iq2_s_f32_impl(
6964
+ src0,
6965
  (device const float *) (src1 + bid*nb11),
6966
  dst + bid*ne0,
6967
  ne00,
 
6981
 
6982
  [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
6983
  kernel void kernel_mul_mv_id_iq1_s_f32(
6984
+ device const char * src0s,
6985
  device const char * src1,
6986
  device float * dst,
6987
+ device const char * ids,
6988
  constant uint64_t & nbi1,
6989
  constant int64_t & ne00,
6990
  constant int64_t & ne01,
 
7005
  constant uint & r2,
7006
  constant uint & r3,
7007
  constant int & idx,
 
 
 
 
 
 
 
 
7008
  uint3 tgpig[[threadgroup_position_in_grid]],
7009
  uint tiitg[[thread_index_in_threadgroup]],
7010
  uint tiisg[[thread_index_in_simdgroup]],
7011
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
7012
  const int64_t bid = tgpig.z/(ne12*ne13);
7013
 
7014
  tgpig.z = tgpig.z%(ne12*ne13);
7015
 
7016
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7017
+ device const char * src0 = src0s + id*nb02;
7018
 
7019
  kernel_mul_mv_iq1_s_f32_impl(
7020
+ src0,
7021
  (device const float *) (src1 + bid*nb11),
7022
  dst + bid*ne0,
7023
  ne00,
 
7036
 
7037
  [[host_name("kernel_mul_mv_id_iq1_m_f32")]]
7038
  kernel void kernel_mul_mv_id_iq1_m_f32(
7039
+ device const char * src0s,
7040
  device const char * src1,
7041
  device float * dst,
7042
+ device const char * ids,
7043
  constant uint64_t & nbi1,
7044
  constant int64_t & ne00,
7045
  constant int64_t & ne01,
 
7060
  constant uint & r2,
7061
  constant uint & r3,
7062
  constant int & idx,
 
 
 
 
 
 
 
 
7063
  uint3 tgpig[[threadgroup_position_in_grid]],
7064
  uint tiitg[[thread_index_in_threadgroup]],
7065
  uint tiisg[[thread_index_in_simdgroup]],
7066
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
7067
  const int64_t bid = tgpig.z/(ne12*ne13);
7068
 
7069
  tgpig.z = tgpig.z%(ne12*ne13);
7070
 
7071
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7072
+ device const char * src0 = src0s + id*nb02;
7073
 
7074
  kernel_mul_mv_iq1_m_f32_impl(
7075
+ src0,
7076
  (device const float *) (src1 + bid*nb11),
7077
  dst + bid*ne0,
7078
  ne00,
 
7091
 
7092
  [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
7093
  kernel void kernel_mul_mv_id_iq4_nl_f32(
7094
+ device const char * src0s,
7095
  device const char * src1,
7096
  device float * dst,
7097
+ device const char * ids,
7098
  constant uint64_t & nbi1,
7099
  constant int64_t & ne00,
7100
  constant int64_t & ne01,
 
7115
  constant uint & r2,
7116
  constant uint & r3,
7117
  constant int & idx,
 
 
 
 
 
 
 
 
7118
  threadgroup float * shared_values [[threadgroup(0)]],
7119
  uint3 tgpig[[threadgroup_position_in_grid]],
7120
  uint tiitg[[thread_index_in_threadgroup]],
7121
  uint tiisg[[thread_index_in_simdgroup]],
7122
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
7123
  const int64_t bid = tgpig.z/(ne12*ne13);
7124
 
7125
  tgpig.z = tgpig.z%(ne12*ne13);
7126
 
7127
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7128
+ device const char * src0 = src0s + id*nb02;
7129
 
7130
  kernel_mul_mv_iq4_nl_f32_impl(
7131
+ src0,
7132
  (device const float *) (src1 + bid*nb11),
7133
  dst + bid*ne0,
7134
  ne00,
 
7148
 
7149
  [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
7150
  kernel void kernel_mul_mv_id_iq4_xs_f32(
7151
+ device const char * src0s,
7152
  device const char * src1,
7153
  device float * dst,
7154
+ device const char * ids,
7155
  constant uint64_t & nbi1,
7156
  constant int64_t & ne00,
7157
  constant int64_t & ne01,
 
7172
  constant uint & r2,
7173
  constant uint & r3,
7174
  constant int & idx,
 
 
 
 
 
 
 
 
7175
  threadgroup float * shared_values [[threadgroup(0)]],
7176
  uint3 tgpig[[threadgroup_position_in_grid]],
7177
  uint tiitg[[thread_index_in_threadgroup]],
7178
  uint tiisg[[thread_index_in_simdgroup]],
7179
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
7180
  const int64_t bid = tgpig.z/(ne12*ne13);
7181
 
7182
  tgpig.z = tgpig.z%(ne12*ne13);
7183
 
7184
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7185
+ device const char * src0 = src0s + id*nb02;
7186
 
7187
  #if QK_K == 64
7188
  kernel_mul_mv_iq4_nl_f32_impl(
7189
  #else
7190
  kernel_mul_mv_iq4_xs_f32_impl(
7191
  #endif
7192
+ src0,
7193
  (device const float *) (src1 + bid*nb11),
7194
  dst + bid*ne0,
7195
  ne00,
ggml.c CHANGED
@@ -4573,45 +4573,38 @@ void ggml_mul_mat_set_prec(
4573
 
4574
  // ggml_mul_mat_id
4575
 
 
 
4576
  struct ggml_tensor * ggml_mul_mat_id(
4577
  struct ggml_context * ctx,
4578
- struct ggml_tensor * const as[],
4579
- int n_as,
4580
  struct ggml_tensor * ids,
4581
  int id,
4582
  struct ggml_tensor * b) {
4583
 
4584
  GGML_ASSERT(ids->type == GGML_TYPE_I32);
4585
- GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
4586
- GGML_ASSERT(ids->ne[1] == b->ne[1]);
4587
  GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
4588
- GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
4589
- GGML_ASSERT(id >= 0 && id < ids->ne[0]);
4590
 
4591
  bool is_node = false;
4592
 
4593
- if (as[0]->grad || b->grad) {
4594
  is_node = true;
4595
  }
4596
 
4597
- const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] };
4598
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4599
 
4600
  ggml_set_op_params_i32(result, 0, id);
4601
- ggml_set_op_params_i32(result, 1, n_as);
4602
 
4603
  result->op = GGML_OP_MUL_MAT_ID;
4604
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4605
- result->src[0] = ids;
4606
  result->src[1] = b;
4607
-
4608
- for (int i = 0; i < n_as; i++) {
4609
- struct ggml_tensor * a = as[i];
4610
- GGML_ASSERT(ggml_are_same_shape(as[0], a));
4611
- GGML_ASSERT(ggml_can_mul_mat(a, b));
4612
- GGML_ASSERT(!ggml_is_transposed(a));
4613
- result->src[i + 2] = a;
4614
- }
4615
 
4616
  return result;
4617
  }
@@ -10948,10 +10941,9 @@ static void ggml_compute_forward_mul_mat_id(
10948
  const struct ggml_compute_params * params,
10949
  struct ggml_tensor * dst) {
10950
 
10951
- const struct ggml_tensor * ids = dst->src[0];
10952
  const struct ggml_tensor * src1 = dst->src[1];
10953
-
10954
- const struct ggml_tensor * src0 = dst->src[2]; // only for GGML_TENSOR_BINARY_OP_LOCALS
10955
 
10956
  GGML_TENSOR_BINARY_OP_LOCALS
10957
 
@@ -10981,13 +10973,13 @@ static void ggml_compute_forward_mul_mat_id(
10981
  GGML_ASSERT(nb1 <= nb2);
10982
  GGML_ASSERT(nb2 <= nb3);
10983
 
10984
- // broadcast factors
10985
- const int64_t r2 = ne12/ne02;
10986
- const int64_t r3 = ne13/ne03;
10987
 
10988
  // row groups
10989
  const int id = ggml_get_op_params_i32(dst, 0);
10990
- const int n_as = ggml_get_op_params_i32(dst, 1);
10991
 
10992
  char * wdata_src1_end = (src1->type == vec_dot_type) ?
10993
  (char *) params->wdata :
@@ -11047,7 +11039,7 @@ static void ggml_compute_forward_mul_mat_id(
11047
  continue;
11048
  }
11049
 
11050
- const struct ggml_tensor * src0_cur = dst->src[cur_a + 2];
11051
 
11052
  const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
11053
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -11082,9 +11074,6 @@ static void ggml_compute_forward_mul_mat_id(
11082
  continue;
11083
  }
11084
 
11085
- assert(ne12 % ne02 == 0);
11086
- assert(ne13 % ne03 == 0);
11087
-
11088
  // block-tiling attempt
11089
  const int64_t blck_0 = 16;
11090
  const int64_t blck_1 = 16;
@@ -11101,14 +11090,14 @@ static void ggml_compute_forward_mul_mat_id(
11101
  const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
11102
 
11103
  // broadcast src0 into src1
11104
- const int64_t i03 = i13/r3;
11105
- const int64_t i02 = i12/r2;
11106
 
11107
  const int64_t i1 = i11;
11108
  const int64_t i2 = i12;
11109
  const int64_t i3 = i13;
11110
 
11111
- const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03);
11112
 
11113
  // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
11114
  // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
@@ -18464,13 +18453,13 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18464
  case GGML_OP_MUL_MAT_ID:
18465
  {
18466
  cur = 0;
18467
- const struct ggml_tensor * src0 = node->src[2];
18468
  const struct ggml_tensor * src1 = node->src[1];
18469
  const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
18470
  if (src1->type != vec_dot_type) {
18471
  cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
18472
  }
18473
- const int n_as = ggml_get_op_params_i32(node, 1);
18474
  cur += GGML_PAD(cur, sizeof(int64_t)); // align
18475
  cur += n_as * sizeof(int64_t); // matrix_row_counts
18476
  cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
 
4573
 
4574
  // ggml_mul_mat_id
4575
 
4576
+ // NOTE: id will be removed in the future and instead all the experts listed in ids will be computed
4577
+ // this will allow computing all the used experts in a single matrix multiplication
4578
  struct ggml_tensor * ggml_mul_mat_id(
4579
  struct ggml_context * ctx,
4580
+ struct ggml_tensor * as,
 
4581
  struct ggml_tensor * ids,
4582
  int id,
4583
  struct ggml_tensor * b) {
4584
 
4585
  GGML_ASSERT(ids->type == GGML_TYPE_I32);
4586
+ GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
4587
+ GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row
4588
  GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
4589
+ GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id
4590
+ GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
4591
 
4592
  bool is_node = false;
4593
 
4594
+ if (as->grad || b->grad) {
4595
  is_node = true;
4596
  }
4597
 
4598
+ const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] };
4599
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4600
 
4601
  ggml_set_op_params_i32(result, 0, id);
 
4602
 
4603
  result->op = GGML_OP_MUL_MAT_ID;
4604
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4605
+ result->src[0] = as;
4606
  result->src[1] = b;
4607
+ result->src[2] = ids;
 
 
 
 
 
 
 
4608
 
4609
  return result;
4610
  }
 
10941
  const struct ggml_compute_params * params,
10942
  struct ggml_tensor * dst) {
10943
 
10944
+ const struct ggml_tensor * src0 = dst->src[0];
10945
  const struct ggml_tensor * src1 = dst->src[1];
10946
+ const struct ggml_tensor * ids = dst->src[2];
 
10947
 
10948
  GGML_TENSOR_BINARY_OP_LOCALS
10949
 
 
10973
  GGML_ASSERT(nb1 <= nb2);
10974
  GGML_ASSERT(nb2 <= nb3);
10975
 
10976
+ // broadcast is not supported with mmid
10977
+ assert(ne12 == 1);
10978
+ assert(ne13 == 1);
10979
 
10980
  // row groups
10981
  const int id = ggml_get_op_params_i32(dst, 0);
10982
+ const int n_as = src0->ne[2];
10983
 
10984
  char * wdata_src1_end = (src1->type == vec_dot_type) ?
10985
  (char *) params->wdata :
 
11039
  continue;
11040
  }
11041
 
11042
+ size_t src0_offset = cur_a*src0->nb[2];
11043
 
11044
  const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
11045
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
 
11074
  continue;
11075
  }
11076
 
 
 
 
11077
  // block-tiling attempt
11078
  const int64_t blck_0 = 16;
11079
  const int64_t blck_1 = 16;
 
11090
  const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
11091
 
11092
  // broadcast src0 into src1
11093
+ //const int64_t i03 = i13/r3;
11094
+ //const int64_t i02 = i12/r2;
11095
 
11096
  const int64_t i1 = i11;
11097
  const int64_t i2 = i12;
11098
  const int64_t i3 = i13;
11099
 
11100
+ const char * src0_row = (const char *) src0->data + src0_offset;
11101
 
11102
  // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
11103
  // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
 
18453
  case GGML_OP_MUL_MAT_ID:
18454
  {
18455
  cur = 0;
18456
+ const struct ggml_tensor * src0 = node->src[0];
18457
  const struct ggml_tensor * src1 = node->src[1];
18458
  const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
18459
  if (src1->type != vec_dot_type) {
18460
  cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
18461
  }
18462
+ const int n_as = src0->ne[2];
18463
  cur += GGML_PAD(cur, sizeof(int64_t)); // align
18464
  cur += n_as * sizeof(int64_t); // matrix_row_counts
18465
  cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
ggml.h CHANGED
@@ -1164,8 +1164,7 @@ extern "C" {
1164
  // ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
1165
  GGML_API struct ggml_tensor * ggml_mul_mat_id(
1166
  struct ggml_context * ctx,
1167
- struct ggml_tensor * const as[],
1168
- int n_as,
1169
  struct ggml_tensor * ids,
1170
  int id,
1171
  struct ggml_tensor * b);
 
1164
  // ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
1165
  GGML_API struct ggml_tensor * ggml_mul_mat_id(
1166
  struct ggml_context * ctx,
1167
+ struct ggml_tensor * as,
 
1168
  struct ggml_tensor * ids,
1169
  int id,
1170
  struct ggml_tensor * b);