Spaces:
Running
ggml : mul_mat_id use the same tensor for all the experts (llama/6387)
Browse files* ggml : update mul_mat_id to use the same tensor for all the experts
* update cuda
* minor
* update metal
* update test-backend-ops
* fix cuda
* Update ggml-metal.m
Co-authored-by: Georgi Gerganov <[email protected]>
* update convert.py
* update convert-hf-to-gguf.py
* update convert.py for mixtral hf models
* Update convert-hf-to-gguf.py
Co-authored-by: Georgi Gerganov <[email protected]>
* cuda : support non-pow-2 number of experts
* allow quantize to work for split and merged experts models in the same way
* cleanup + disable mmap automatically with split tensors models
* update imatrix
* test-backend-ops : test qwen argsort
* update grok model loading
* llama : add merged experts tensors to the grok tensor map
* minor
* gguf : bump version
* fix quantizing of merged experts
* convert-hf-to-gguf.py : update grok (untested)
* make linter happy
* cuda/argsort : use shared memory instead of pool memory
* convert : fix grok tensor names
* metal : add support for non-pow-2 argsort
* llama : more loader cleanup, better error checking
* cuda : fix warning
* llama : still use mmap for loading old models, but copy the data to a host buffer
* add review note
* llama : remove ffn tensor counting + add sanity check
ggml-ci
* convert : fix handling of n_experts == None
ggml-ci
* imatrix : fix ncall counters
* llama : produce error if imatrix size does not match
* quantize : terminate on errors + trace logs
ggml-ci
* metal : pad shared memory to 16 bytes
---------
Co-authored-by: Georgi Gerganov <[email protected]>
- ggml-cuda.cu +17 -197
- ggml-cuda/argsort.cu +39 -13
- ggml-metal.m +100 -109
- ggml-metal.metal +122 -288
- ggml.c +23 -34
- ggml.h +1 -2
|
@@ -401,10 +401,8 @@ GGML_CALL static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t
|
|
| 401 |
GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
| 402 |
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
| 403 |
|
| 404 |
-
if (tensor->view_src != NULL
|
| 405 |
assert(tensor->view_src->buffer->buft == buffer->buft);
|
| 406 |
-
tensor->backend = tensor->view_src->backend;
|
| 407 |
-
tensor->extra = tensor->view_src->extra;
|
| 408 |
return;
|
| 409 |
}
|
| 410 |
|
|
@@ -1962,227 +1960,49 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
| 1962 |
}
|
| 1963 |
}
|
| 1964 |
|
| 1965 |
-
#if 0
|
| 1966 |
-
template<typename ... Srcs>
|
| 1967 |
-
static __global__ void k_compute_batched_ptrs_id(
|
| 1968 |
-
const void ** ptrs_src, void ** ptrs_dst,
|
| 1969 |
-
int ne12, int ne13,
|
| 1970 |
-
int ne23,
|
| 1971 |
-
int nb02, int nb03,
|
| 1972 |
-
int nb12, int nb13,
|
| 1973 |
-
int nb2, int nb3,
|
| 1974 |
-
int r2, int r3,
|
| 1975 |
-
ggml_type src0_type, half * src0_as_f16, int64_t src0_ne,
|
| 1976 |
-
const half * src1_f16, half * dst_f16,
|
| 1977 |
-
const int32_t * ids, const int id,
|
| 1978 |
-
Srcs... src0s) {
|
| 1979 |
-
|
| 1980 |
-
int i = ids[id];
|
| 1981 |
-
|
| 1982 |
-
half * src0_f16;
|
| 1983 |
-
const void * srcs_ar[] = { (const half *) src0s... };
|
| 1984 |
-
if (src0_type == GGML_TYPE_F16) {
|
| 1985 |
-
src0_f16 = (half *) srcs_ar[i];
|
| 1986 |
-
} else {
|
| 1987 |
-
src0_f16 = src0_as_f16;
|
| 1988 |
-
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
| 1989 |
-
const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type);
|
| 1990 |
-
to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget);
|
| 1991 |
-
}
|
| 1992 |
-
}
|
| 1993 |
-
|
| 1994 |
-
int i13 = blockIdx.x * blockDim.x + threadIdx.x;
|
| 1995 |
-
int i12 = blockIdx.y * blockDim.y + threadIdx.y;
|
| 1996 |
-
|
| 1997 |
-
if (i13 >= ne13 || i12 >= ne12) {
|
| 1998 |
-
return;
|
| 1999 |
-
}
|
| 2000 |
-
|
| 2001 |
-
int i03 = i13 / r3;
|
| 2002 |
-
int i02 = i12 / r2;
|
| 2003 |
-
|
| 2004 |
-
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02 + i03*nb03;
|
| 2005 |
-
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2;
|
| 2006 |
-
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
|
| 2007 |
-
}
|
| 2008 |
-
|
| 2009 |
-
static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
|
| 2010 |
-
const struct ggml_tensor * ids = dst->src[0];
|
| 2011 |
-
const struct ggml_tensor * src1 = dst->src[1];
|
| 2012 |
-
const struct ggml_tensor * src00 = dst->src[2];
|
| 2013 |
-
|
| 2014 |
-
const int id = dst->op_params[0];
|
| 2015 |
-
|
| 2016 |
-
GGML_ASSERT(!ggml_is_transposed(src00));
|
| 2017 |
-
GGML_ASSERT(!ggml_is_transposed(src1));
|
| 2018 |
-
|
| 2019 |
-
GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
| 2020 |
-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 2021 |
-
|
| 2022 |
-
const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00);
|
| 2023 |
-
const int64_t ne01 = src00->ne[1];
|
| 2024 |
-
const int64_t ne02 = src00->ne[2];
|
| 2025 |
-
const int64_t ne03 = src00->ne[3];
|
| 2026 |
-
|
| 2027 |
-
//const int64_t nb01 = src00->nb[1];
|
| 2028 |
-
const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02);
|
| 2029 |
-
const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03);
|
| 2030 |
-
|
| 2031 |
-
const int64_t ne10 = src1->ne[0];
|
| 2032 |
-
const int64_t ne11 = src1->ne[1];
|
| 2033 |
-
const int64_t ne12 = src1->ne[2];
|
| 2034 |
-
const int64_t ne13 = src1->ne[3];
|
| 2035 |
-
|
| 2036 |
-
//const int64_t nb11 = src1->nb[1];
|
| 2037 |
-
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
|
| 2038 |
-
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
|
| 2039 |
-
|
| 2040 |
-
const int64_t ne1 = ggml_nelements(src1);
|
| 2041 |
-
const int64_t ne = ggml_nelements(dst);
|
| 2042 |
-
|
| 2043 |
-
ggml_cuda_set_device(g_main_device);
|
| 2044 |
-
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
| 2045 |
-
|
| 2046 |
-
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
|
| 2047 |
-
|
| 2048 |
-
//ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
| 2049 |
-
//void * src0_ddq = src0_extra->data_device[g_main_device];
|
| 2050 |
-
//half * src0_as_f16 = (half *) src0_ddq;
|
| 2051 |
-
|
| 2052 |
-
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
| 2053 |
-
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
| 2054 |
-
|
| 2055 |
-
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
| 2056 |
-
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
| 2057 |
-
|
| 2058 |
-
// convert src1 to fp16
|
| 2059 |
-
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
| 2060 |
-
GGML_ASSERT(to_fp16_cuda != nullptr);
|
| 2061 |
-
|
| 2062 |
-
size_t src1_as = 0;
|
| 2063 |
-
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
|
| 2064 |
-
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
| 2065 |
-
|
| 2066 |
-
size_t dst_as = 0;
|
| 2067 |
-
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
|
| 2068 |
-
|
| 2069 |
-
GGML_ASSERT(ne12 % ne02 == 0);
|
| 2070 |
-
GGML_ASSERT(ne13 % ne03 == 0);
|
| 2071 |
-
|
| 2072 |
-
// broadcast factors
|
| 2073 |
-
const int64_t r2 = ne12/ne02;
|
| 2074 |
-
const int64_t r3 = ne13/ne03;
|
| 2075 |
-
|
| 2076 |
-
const half alpha_f16 = 1.0f;
|
| 2077 |
-
const half beta_f16 = 0.0f;
|
| 2078 |
-
|
| 2079 |
-
// use cublasGemmBatchedEx
|
| 2080 |
-
const int ne23 = ne12*ne13;
|
| 2081 |
-
|
| 2082 |
-
const void ** ptrs_src = nullptr;
|
| 2083 |
-
void ** ptrs_dst = nullptr;
|
| 2084 |
-
|
| 2085 |
-
size_t ptrs_src_s = 0;
|
| 2086 |
-
size_t ptrs_dst_s = 0;
|
| 2087 |
-
|
| 2088 |
-
ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
|
| 2089 |
-
ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
|
| 2090 |
-
|
| 2091 |
-
int64_t src0_ne = ggml_nelements(src00);
|
| 2092 |
-
half * src0_as_f16 = nullptr;
|
| 2093 |
-
size_t src0_as = 0;
|
| 2094 |
-
if (src00->type != GGML_TYPE_F16) {
|
| 2095 |
-
src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as);
|
| 2096 |
-
}
|
| 2097 |
-
|
| 2098 |
-
static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6");
|
| 2099 |
-
dim3 block_dims(ne13, ne12);
|
| 2100 |
-
k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>(
|
| 2101 |
-
ptrs_src, ptrs_dst,
|
| 2102 |
-
ne12, ne13,
|
| 2103 |
-
ne23,
|
| 2104 |
-
ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half),
|
| 2105 |
-
nb12, nb13,
|
| 2106 |
-
dst->nb[2], dst->nb[3],
|
| 2107 |
-
r2, r3,
|
| 2108 |
-
src00->type, src0_as_f16, src0_ne,
|
| 2109 |
-
src1_as_f16, dst_f16,
|
| 2110 |
-
(const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id,
|
| 2111 |
-
dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr,
|
| 2112 |
-
dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr,
|
| 2113 |
-
dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr,
|
| 2114 |
-
dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr
|
| 2115 |
-
);
|
| 2116 |
-
CUDA_CHECK(cudaGetLastError());
|
| 2117 |
-
|
| 2118 |
-
CUBLAS_CHECK(
|
| 2119 |
-
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
| 2120 |
-
ne01, ne11, ne10,
|
| 2121 |
-
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00,
|
| 2122 |
-
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10,
|
| 2123 |
-
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
|
| 2124 |
-
ne23,
|
| 2125 |
-
CUBLAS_COMPUTE_16F,
|
| 2126 |
-
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
| 2127 |
-
|
| 2128 |
-
if (src0_as != 0) {
|
| 2129 |
-
ggml_cuda_pool_free(src0_as_f16, src0_as);
|
| 2130 |
-
}
|
| 2131 |
-
if (ptrs_src_s != 0) {
|
| 2132 |
-
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
|
| 2133 |
-
}
|
| 2134 |
-
if (ptrs_dst_s != 0) {
|
| 2135 |
-
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
|
| 2136 |
-
}
|
| 2137 |
-
|
| 2138 |
-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
| 2139 |
-
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
| 2140 |
-
|
| 2141 |
-
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
| 2142 |
-
ggml_cuda_pool_free(dst_f16, dst_as);
|
| 2143 |
-
}
|
| 2144 |
-
#endif
|
| 2145 |
-
|
| 2146 |
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 2147 |
-
#if 0
|
| 2148 |
-
ggml_cuda_mul_mat_id_cublas(dst);
|
| 2149 |
-
// TODO: mmq/mmv support
|
| 2150 |
-
#endif
|
| 2151 |
const ggml_tensor * src0 = dst->src[0];
|
| 2152 |
const ggml_tensor * src1 = dst->src[1];
|
|
|
|
|
|
|
|
|
|
| 2153 |
|
| 2154 |
cudaStream_t stream = ctx.stream();
|
| 2155 |
|
| 2156 |
const size_t nb11 = src1->nb[1];
|
| 2157 |
const size_t nb1 = dst->nb[1];
|
| 2158 |
|
| 2159 |
-
const struct ggml_tensor * ids = src0;
|
| 2160 |
const int32_t id = ((int32_t *) dst->op_params)[0];
|
| 2161 |
-
const int32_t n_as =
|
| 2162 |
|
| 2163 |
std::vector<char> ids_host(ggml_nbytes(ids));
|
| 2164 |
const char * ids_dev = (const char *) ids->data;
|
| 2165 |
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
| 2166 |
CUDA_CHECK(cudaStreamSynchronize(stream));
|
| 2167 |
|
|
|
|
| 2168 |
ggml_tensor src1_row = *src1;
|
| 2169 |
ggml_tensor dst_row = *dst;
|
| 2170 |
|
|
|
|
| 2171 |
char * src1_original = (char *) src1->data;
|
| 2172 |
char * dst_original = (char *) dst->data;
|
| 2173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2174 |
if (src1->ne[1] == 1) {
|
| 2175 |
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
| 2176 |
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
| 2177 |
|
| 2178 |
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
| 2179 |
|
| 2180 |
-
|
| 2181 |
-
|
| 2182 |
src1_row.data = src1_original + i01*src1->nb[1];
|
| 2183 |
dst_row.data = dst_original + i01*dst->nb[1];
|
| 2184 |
|
| 2185 |
-
ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
|
| 2186 |
}
|
| 2187 |
} else {
|
| 2188 |
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
|
@@ -2192,8 +2012,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 2192 |
dst_row.data = dst_contiguous.get();
|
| 2193 |
|
| 2194 |
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
|
| 2195 |
-
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
| 2196 |
-
|
| 2197 |
int64_t num_src1_rows = 0;
|
| 2198 |
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
| 2199 |
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
|
@@ -2213,6 +2031,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 2213 |
continue;
|
| 2214 |
}
|
| 2215 |
|
|
|
|
|
|
|
| 2216 |
src1_row.ne[1] = num_src1_rows;
|
| 2217 |
dst_row.ne[1] = num_src1_rows;
|
| 2218 |
|
|
@@ -2224,7 +2044,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 2224 |
dst_row.nb[2] = num_src1_rows*nb1;
|
| 2225 |
dst_row.nb[3] = num_src1_rows*nb1;
|
| 2226 |
|
| 2227 |
-
ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
|
| 2228 |
|
| 2229 |
num_src1_rows = 0;
|
| 2230 |
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
|
@@ -2389,7 +2209,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2389 |
cudaError_t err = cudaGetLastError();
|
| 2390 |
if (err != cudaSuccess) {
|
| 2391 |
fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
|
| 2392 |
-
|
| 2393 |
}
|
| 2394 |
|
| 2395 |
return true;
|
|
|
|
| 401 |
GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
| 402 |
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
| 403 |
|
| 404 |
+
if (tensor->view_src != NULL) {
|
| 405 |
assert(tensor->view_src->buffer->buft == buffer->buft);
|
|
|
|
|
|
|
| 406 |
return;
|
| 407 |
}
|
| 408 |
|
|
|
|
| 1960 |
}
|
| 1961 |
}
|
| 1962 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1963 |
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1964 |
const ggml_tensor * src0 = dst->src[0];
|
| 1965 |
const ggml_tensor * src1 = dst->src[1];
|
| 1966 |
+
const ggml_tensor * ids = dst->src[2];
|
| 1967 |
+
|
| 1968 |
+
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
|
| 1969 |
|
| 1970 |
cudaStream_t stream = ctx.stream();
|
| 1971 |
|
| 1972 |
const size_t nb11 = src1->nb[1];
|
| 1973 |
const size_t nb1 = dst->nb[1];
|
| 1974 |
|
|
|
|
| 1975 |
const int32_t id = ((int32_t *) dst->op_params)[0];
|
| 1976 |
+
const int32_t n_as = src0->ne[2];
|
| 1977 |
|
| 1978 |
std::vector<char> ids_host(ggml_nbytes(ids));
|
| 1979 |
const char * ids_dev = (const char *) ids->data;
|
| 1980 |
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
| 1981 |
CUDA_CHECK(cudaStreamSynchronize(stream));
|
| 1982 |
|
| 1983 |
+
ggml_tensor src0_row = *src0;
|
| 1984 |
ggml_tensor src1_row = *src1;
|
| 1985 |
ggml_tensor dst_row = *dst;
|
| 1986 |
|
| 1987 |
+
char * src0_original = (char *) src0->data;
|
| 1988 |
char * src1_original = (char *) src1->data;
|
| 1989 |
char * dst_original = (char *) dst->data;
|
| 1990 |
|
| 1991 |
+
src0_row.ne[2] = 1;
|
| 1992 |
+
src0_row.ne[3] = 1;
|
| 1993 |
+
src0_row.nb[3] = src0->nb[2];
|
| 1994 |
+
|
| 1995 |
if (src1->ne[1] == 1) {
|
| 1996 |
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
| 1997 |
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
| 1998 |
|
| 1999 |
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
| 2000 |
|
| 2001 |
+
src0_row.data = src0_original + row_id*src0->nb[2];
|
|
|
|
| 2002 |
src1_row.data = src1_original + i01*src1->nb[1];
|
| 2003 |
dst_row.data = dst_original + i01*dst->nb[1];
|
| 2004 |
|
| 2005 |
+
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
| 2006 |
}
|
| 2007 |
} else {
|
| 2008 |
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
|
|
|
| 2012 |
dst_row.data = dst_contiguous.get();
|
| 2013 |
|
| 2014 |
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
|
|
|
|
|
|
|
| 2015 |
int64_t num_src1_rows = 0;
|
| 2016 |
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
| 2017 |
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
|
|
|
| 2031 |
continue;
|
| 2032 |
}
|
| 2033 |
|
| 2034 |
+
src0_row.data = src0_original + row_id*src0->nb[2];
|
| 2035 |
+
|
| 2036 |
src1_row.ne[1] = num_src1_rows;
|
| 2037 |
dst_row.ne[1] = num_src1_rows;
|
| 2038 |
|
|
|
|
| 2044 |
dst_row.nb[2] = num_src1_rows*nb1;
|
| 2045 |
dst_row.nb[3] = num_src1_rows*nb1;
|
| 2046 |
|
| 2047 |
+
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
| 2048 |
|
| 2049 |
num_src1_rows = 0;
|
| 2050 |
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
|
|
|
| 2209 |
cudaError_t err = cudaGetLastError();
|
| 2210 |
if (err != cudaSuccess) {
|
| 2211 |
fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
|
| 2212 |
+
CUDA_CHECK(err);
|
| 2213 |
}
|
| 2214 |
|
| 2215 |
return true;
|
|
@@ -8,32 +8,41 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
|
| 8 |
}
|
| 9 |
|
| 10 |
template<ggml_sort_order order>
|
| 11 |
-
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
|
| 12 |
// bitonic sort
|
| 13 |
int col = threadIdx.x;
|
| 14 |
int row = blockIdx.y;
|
| 15 |
|
| 16 |
-
if (col >=
|
|
|
|
|
|
|
| 17 |
|
| 18 |
const float * x_row = x + row * ncols;
|
| 19 |
-
int
|
| 20 |
|
| 21 |
// initialize indices
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
}
|
| 25 |
__syncthreads();
|
| 26 |
|
| 27 |
-
for (int k = 2; k <=
|
| 28 |
for (int j = k / 2; j > 0; j /= 2) {
|
| 29 |
int ixj = col ^ j;
|
| 30 |
if (ixj > col) {
|
| 31 |
if ((col & k) == 0) {
|
| 32 |
-
if (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
| 34 |
}
|
| 35 |
} else {
|
| 36 |
-
if (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
| 38 |
}
|
| 39 |
}
|
|
@@ -41,18 +50,35 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
|
|
| 41 |
__syncthreads();
|
| 42 |
}
|
| 43 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
}
|
| 45 |
|
| 46 |
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
|
| 47 |
// bitonic sort requires ncols to be power of 2
|
| 48 |
-
|
| 49 |
|
| 50 |
-
const dim3 block_dims(
|
| 51 |
const dim3 block_nums(1, nrows, 1);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
if (order == GGML_SORT_ORDER_ASC) {
|
| 53 |
-
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims,
|
| 54 |
} else if (order == GGML_SORT_ORDER_DESC) {
|
| 55 |
-
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims,
|
| 56 |
} else {
|
| 57 |
GGML_ASSERT(false);
|
| 58 |
}
|
|
|
|
| 8 |
}
|
| 9 |
|
| 10 |
template<ggml_sort_order order>
|
| 11 |
+
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
|
| 12 |
// bitonic sort
|
| 13 |
int col = threadIdx.x;
|
| 14 |
int row = blockIdx.y;
|
| 15 |
|
| 16 |
+
if (col >= ncols_pad) {
|
| 17 |
+
return;
|
| 18 |
+
}
|
| 19 |
|
| 20 |
const float * x_row = x + row * ncols;
|
| 21 |
+
extern __shared__ int dst_row[];
|
| 22 |
|
| 23 |
// initialize indices
|
| 24 |
+
dst_row[col] = col;
|
| 25 |
+
|
|
|
|
| 26 |
__syncthreads();
|
| 27 |
|
| 28 |
+
for (int k = 2; k <= ncols_pad; k *= 2) {
|
| 29 |
for (int j = k / 2; j > 0; j /= 2) {
|
| 30 |
int ixj = col ^ j;
|
| 31 |
if (ixj > col) {
|
| 32 |
if ((col & k) == 0) {
|
| 33 |
+
if (dst_row[col] >= ncols ||
|
| 34 |
+
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
| 35 |
+
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
| 36 |
+
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
| 37 |
+
) {
|
| 38 |
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
| 39 |
}
|
| 40 |
} else {
|
| 41 |
+
if (dst_row[ixj] >= ncols ||
|
| 42 |
+
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
| 43 |
+
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
| 44 |
+
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
| 45 |
+
) {
|
| 46 |
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
| 47 |
}
|
| 48 |
}
|
|
|
|
| 50 |
__syncthreads();
|
| 51 |
}
|
| 52 |
}
|
| 53 |
+
|
| 54 |
+
// copy the result to dst without the padding
|
| 55 |
+
if (col < ncols) {
|
| 56 |
+
dst[row * ncols + col] = dst_row[col];
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
static int next_power_of_2(int x) {
|
| 61 |
+
int n = 1;
|
| 62 |
+
while (n < x) {
|
| 63 |
+
n *= 2;
|
| 64 |
+
}
|
| 65 |
+
return n;
|
| 66 |
}
|
| 67 |
|
| 68 |
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
|
| 69 |
// bitonic sort requires ncols to be power of 2
|
| 70 |
+
const int ncols_pad = next_power_of_2(ncols);
|
| 71 |
|
| 72 |
+
const dim3 block_dims(ncols_pad, 1, 1);
|
| 73 |
const dim3 block_nums(1, nrows, 1);
|
| 74 |
+
const size_t shared_mem = ncols_pad * sizeof(int);
|
| 75 |
+
|
| 76 |
+
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
| 77 |
+
|
| 78 |
if (order == GGML_SORT_ORDER_ASC) {
|
| 79 |
+
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
| 80 |
} else if (order == GGML_SORT_ORDER_DESC) {
|
| 81 |
+
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
| 82 |
} else {
|
| 83 |
GGML_ASSERT(false);
|
| 84 |
}
|
|
@@ -1685,37 +1685,31 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1685 |
{
|
| 1686 |
//GGML_ASSERT(ne00 == ne10);
|
| 1687 |
//GGML_ASSERT(ne03 == ne13);
|
| 1688 |
-
|
| 1689 |
-
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
| 1690 |
-
|
| 1691 |
-
const int n_as = ((int32_t *) dst->op_params)[1];
|
| 1692 |
-
|
| 1693 |
-
// TODO: make this more general
|
| 1694 |
-
GGML_ASSERT(n_as <= 8);
|
| 1695 |
|
| 1696 |
// max size of the src1ids array in the kernel shared buffer
|
| 1697 |
GGML_ASSERT(ne11 <= 4096);
|
| 1698 |
|
| 1699 |
-
|
| 1700 |
-
const int64_t
|
| 1701 |
-
const int64_t
|
| 1702 |
-
const int64_t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1703 |
|
| 1704 |
-
const
|
| 1705 |
-
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
| 1706 |
-
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
| 1707 |
-
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
| 1708 |
|
| 1709 |
-
|
| 1710 |
|
| 1711 |
-
GGML_ASSERT(!ggml_is_transposed(
|
| 1712 |
GGML_ASSERT(!ggml_is_transposed(src1));
|
| 1713 |
|
| 1714 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 1715 |
|
| 1716 |
-
const uint r2 = ne12/ne22;
|
| 1717 |
-
const uint r3 = ne13/ne23;
|
| 1718 |
-
|
| 1719 |
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
| 1720 |
// to the matrix-vector kernel
|
| 1721 |
int ne11_mm_min = n_as;
|
|
@@ -1723,7 +1717,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1723 |
const int idx = ((int32_t *) dst->op_params)[0];
|
| 1724 |
|
| 1725 |
// batch size
|
| 1726 |
-
GGML_ASSERT(
|
|
|
|
|
|
|
|
|
|
| 1727 |
|
| 1728 |
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
| 1729 |
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
@@ -1732,7 +1729,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1732 |
// indirect matrix multiplication
|
| 1733 |
// !!!
|
| 1734 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
| 1735 |
-
|
| 1736 |
ne11 > ne11_mm_min) {
|
| 1737 |
|
| 1738 |
// some Metal matrix data types require aligned pointers
|
|
@@ -1745,7 +1742,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1745 |
|
| 1746 |
id<MTLComputePipelineState> pipeline = nil;
|
| 1747 |
|
| 1748 |
-
switch (
|
| 1749 |
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
| 1750 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
| 1751 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
|
@@ -1774,36 +1771,27 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1774 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1775 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1776 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1777 |
-
[encoder
|
| 1778 |
-
[encoder setBytes:&
|
| 1779 |
-
[encoder setBytes:&
|
| 1780 |
-
[encoder setBytes:&
|
| 1781 |
-
[encoder setBytes:&
|
| 1782 |
-
[encoder setBytes:&
|
| 1783 |
-
[encoder setBytes:&
|
| 1784 |
-
[encoder setBytes:&
|
| 1785 |
-
[encoder setBytes:&
|
| 1786 |
-
[encoder setBytes:&
|
| 1787 |
-
[encoder setBytes:&
|
| 1788 |
-
[encoder setBytes:&
|
| 1789 |
-
[encoder setBytes:&
|
| 1790 |
-
[encoder setBytes:&
|
| 1791 |
-
[encoder setBytes:&
|
| 1792 |
-
[encoder setBytes:&
|
| 1793 |
-
|
| 1794 |
-
for (int j = 0; j < 8; ++j) {
|
| 1795 |
-
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
| 1796 |
-
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
| 1797 |
-
|
| 1798 |
-
size_t offs_src_cur = 0;
|
| 1799 |
-
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
|
| 1800 |
-
|
| 1801 |
-
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
| 1802 |
-
}
|
| 1803 |
|
| 1804 |
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
|
| 1805 |
|
| 1806 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (
|
| 1807 |
} else {
|
| 1808 |
int nth0 = 32;
|
| 1809 |
int nth1 = 1;
|
|
@@ -1813,7 +1801,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1813 |
id<MTLComputePipelineState> pipeline = nil;
|
| 1814 |
|
| 1815 |
// use custom matrix x vector kernel
|
| 1816 |
-
switch (
|
| 1817 |
case GGML_TYPE_F32:
|
| 1818 |
{
|
| 1819 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
@@ -1947,8 +1935,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1947 |
}
|
| 1948 |
};
|
| 1949 |
|
| 1950 |
-
if (ggml_is_quantized(
|
| 1951 |
-
GGML_ASSERT(
|
| 1952 |
}
|
| 1953 |
|
| 1954 |
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
|
|
@@ -1957,75 +1945,66 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1957 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1958 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1959 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1960 |
-
[encoder
|
| 1961 |
-
[encoder setBytes:&
|
| 1962 |
-
[encoder setBytes:&
|
| 1963 |
-
[encoder setBytes:&
|
| 1964 |
-
[encoder setBytes:&
|
| 1965 |
-
[encoder setBytes:&
|
| 1966 |
-
[encoder setBytes:&
|
| 1967 |
-
[encoder setBytes:&
|
| 1968 |
-
[encoder setBytes:&
|
| 1969 |
-
[encoder setBytes:&
|
| 1970 |
-
[encoder setBytes:&
|
| 1971 |
-
[encoder setBytes:&
|
| 1972 |
-
[encoder setBytes:&
|
| 1973 |
-
[encoder setBytes:&
|
| 1974 |
-
[encoder setBytes:&
|
| 1975 |
-
[encoder setBytes:&
|
| 1976 |
-
[encoder setBytes:&
|
| 1977 |
-
[encoder setBytes:&
|
| 1978 |
-
[encoder setBytes:&
|
| 1979 |
-
[encoder setBytes:&
|
| 1980 |
-
|
| 1981 |
-
for (int j = 0; j < 8; ++j) {
|
| 1982 |
-
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
| 1983 |
-
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
| 1984 |
-
|
| 1985 |
-
size_t offs_src_cur = 0;
|
| 1986 |
-
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
|
| 1987 |
-
|
| 1988 |
-
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
| 1989 |
-
}
|
| 1990 |
|
| 1991 |
-
if (
|
| 1992 |
-
|
| 1993 |
-
|
| 1994 |
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
| 1995 |
}
|
| 1996 |
-
else if (
|
| 1997 |
-
const int mem_size =
|
| 1998 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1999 |
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
| 2000 |
}
|
| 2001 |
-
else if (
|
| 2002 |
-
const int mem_size =
|
| 2003 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2004 |
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
| 2005 |
}
|
| 2006 |
-
else if (
|
| 2007 |
const int mem_size = 32*sizeof(float);
|
| 2008 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2009 |
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
| 2010 |
}
|
| 2011 |
-
else if (
|
| 2012 |
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
| 2013 |
}
|
| 2014 |
-
else if (
|
| 2015 |
#ifdef GGML_QKK_64
|
| 2016 |
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
| 2017 |
#else
|
| 2018 |
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
| 2019 |
#endif
|
| 2020 |
}
|
| 2021 |
-
else if (
|
| 2022 |
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
| 2023 |
}
|
| 2024 |
-
else if (
|
| 2025 |
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
| 2026 |
} else {
|
| 2027 |
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
| 2028 |
-
[encoder dispatchThreadgroups:MTLSizeMake(
|
| 2029 |
}
|
| 2030 |
}
|
| 2031 |
} break;
|
|
@@ -2432,6 +2411,16 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2432 |
|
| 2433 |
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
| 2434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2435 |
id<MTLComputePipelineState> pipeline = nil;
|
| 2436 |
|
| 2437 |
switch (order) {
|
|
@@ -2441,11 +2430,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2441 |
};
|
| 2442 |
|
| 2443 |
[encoder setComputePipelineState:pipeline];
|
| 2444 |
-
[encoder setBuffer:id_src0
|
| 2445 |
-
[encoder setBuffer:id_dst
|
| 2446 |
-
[encoder setBytes:&ne00
|
|
|
|
|
|
|
| 2447 |
|
| 2448 |
-
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(
|
| 2449 |
} break;
|
| 2450 |
case GGML_OP_LEAKY_RELU:
|
| 2451 |
{
|
|
|
|
| 1685 |
{
|
| 1686 |
//GGML_ASSERT(ne00 == ne10);
|
| 1687 |
//GGML_ASSERT(ne03 == ne13);
|
| 1688 |
+
const int n_as = src0->ne[2];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1689 |
|
| 1690 |
// max size of the src1ids array in the kernel shared buffer
|
| 1691 |
GGML_ASSERT(ne11 <= 4096);
|
| 1692 |
|
| 1693 |
+
// src2 = ids
|
| 1694 |
+
const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
|
| 1695 |
+
const int64_t ne21 = src2->ne[1];
|
| 1696 |
+
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
|
| 1697 |
+
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
|
| 1698 |
+
|
| 1699 |
+
const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
|
| 1700 |
+
const uint64_t nb21 = src2->nb[1];
|
| 1701 |
+
const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
|
| 1702 |
+
const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
|
| 1703 |
|
| 1704 |
+
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
|
|
|
|
|
|
|
|
|
| 1705 |
|
| 1706 |
+
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
| 1707 |
|
| 1708 |
+
GGML_ASSERT(!ggml_is_transposed(src0));
|
| 1709 |
GGML_ASSERT(!ggml_is_transposed(src1));
|
| 1710 |
|
| 1711 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 1712 |
|
|
|
|
|
|
|
|
|
|
| 1713 |
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
| 1714 |
// to the matrix-vector kernel
|
| 1715 |
int ne11_mm_min = n_as;
|
|
|
|
| 1717 |
const int idx = ((int32_t *) dst->op_params)[0];
|
| 1718 |
|
| 1719 |
// batch size
|
| 1720 |
+
GGML_ASSERT(ne21 == ne11); // ?
|
| 1721 |
+
GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
|
| 1722 |
+
const uint r2 = 1;
|
| 1723 |
+
const uint r3 = 1;
|
| 1724 |
|
| 1725 |
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
| 1726 |
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
|
|
| 1729 |
// indirect matrix multiplication
|
| 1730 |
// !!!
|
| 1731 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
| 1732 |
+
ne00 % 32 == 0 && ne00 >= 64 &&
|
| 1733 |
ne11 > ne11_mm_min) {
|
| 1734 |
|
| 1735 |
// some Metal matrix data types require aligned pointers
|
|
|
|
| 1742 |
|
| 1743 |
id<MTLComputePipelineState> pipeline = nil;
|
| 1744 |
|
| 1745 |
+
switch (src0->type) {
|
| 1746 |
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
| 1747 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
| 1748 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
|
|
|
| 1771 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1772 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1773 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1774 |
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
| 1775 |
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
|
| 1776 |
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
|
| 1777 |
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
|
| 1778 |
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
| 1779 |
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
| 1780 |
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
|
| 1781 |
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
|
| 1782 |
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
|
| 1783 |
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
|
| 1784 |
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
|
| 1785 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
|
| 1786 |
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
|
| 1787 |
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
|
| 1788 |
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
| 1789 |
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
| 1790 |
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:19];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1791 |
|
| 1792 |
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
|
| 1793 |
|
| 1794 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 1795 |
} else {
|
| 1796 |
int nth0 = 32;
|
| 1797 |
int nth1 = 1;
|
|
|
|
| 1801 |
id<MTLComputePipelineState> pipeline = nil;
|
| 1802 |
|
| 1803 |
// use custom matrix x vector kernel
|
| 1804 |
+
switch (src0t) {
|
| 1805 |
case GGML_TYPE_F32:
|
| 1806 |
{
|
| 1807 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
|
|
| 1935 |
}
|
| 1936 |
};
|
| 1937 |
|
| 1938 |
+
if (ggml_is_quantized(src0t)) {
|
| 1939 |
+
GGML_ASSERT(ne00 >= nth0*nth1);
|
| 1940 |
}
|
| 1941 |
|
| 1942 |
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
|
|
|
|
| 1945 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1946 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1947 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1948 |
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
| 1949 |
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
|
| 1950 |
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
|
| 1951 |
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
|
| 1952 |
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
|
| 1953 |
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
|
| 1954 |
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
|
| 1955 |
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
|
| 1956 |
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
| 1957 |
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
|
| 1958 |
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
| 1959 |
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
| 1960 |
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
| 1961 |
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
| 1962 |
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
| 1963 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
| 1964 |
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
|
| 1965 |
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
|
| 1966 |
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
|
| 1967 |
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
|
| 1968 |
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:23];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1969 |
|
| 1970 |
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
| 1971 |
+
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
| 1972 |
+
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
| 1973 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1974 |
}
|
| 1975 |
+
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
| 1976 |
+
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
| 1977 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1978 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1979 |
}
|
| 1980 |
+
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
| 1981 |
+
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
| 1982 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1983 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1984 |
}
|
| 1985 |
+
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
| 1986 |
const int mem_size = 32*sizeof(float);
|
| 1987 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1988 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1989 |
}
|
| 1990 |
+
else if (src0t == GGML_TYPE_Q4_K) {
|
| 1991 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1992 |
}
|
| 1993 |
+
else if (src0t == GGML_TYPE_Q3_K) {
|
| 1994 |
#ifdef GGML_QKK_64
|
| 1995 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1996 |
#else
|
| 1997 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1998 |
#endif
|
| 1999 |
}
|
| 2000 |
+
else if (src0t == GGML_TYPE_Q5_K) {
|
| 2001 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2002 |
}
|
| 2003 |
+
else if (src0t == GGML_TYPE_Q6_K) {
|
| 2004 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2005 |
} else {
|
| 2006 |
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
| 2007 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2008 |
}
|
| 2009 |
}
|
| 2010 |
} break;
|
|
|
|
| 2411 |
|
| 2412 |
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
| 2413 |
|
| 2414 |
+
// bitonic sort requires the number of elements to be power of 2
|
| 2415 |
+
int64_t ne00_padded = 1;
|
| 2416 |
+
while (ne00_padded < ne00) {
|
| 2417 |
+
ne00_padded *= 2;
|
| 2418 |
+
}
|
| 2419 |
+
|
| 2420 |
+
// Metal kernels require the buffer size to be multiple of 16 bytes
|
| 2421 |
+
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
| 2422 |
+
const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
|
| 2423 |
+
|
| 2424 |
id<MTLComputePipelineState> pipeline = nil;
|
| 2425 |
|
| 2426 |
switch (order) {
|
|
|
|
| 2430 |
};
|
| 2431 |
|
| 2432 |
[encoder setComputePipelineState:pipeline];
|
| 2433 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2434 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 2435 |
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
| 2436 |
+
[encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
|
| 2437 |
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2438 |
|
| 2439 |
+
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
|
| 2440 |
} break;
|
| 2441 |
case GGML_OP_LEAKY_RELU:
|
| 2442 |
{
|
|
@@ -13,8 +13,8 @@ using namespace metal;
|
|
| 13 |
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
| 14 |
|
| 15 |
enum ggml_sort_order {
|
| 16 |
-
|
| 17 |
-
|
| 18 |
};
|
| 19 |
|
| 20 |
// general-purpose kernel for addition, multiplication and division of two tensors
|
|
@@ -1973,9 +1973,11 @@ kernel void kernel_timestep_embedding_f32(
|
|
| 1973 |
|
| 1974 |
// bitonic sort implementation following the CUDA kernels as reference
|
| 1975 |
typedef void (argsort_t)(
|
| 1976 |
-
device const float
|
| 1977 |
-
device int32_t
|
| 1978 |
-
constant int64_t
|
|
|
|
|
|
|
| 1979 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1980 |
uint3 tpitg[[thread_position_in_threadgroup]]);
|
| 1981 |
|
|
@@ -1984,33 +1986,42 @@ kernel void kernel_argsort_f32_i32(
|
|
| 1984 |
device const float * x,
|
| 1985 |
device int32_t * dst,
|
| 1986 |
constant int64_t & ncols,
|
|
|
|
|
|
|
| 1987 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1988 |
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
| 1989 |
// bitonic sort
|
| 1990 |
int col = tpitg[0];
|
| 1991 |
int row = tgpig[1];
|
| 1992 |
|
| 1993 |
-
if (col >=
|
| 1994 |
|
| 1995 |
-
device const float * x_row = x
|
| 1996 |
-
|
| 1997 |
|
| 1998 |
// initialize indices
|
| 1999 |
-
|
| 2000 |
-
|
| 2001 |
-
}
|
| 2002 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2003 |
|
| 2004 |
-
for (int k = 2; k <=
|
| 2005 |
for (int j = k / 2; j > 0; j /= 2) {
|
| 2006 |
int ixj = col ^ j;
|
| 2007 |
if (ixj > col) {
|
| 2008 |
if ((col & k) == 0) {
|
| 2009 |
-
if (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2010 |
SWAP(dst_row[col], dst_row[ixj]);
|
| 2011 |
}
|
| 2012 |
} else {
|
| 2013 |
-
if (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2014 |
SWAP(dst_row[col], dst_row[ixj]);
|
| 2015 |
}
|
| 2016 |
}
|
|
@@ -2018,10 +2029,15 @@ kernel void kernel_argsort_f32_i32(
|
|
| 2018 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2019 |
}
|
| 2020 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2021 |
}
|
| 2022 |
|
| 2023 |
-
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<
|
| 2024 |
-
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<
|
| 2025 |
|
| 2026 |
kernel void kernel_leaky_relu_f32(
|
| 2027 |
device const float * src0,
|
|
@@ -5785,9 +5801,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
| 5785 |
|
| 5786 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
| 5787 |
kernel void kernel_mul_mm_id(
|
| 5788 |
-
device const uchar *
|
| 5789 |
device const uchar * src1,
|
| 5790 |
device float * dst,
|
|
|
|
| 5791 |
constant uint64_t & nbi1,
|
| 5792 |
constant int64_t & ne00,
|
| 5793 |
constant int64_t & ne02,
|
|
@@ -5804,22 +5821,14 @@ kernel void kernel_mul_mm_id(
|
|
| 5804 |
constant uint & r2,
|
| 5805 |
constant uint & r3,
|
| 5806 |
constant int & idx,
|
| 5807 |
-
device const uchar * src00,
|
| 5808 |
-
device const uchar * src01,
|
| 5809 |
-
device const uchar * src02,
|
| 5810 |
-
device const uchar * src03,
|
| 5811 |
-
device const uchar * src04,
|
| 5812 |
-
device const uchar * src05,
|
| 5813 |
-
device const uchar * src06,
|
| 5814 |
-
device const uchar * src07,
|
| 5815 |
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
| 5816 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5817 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 5818 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5819 |
-
device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 5820 |
|
| 5821 |
// expert id
|
| 5822 |
const int32_t id = tgpig.z/(ne12*ne13);
|
|
|
|
| 5823 |
|
| 5824 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 5825 |
|
|
@@ -5834,7 +5843,7 @@ kernel void kernel_mul_mm_id(
|
|
| 5834 |
}
|
| 5835 |
|
| 5836 |
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
| 5837 |
-
|
| 5838 |
src1,
|
| 5839 |
src1ids,
|
| 5840 |
dst,
|
|
@@ -5960,9 +5969,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
|
|
| 5960 |
//
|
| 5961 |
|
| 5962 |
typedef void (mat_mm_id_t)(
|
| 5963 |
-
device const uchar *
|
| 5964 |
device const uchar * src1,
|
| 5965 |
device float * dst,
|
|
|
|
| 5966 |
constant uint64_t & nbi1,
|
| 5967 |
constant int64_t & ne00,
|
| 5968 |
constant int64_t & ne02,
|
|
@@ -5979,14 +5989,6 @@ typedef void (mat_mm_id_t)(
|
|
| 5979 |
constant uint & r2,
|
| 5980 |
constant uint & r3,
|
| 5981 |
constant int & idx,
|
| 5982 |
-
device const uchar * src00,
|
| 5983 |
-
device const uchar * src01,
|
| 5984 |
-
device const uchar * src02,
|
| 5985 |
-
device const uchar * src03,
|
| 5986 |
-
device const uchar * src04,
|
| 5987 |
-
device const uchar * src05,
|
| 5988 |
-
device const uchar * src06,
|
| 5989 |
-
device const uchar * src07,
|
| 5990 |
threadgroup uchar *,
|
| 5991 |
uint3, uint, uint);
|
| 5992 |
|
|
@@ -6022,9 +6024,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
|
|
| 6022 |
|
| 6023 |
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
| 6024 |
kernel void kernel_mul_mv_id_f32_f32(
|
| 6025 |
-
device const char *
|
| 6026 |
device const char * src1,
|
| 6027 |
device float * dst,
|
|
|
|
| 6028 |
constant uint64_t & nbi1,
|
| 6029 |
constant int64_t & ne00,
|
| 6030 |
constant int64_t & ne01,
|
|
@@ -6045,28 +6048,19 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
| 6045 |
constant uint & r2,
|
| 6046 |
constant uint & r3,
|
| 6047 |
constant int & idx,
|
| 6048 |
-
device const char * src00,
|
| 6049 |
-
device const char * src01,
|
| 6050 |
-
device const char * src02,
|
| 6051 |
-
device const char * src03,
|
| 6052 |
-
device const char * src04,
|
| 6053 |
-
device const char * src05,
|
| 6054 |
-
device const char * src06,
|
| 6055 |
-
device const char * src07,
|
| 6056 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6057 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6058 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6059 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6060 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6061 |
-
|
| 6062 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6063 |
|
| 6064 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6065 |
|
| 6066 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 6067 |
|
| 6068 |
kernel_mul_mv_f32_f32_impl(
|
| 6069 |
-
src0
|
| 6070 |
src1 + bid*nb11,
|
| 6071 |
dst + bid*ne0,
|
| 6072 |
ne00,
|
|
@@ -6091,9 +6085,10 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
| 6091 |
|
| 6092 |
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
| 6093 |
kernel void kernel_mul_mv_id_f16_f32(
|
| 6094 |
-
device const char *
|
| 6095 |
device const char * src1,
|
| 6096 |
device float * dst,
|
|
|
|
| 6097 |
constant uint64_t & nbi1,
|
| 6098 |
constant int64_t & ne00,
|
| 6099 |
constant int64_t & ne01,
|
|
@@ -6114,28 +6109,19 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
| 6114 |
constant uint & r2,
|
| 6115 |
constant uint & r3,
|
| 6116 |
constant int & idx,
|
| 6117 |
-
device const char * src00,
|
| 6118 |
-
device const char * src01,
|
| 6119 |
-
device const char * src02,
|
| 6120 |
-
device const char * src03,
|
| 6121 |
-
device const char * src04,
|
| 6122 |
-
device const char * src05,
|
| 6123 |
-
device const char * src06,
|
| 6124 |
-
device const char * src07,
|
| 6125 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6126 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6127 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6128 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6129 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6130 |
-
|
| 6131 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6132 |
|
| 6133 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6134 |
|
| 6135 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 6136 |
|
| 6137 |
kernel_mul_mv_f16_f32_impl(
|
| 6138 |
-
src0
|
| 6139 |
src1 + bid*nb11,
|
| 6140 |
dst + bid*ne0,
|
| 6141 |
ne00,
|
|
@@ -6160,9 +6146,10 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
| 6160 |
|
| 6161 |
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
| 6162 |
kernel void kernel_mul_mv_id_q8_0_f32(
|
| 6163 |
-
device const char *
|
| 6164 |
device const char * src1,
|
| 6165 |
device float * dst,
|
|
|
|
| 6166 |
constant uint64_t & nbi1,
|
| 6167 |
constant int64_t & ne00,
|
| 6168 |
constant int64_t & ne01,
|
|
@@ -6183,28 +6170,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
| 6183 |
constant uint & r2,
|
| 6184 |
constant uint & r3,
|
| 6185 |
constant int & idx,
|
| 6186 |
-
device const char * src00,
|
| 6187 |
-
device const char * src01,
|
| 6188 |
-
device const char * src02,
|
| 6189 |
-
device const char * src03,
|
| 6190 |
-
device const char * src04,
|
| 6191 |
-
device const char * src05,
|
| 6192 |
-
device const char * src06,
|
| 6193 |
-
device const char * src07,
|
| 6194 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6195 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6196 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6197 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6198 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6199 |
-
|
| 6200 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6201 |
|
| 6202 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6203 |
|
| 6204 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 6205 |
|
| 6206 |
kernel_mul_mv_q8_0_f32_impl(
|
| 6207 |
-
src0
|
| 6208 |
(device const float *) (src1 + bid*nb11),
|
| 6209 |
dst + bid*ne0,
|
| 6210 |
ne00,
|
|
@@ -6223,9 +6201,10 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
| 6223 |
|
| 6224 |
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
| 6225 |
kernel void kernel_mul_mv_id_q4_0_f32(
|
| 6226 |
-
device const char *
|
| 6227 |
device const char * src1,
|
| 6228 |
device float * dst,
|
|
|
|
| 6229 |
constant uint64_t & nbi1,
|
| 6230 |
constant int64_t & ne00,
|
| 6231 |
constant int64_t & ne01,
|
|
@@ -6246,28 +6225,19 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
| 6246 |
constant uint & r2,
|
| 6247 |
constant uint & r3,
|
| 6248 |
constant int & idx,
|
| 6249 |
-
device const char * src00,
|
| 6250 |
-
device const char * src01,
|
| 6251 |
-
device const char * src02,
|
| 6252 |
-
device const char * src03,
|
| 6253 |
-
device const char * src04,
|
| 6254 |
-
device const char * src05,
|
| 6255 |
-
device const char * src06,
|
| 6256 |
-
device const char * src07,
|
| 6257 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6258 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6259 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6260 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6261 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6262 |
-
|
| 6263 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6264 |
|
| 6265 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6266 |
|
| 6267 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 6268 |
|
| 6269 |
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
| 6270 |
-
src0
|
| 6271 |
(device const float *) (src1 + bid*nb11),
|
| 6272 |
dst + bid*ne0,
|
| 6273 |
ne00,
|
|
@@ -6286,9 +6256,10 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
| 6286 |
|
| 6287 |
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
| 6288 |
kernel void kernel_mul_mv_id_q4_1_f32(
|
| 6289 |
-
device const char *
|
| 6290 |
device const char * src1,
|
| 6291 |
device float * dst,
|
|
|
|
| 6292 |
constant uint64_t & nbi1,
|
| 6293 |
constant int64_t & ne00,
|
| 6294 |
constant int64_t & ne01,
|
|
@@ -6309,28 +6280,19 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
| 6309 |
constant uint & r2,
|
| 6310 |
constant uint & r3,
|
| 6311 |
constant int & idx,
|
| 6312 |
-
device const char * src00,
|
| 6313 |
-
device const char * src01,
|
| 6314 |
-
device const char * src02,
|
| 6315 |
-
device const char * src03,
|
| 6316 |
-
device const char * src04,
|
| 6317 |
-
device const char * src05,
|
| 6318 |
-
device const char * src06,
|
| 6319 |
-
device const char * src07,
|
| 6320 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6321 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6322 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6323 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6324 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6325 |
-
|
| 6326 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6327 |
|
| 6328 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6329 |
|
| 6330 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 6331 |
|
| 6332 |
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
| 6333 |
-
src0
|
| 6334 |
(device const float *) (src1 + bid*nb11),
|
| 6335 |
dst + bid*ne0,
|
| 6336 |
ne00,
|
|
@@ -6349,9 +6311,10 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
| 6349 |
|
| 6350 |
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
| 6351 |
kernel void kernel_mul_mv_id_q5_0_f32(
|
| 6352 |
-
device const char *
|
| 6353 |
device const char * src1,
|
| 6354 |
device float * dst,
|
|
|
|
| 6355 |
constant uint64_t & nbi1,
|
| 6356 |
constant int64_t & ne00,
|
| 6357 |
constant int64_t & ne01,
|
|
@@ -6372,28 +6335,19 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
| 6372 |
constant uint & r2,
|
| 6373 |
constant uint & r3,
|
| 6374 |
constant int & idx,
|
| 6375 |
-
device const char * src00,
|
| 6376 |
-
device const char * src01,
|
| 6377 |
-
device const char * src02,
|
| 6378 |
-
device const char * src03,
|
| 6379 |
-
device const char * src04,
|
| 6380 |
-
device const char * src05,
|
| 6381 |
-
device const char * src06,
|
| 6382 |
-
device const char * src07,
|
| 6383 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6384 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6385 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6386 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6387 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6388 |
-
|
| 6389 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6390 |
|
| 6391 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6392 |
|
| 6393 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 6394 |
|
| 6395 |
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
| 6396 |
-
src0
|
| 6397 |
(device const float *) (src1 + bid*nb11),
|
| 6398 |
dst + bid*ne0,
|
| 6399 |
ne00,
|
|
@@ -6412,9 +6366,10 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
| 6412 |
|
| 6413 |
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
| 6414 |
kernel void kernel_mul_mv_id_q5_1_f32(
|
| 6415 |
-
device const char *
|
| 6416 |
device const char * src1,
|
| 6417 |
device float * dst,
|
|
|
|
| 6418 |
constant uint64_t & nbi1,
|
| 6419 |
constant int64_t & ne00,
|
| 6420 |
constant int64_t & ne01,
|
|
@@ -6435,28 +6390,19 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
| 6435 |
constant uint & r2,
|
| 6436 |
constant uint & r3,
|
| 6437 |
constant int & idx,
|
| 6438 |
-
device const char * src00,
|
| 6439 |
-
device const char * src01,
|
| 6440 |
-
device const char * src02,
|
| 6441 |
-
device const char * src03,
|
| 6442 |
-
device const char * src04,
|
| 6443 |
-
device const char * src05,
|
| 6444 |
-
device const char * src06,
|
| 6445 |
-
device const char * src07,
|
| 6446 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6447 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6448 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6449 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6450 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6451 |
-
|
| 6452 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6453 |
|
| 6454 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6455 |
|
| 6456 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 6457 |
|
| 6458 |
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
| 6459 |
-
src0
|
| 6460 |
(device const float *) (src1 + bid*nb11),
|
| 6461 |
dst + bid*ne0,
|
| 6462 |
ne00,
|
|
@@ -6475,9 +6421,10 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
| 6475 |
|
| 6476 |
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
| 6477 |
kernel void kernel_mul_mv_id_q2_K_f32(
|
| 6478 |
-
device const char *
|
| 6479 |
device const char * src1,
|
| 6480 |
device float * dst,
|
|
|
|
| 6481 |
constant uint64_t & nbi1,
|
| 6482 |
constant int64_t & ne00,
|
| 6483 |
constant int64_t & ne01,
|
|
@@ -6498,28 +6445,19 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
| 6498 |
constant uint & r2,
|
| 6499 |
constant uint & r3,
|
| 6500 |
constant int & idx,
|
| 6501 |
-
device const char * src00,
|
| 6502 |
-
device const char * src01,
|
| 6503 |
-
device const char * src02,
|
| 6504 |
-
device const char * src03,
|
| 6505 |
-
device const char * src04,
|
| 6506 |
-
device const char * src05,
|
| 6507 |
-
device const char * src06,
|
| 6508 |
-
device const char * src07,
|
| 6509 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6510 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6511 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6512 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6513 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6514 |
-
|
| 6515 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6516 |
|
| 6517 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6518 |
|
| 6519 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 6520 |
|
| 6521 |
kernel_mul_mv_q2_K_f32_impl(
|
| 6522 |
-
src0
|
| 6523 |
(device const float *) (src1 + bid*nb11),
|
| 6524 |
dst + bid*ne0,
|
| 6525 |
ne00,
|
|
@@ -6538,9 +6476,10 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
| 6538 |
|
| 6539 |
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
| 6540 |
kernel void kernel_mul_mv_id_q3_K_f32(
|
| 6541 |
-
device const char *
|
| 6542 |
device const char * src1,
|
| 6543 |
device float * dst,
|
|
|
|
| 6544 |
constant uint64_t & nbi1,
|
| 6545 |
constant int64_t & ne00,
|
| 6546 |
constant int64_t & ne01,
|
|
@@ -6561,28 +6500,19 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
| 6561 |
constant uint & r2,
|
| 6562 |
constant uint & r3,
|
| 6563 |
constant int & idx,
|
| 6564 |
-
device const char * src00,
|
| 6565 |
-
device const char * src01,
|
| 6566 |
-
device const char * src02,
|
| 6567 |
-
device const char * src03,
|
| 6568 |
-
device const char * src04,
|
| 6569 |
-
device const char * src05,
|
| 6570 |
-
device const char * src06,
|
| 6571 |
-
device const char * src07,
|
| 6572 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6573 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6574 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6575 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6576 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6577 |
-
|
| 6578 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6579 |
|
| 6580 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6581 |
|
| 6582 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 6583 |
|
| 6584 |
kernel_mul_mv_q3_K_f32_impl(
|
| 6585 |
-
src0
|
| 6586 |
(device const float *) (src1 + bid*nb11),
|
| 6587 |
dst + bid*ne0,
|
| 6588 |
ne00,
|
|
@@ -6601,9 +6531,10 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
| 6601 |
|
| 6602 |
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
| 6603 |
kernel void kernel_mul_mv_id_q4_K_f32(
|
| 6604 |
-
device const char *
|
| 6605 |
device const char * src1,
|
| 6606 |
device float * dst,
|
|
|
|
| 6607 |
constant uint64_t & nbi1,
|
| 6608 |
constant int64_t & ne00,
|
| 6609 |
constant int64_t & ne01,
|
|
@@ -6624,28 +6555,19 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
| 6624 |
constant uint & r2,
|
| 6625 |
constant uint & r3,
|
| 6626 |
constant int & idx,
|
| 6627 |
-
device const char * src00,
|
| 6628 |
-
device const char * src01,
|
| 6629 |
-
device const char * src02,
|
| 6630 |
-
device const char * src03,
|
| 6631 |
-
device const char * src04,
|
| 6632 |
-
device const char * src05,
|
| 6633 |
-
device const char * src06,
|
| 6634 |
-
device const char * src07,
|
| 6635 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6636 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6637 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6638 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6639 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6640 |
-
|
| 6641 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6642 |
|
| 6643 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6644 |
|
| 6645 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 6646 |
|
| 6647 |
kernel_mul_mv_q4_K_f32_impl(
|
| 6648 |
-
src0
|
| 6649 |
(device const float *) (src1 + bid*nb11),
|
| 6650 |
dst + bid*ne0,
|
| 6651 |
ne00,
|
|
@@ -6664,9 +6586,10 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
| 6664 |
|
| 6665 |
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
| 6666 |
kernel void kernel_mul_mv_id_q5_K_f32(
|
| 6667 |
-
device const char *
|
| 6668 |
device const char * src1,
|
| 6669 |
device float * dst,
|
|
|
|
| 6670 |
constant uint64_t & nbi1,
|
| 6671 |
constant int64_t & ne00,
|
| 6672 |
constant int64_t & ne01,
|
|
@@ -6687,28 +6610,19 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
| 6687 |
constant uint & r2,
|
| 6688 |
constant uint & r3,
|
| 6689 |
constant int & idx,
|
| 6690 |
-
device const char * src00,
|
| 6691 |
-
device const char * src01,
|
| 6692 |
-
device const char * src02,
|
| 6693 |
-
device const char * src03,
|
| 6694 |
-
device const char * src04,
|
| 6695 |
-
device const char * src05,
|
| 6696 |
-
device const char * src06,
|
| 6697 |
-
device const char * src07,
|
| 6698 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6699 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6700 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6701 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6702 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6703 |
-
|
| 6704 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6705 |
|
| 6706 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6707 |
|
| 6708 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 6709 |
|
| 6710 |
kernel_mul_mv_q5_K_f32_impl(
|
| 6711 |
-
src0
|
| 6712 |
(device const float *) (src1 + bid*nb11),
|
| 6713 |
dst + bid*ne0,
|
| 6714 |
ne00,
|
|
@@ -6727,9 +6641,10 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
| 6727 |
|
| 6728 |
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
| 6729 |
kernel void kernel_mul_mv_id_q6_K_f32(
|
| 6730 |
-
device const char *
|
| 6731 |
device const char * src1,
|
| 6732 |
device float * dst,
|
|
|
|
| 6733 |
constant uint64_t & nbi1,
|
| 6734 |
constant int64_t & ne00,
|
| 6735 |
constant int64_t & ne01,
|
|
@@ -6750,28 +6665,19 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
| 6750 |
constant uint & r2,
|
| 6751 |
constant uint & r3,
|
| 6752 |
constant int & idx,
|
| 6753 |
-
device const char * src00,
|
| 6754 |
-
device const char * src01,
|
| 6755 |
-
device const char * src02,
|
| 6756 |
-
device const char * src03,
|
| 6757 |
-
device const char * src04,
|
| 6758 |
-
device const char * src05,
|
| 6759 |
-
device const char * src06,
|
| 6760 |
-
device const char * src07,
|
| 6761 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6762 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6763 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6764 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6765 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6766 |
-
|
| 6767 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6768 |
|
| 6769 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6770 |
|
| 6771 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 6772 |
|
| 6773 |
kernel_mul_mv_q6_K_f32_impl(
|
| 6774 |
-
src0
|
| 6775 |
(device const float *) (src1 + bid*nb11),
|
| 6776 |
dst + bid*ne0,
|
| 6777 |
ne00,
|
|
@@ -6790,9 +6696,10 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
| 6790 |
|
| 6791 |
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
| 6792 |
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
| 6793 |
-
device const char *
|
| 6794 |
device const char * src1,
|
| 6795 |
device float * dst,
|
|
|
|
| 6796 |
constant uint64_t & nbi1,
|
| 6797 |
constant int64_t & ne00,
|
| 6798 |
constant int64_t & ne01,
|
|
@@ -6813,29 +6720,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
| 6813 |
constant uint & r2,
|
| 6814 |
constant uint & r3,
|
| 6815 |
constant int & idx,
|
| 6816 |
-
device const char * src00,
|
| 6817 |
-
device const char * src01,
|
| 6818 |
-
device const char * src02,
|
| 6819 |
-
device const char * src03,
|
| 6820 |
-
device const char * src04,
|
| 6821 |
-
device const char * src05,
|
| 6822 |
-
device const char * src06,
|
| 6823 |
-
device const char * src07,
|
| 6824 |
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6825 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6826 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6827 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6828 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6829 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6830 |
-
|
| 6831 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6832 |
|
| 6833 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6834 |
|
| 6835 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 6836 |
|
| 6837 |
kernel_mul_mv_iq2_xxs_f32_impl(
|
| 6838 |
-
src0
|
| 6839 |
(device const float *) (src1 + bid*nb11),
|
| 6840 |
dst + bid*ne0,
|
| 6841 |
ne00,
|
|
@@ -6855,9 +6753,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
| 6855 |
|
| 6856 |
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
| 6857 |
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
| 6858 |
-
device const char *
|
| 6859 |
device const char * src1,
|
| 6860 |
device float * dst,
|
|
|
|
| 6861 |
constant uint64_t & nbi1,
|
| 6862 |
constant int64_t & ne00,
|
| 6863 |
constant int64_t & ne01,
|
|
@@ -6878,29 +6777,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
| 6878 |
constant uint & r2,
|
| 6879 |
constant uint & r3,
|
| 6880 |
constant int & idx,
|
| 6881 |
-
device const char * src00,
|
| 6882 |
-
device const char * src01,
|
| 6883 |
-
device const char * src02,
|
| 6884 |
-
device const char * src03,
|
| 6885 |
-
device const char * src04,
|
| 6886 |
-
device const char * src05,
|
| 6887 |
-
device const char * src06,
|
| 6888 |
-
device const char * src07,
|
| 6889 |
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6890 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6891 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6892 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6893 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6894 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6895 |
-
|
| 6896 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6897 |
|
| 6898 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6899 |
|
| 6900 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 6901 |
|
| 6902 |
kernel_mul_mv_iq2_xs_f32_impl(
|
| 6903 |
-
src0
|
| 6904 |
(device const float *) (src1 + bid*nb11),
|
| 6905 |
dst + bid*ne0,
|
| 6906 |
ne00,
|
|
@@ -6920,9 +6810,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
| 6920 |
|
| 6921 |
[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
|
| 6922 |
kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
| 6923 |
-
device const char *
|
| 6924 |
device const char * src1,
|
| 6925 |
device float * dst,
|
|
|
|
| 6926 |
constant uint64_t & nbi1,
|
| 6927 |
constant int64_t & ne00,
|
| 6928 |
constant int64_t & ne01,
|
|
@@ -6943,29 +6834,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
| 6943 |
constant uint & r2,
|
| 6944 |
constant uint & r3,
|
| 6945 |
constant int & idx,
|
| 6946 |
-
device const char * src00,
|
| 6947 |
-
device const char * src01,
|
| 6948 |
-
device const char * src02,
|
| 6949 |
-
device const char * src03,
|
| 6950 |
-
device const char * src04,
|
| 6951 |
-
device const char * src05,
|
| 6952 |
-
device const char * src06,
|
| 6953 |
-
device const char * src07,
|
| 6954 |
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6955 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6956 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6957 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6958 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6959 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6960 |
-
|
| 6961 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6962 |
|
| 6963 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6964 |
|
| 6965 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 6966 |
|
| 6967 |
kernel_mul_mv_iq3_xxs_f32_impl(
|
| 6968 |
-
src0
|
| 6969 |
(device const float *) (src1 + bid*nb11),
|
| 6970 |
dst + bid*ne0,
|
| 6971 |
ne00,
|
|
@@ -6985,9 +6867,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
| 6985 |
|
| 6986 |
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
| 6987 |
kernel void kernel_mul_mv_id_iq3_s_f32(
|
| 6988 |
-
device const char *
|
| 6989 |
device const char * src1,
|
| 6990 |
device float * dst,
|
|
|
|
| 6991 |
constant uint64_t & nbi1,
|
| 6992 |
constant int64_t & ne00,
|
| 6993 |
constant int64_t & ne01,
|
|
@@ -7008,29 +6891,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
|
|
| 7008 |
constant uint & r2,
|
| 7009 |
constant uint & r3,
|
| 7010 |
constant int & idx,
|
| 7011 |
-
device const char * src00,
|
| 7012 |
-
device const char * src01,
|
| 7013 |
-
device const char * src02,
|
| 7014 |
-
device const char * src03,
|
| 7015 |
-
device const char * src04,
|
| 7016 |
-
device const char * src05,
|
| 7017 |
-
device const char * src06,
|
| 7018 |
-
device const char * src07,
|
| 7019 |
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 7020 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 7021 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 7022 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 7023 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 7024 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 7025 |
-
|
| 7026 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 7027 |
|
| 7028 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 7029 |
|
| 7030 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 7031 |
|
| 7032 |
kernel_mul_mv_iq3_s_f32_impl(
|
| 7033 |
-
src0
|
| 7034 |
(device const float *) (src1 + bid*nb11),
|
| 7035 |
dst + bid*ne0,
|
| 7036 |
ne00,
|
|
@@ -7050,9 +6924,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
|
|
| 7050 |
|
| 7051 |
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
|
| 7052 |
kernel void kernel_mul_mv_id_iq2_s_f32(
|
| 7053 |
-
device const char *
|
| 7054 |
device const char * src1,
|
| 7055 |
device float * dst,
|
|
|
|
| 7056 |
constant uint64_t & nbi1,
|
| 7057 |
constant int64_t & ne00,
|
| 7058 |
constant int64_t & ne01,
|
|
@@ -7073,29 +6948,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
|
|
| 7073 |
constant uint & r2,
|
| 7074 |
constant uint & r3,
|
| 7075 |
constant int & idx,
|
| 7076 |
-
device const char * src00,
|
| 7077 |
-
device const char * src01,
|
| 7078 |
-
device const char * src02,
|
| 7079 |
-
device const char * src03,
|
| 7080 |
-
device const char * src04,
|
| 7081 |
-
device const char * src05,
|
| 7082 |
-
device const char * src06,
|
| 7083 |
-
device const char * src07,
|
| 7084 |
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 7085 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 7086 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 7087 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 7088 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 7089 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 7090 |
-
|
| 7091 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 7092 |
|
| 7093 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 7094 |
|
| 7095 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 7096 |
|
| 7097 |
kernel_mul_mv_iq2_s_f32_impl(
|
| 7098 |
-
src0
|
| 7099 |
(device const float *) (src1 + bid*nb11),
|
| 7100 |
dst + bid*ne0,
|
| 7101 |
ne00,
|
|
@@ -7115,9 +6981,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
|
|
| 7115 |
|
| 7116 |
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
| 7117 |
kernel void kernel_mul_mv_id_iq1_s_f32(
|
| 7118 |
-
device const char *
|
| 7119 |
device const char * src1,
|
| 7120 |
device float * dst,
|
|
|
|
| 7121 |
constant uint64_t & nbi1,
|
| 7122 |
constant int64_t & ne00,
|
| 7123 |
constant int64_t & ne01,
|
|
@@ -7138,28 +7005,19 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
| 7138 |
constant uint & r2,
|
| 7139 |
constant uint & r3,
|
| 7140 |
constant int & idx,
|
| 7141 |
-
device const char * src00,
|
| 7142 |
-
device const char * src01,
|
| 7143 |
-
device const char * src02,
|
| 7144 |
-
device const char * src03,
|
| 7145 |
-
device const char * src04,
|
| 7146 |
-
device const char * src05,
|
| 7147 |
-
device const char * src06,
|
| 7148 |
-
device const char * src07,
|
| 7149 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 7150 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 7151 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 7152 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 7153 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 7154 |
-
|
| 7155 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 7156 |
|
| 7157 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 7158 |
|
| 7159 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 7160 |
|
| 7161 |
kernel_mul_mv_iq1_s_f32_impl(
|
| 7162 |
-
src0
|
| 7163 |
(device const float *) (src1 + bid*nb11),
|
| 7164 |
dst + bid*ne0,
|
| 7165 |
ne00,
|
|
@@ -7178,9 +7036,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
| 7178 |
|
| 7179 |
[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
|
| 7180 |
kernel void kernel_mul_mv_id_iq1_m_f32(
|
| 7181 |
-
device const char *
|
| 7182 |
device const char * src1,
|
| 7183 |
device float * dst,
|
|
|
|
| 7184 |
constant uint64_t & nbi1,
|
| 7185 |
constant int64_t & ne00,
|
| 7186 |
constant int64_t & ne01,
|
|
@@ -7201,28 +7060,19 @@ kernel void kernel_mul_mv_id_iq1_m_f32(
|
|
| 7201 |
constant uint & r2,
|
| 7202 |
constant uint & r3,
|
| 7203 |
constant int & idx,
|
| 7204 |
-
device const char * src00,
|
| 7205 |
-
device const char * src01,
|
| 7206 |
-
device const char * src02,
|
| 7207 |
-
device const char * src03,
|
| 7208 |
-
device const char * src04,
|
| 7209 |
-
device const char * src05,
|
| 7210 |
-
device const char * src06,
|
| 7211 |
-
device const char * src07,
|
| 7212 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 7213 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 7214 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 7215 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 7216 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 7217 |
-
|
| 7218 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 7219 |
|
| 7220 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 7221 |
|
| 7222 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 7223 |
|
| 7224 |
kernel_mul_mv_iq1_m_f32_impl(
|
| 7225 |
-
src0
|
| 7226 |
(device const float *) (src1 + bid*nb11),
|
| 7227 |
dst + bid*ne0,
|
| 7228 |
ne00,
|
|
@@ -7241,9 +7091,10 @@ kernel void kernel_mul_mv_id_iq1_m_f32(
|
|
| 7241 |
|
| 7242 |
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
| 7243 |
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
| 7244 |
-
device const char *
|
| 7245 |
device const char * src1,
|
| 7246 |
device float * dst,
|
|
|
|
| 7247 |
constant uint64_t & nbi1,
|
| 7248 |
constant int64_t & ne00,
|
| 7249 |
constant int64_t & ne01,
|
|
@@ -7264,29 +7115,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
| 7264 |
constant uint & r2,
|
| 7265 |
constant uint & r3,
|
| 7266 |
constant int & idx,
|
| 7267 |
-
device const char * src00,
|
| 7268 |
-
device const char * src01,
|
| 7269 |
-
device const char * src02,
|
| 7270 |
-
device const char * src03,
|
| 7271 |
-
device const char * src04,
|
| 7272 |
-
device const char * src05,
|
| 7273 |
-
device const char * src06,
|
| 7274 |
-
device const char * src07,
|
| 7275 |
threadgroup float * shared_values [[threadgroup(0)]],
|
| 7276 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 7277 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 7278 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 7279 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 7280 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 7281 |
-
|
| 7282 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 7283 |
|
| 7284 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 7285 |
|
| 7286 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 7287 |
|
| 7288 |
kernel_mul_mv_iq4_nl_f32_impl(
|
| 7289 |
-
src0
|
| 7290 |
(device const float *) (src1 + bid*nb11),
|
| 7291 |
dst + bid*ne0,
|
| 7292 |
ne00,
|
|
@@ -7306,9 +7148,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
| 7306 |
|
| 7307 |
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
|
| 7308 |
kernel void kernel_mul_mv_id_iq4_xs_f32(
|
| 7309 |
-
device const char *
|
| 7310 |
device const char * src1,
|
| 7311 |
device float * dst,
|
|
|
|
| 7312 |
constant uint64_t & nbi1,
|
| 7313 |
constant int64_t & ne00,
|
| 7314 |
constant int64_t & ne01,
|
|
@@ -7329,33 +7172,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
|
|
| 7329 |
constant uint & r2,
|
| 7330 |
constant uint & r3,
|
| 7331 |
constant int & idx,
|
| 7332 |
-
device const char * src00,
|
| 7333 |
-
device const char * src01,
|
| 7334 |
-
device const char * src02,
|
| 7335 |
-
device const char * src03,
|
| 7336 |
-
device const char * src04,
|
| 7337 |
-
device const char * src05,
|
| 7338 |
-
device const char * src06,
|
| 7339 |
-
device const char * src07,
|
| 7340 |
threadgroup float * shared_values [[threadgroup(0)]],
|
| 7341 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 7342 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 7343 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 7344 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 7345 |
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 7346 |
-
|
| 7347 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 7348 |
|
| 7349 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 7350 |
|
| 7351 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
| 7352 |
|
| 7353 |
#if QK_K == 64
|
| 7354 |
kernel_mul_mv_iq4_nl_f32_impl(
|
| 7355 |
#else
|
| 7356 |
kernel_mul_mv_iq4_xs_f32_impl(
|
| 7357 |
#endif
|
| 7358 |
-
src0
|
| 7359 |
(device const float *) (src1 + bid*nb11),
|
| 7360 |
dst + bid*ne0,
|
| 7361 |
ne00,
|
|
|
|
| 13 |
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
| 14 |
|
| 15 |
enum ggml_sort_order {
|
| 16 |
+
GGML_SORT_ORDER_ASC,
|
| 17 |
+
GGML_SORT_ORDER_DESC,
|
| 18 |
};
|
| 19 |
|
| 20 |
// general-purpose kernel for addition, multiplication and division of two tensors
|
|
|
|
| 1973 |
|
| 1974 |
// bitonic sort implementation following the CUDA kernels as reference
|
| 1975 |
typedef void (argsort_t)(
|
| 1976 |
+
device const float * x,
|
| 1977 |
+
device int32_t * dst,
|
| 1978 |
+
constant int64_t & ncols,
|
| 1979 |
+
constant int64_t & ncols_pad,
|
| 1980 |
+
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
| 1981 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1982 |
uint3 tpitg[[thread_position_in_threadgroup]]);
|
| 1983 |
|
|
|
|
| 1986 |
device const float * x,
|
| 1987 |
device int32_t * dst,
|
| 1988 |
constant int64_t & ncols,
|
| 1989 |
+
constant int64_t & ncols_pad,
|
| 1990 |
+
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
| 1991 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1992 |
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
| 1993 |
// bitonic sort
|
| 1994 |
int col = tpitg[0];
|
| 1995 |
int row = tgpig[1];
|
| 1996 |
|
| 1997 |
+
if (col >= ncols_pad) return;
|
| 1998 |
|
| 1999 |
+
device const float * x_row = x + row * ncols;
|
| 2000 |
+
threadgroup int32_t * dst_row = shared_values;
|
| 2001 |
|
| 2002 |
// initialize indices
|
| 2003 |
+
dst_row[col] = col;
|
| 2004 |
+
|
|
|
|
| 2005 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2006 |
|
| 2007 |
+
for (int k = 2; k <= ncols_pad; k *= 2) {
|
| 2008 |
for (int j = k / 2; j > 0; j /= 2) {
|
| 2009 |
int ixj = col ^ j;
|
| 2010 |
if (ixj > col) {
|
| 2011 |
if ((col & k) == 0) {
|
| 2012 |
+
if (dst_row[col] >= ncols ||
|
| 2013 |
+
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
| 2014 |
+
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
| 2015 |
+
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
| 2016 |
+
) {
|
| 2017 |
SWAP(dst_row[col], dst_row[ixj]);
|
| 2018 |
}
|
| 2019 |
} else {
|
| 2020 |
+
if (dst_row[ixj] >= ncols ||
|
| 2021 |
+
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
| 2022 |
+
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
| 2023 |
+
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
| 2024 |
+
) {
|
| 2025 |
SWAP(dst_row[col], dst_row[ixj]);
|
| 2026 |
}
|
| 2027 |
}
|
|
|
|
| 2029 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2030 |
}
|
| 2031 |
}
|
| 2032 |
+
|
| 2033 |
+
// copy the result to dst without the padding
|
| 2034 |
+
if (col < ncols) {
|
| 2035 |
+
dst[row * ncols + col] = dst_row[col];
|
| 2036 |
+
}
|
| 2037 |
}
|
| 2038 |
|
| 2039 |
+
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
|
| 2040 |
+
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
|
| 2041 |
|
| 2042 |
kernel void kernel_leaky_relu_f32(
|
| 2043 |
device const float * src0,
|
|
|
|
| 5801 |
|
| 5802 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
| 5803 |
kernel void kernel_mul_mm_id(
|
| 5804 |
+
device const uchar * src0s,
|
| 5805 |
device const uchar * src1,
|
| 5806 |
device float * dst,
|
| 5807 |
+
device const uchar * ids,
|
| 5808 |
constant uint64_t & nbi1,
|
| 5809 |
constant int64_t & ne00,
|
| 5810 |
constant int64_t & ne02,
|
|
|
|
| 5821 |
constant uint & r2,
|
| 5822 |
constant uint & r3,
|
| 5823 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5824 |
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
| 5825 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5826 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 5827 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 5828 |
|
| 5829 |
// expert id
|
| 5830 |
const int32_t id = tgpig.z/(ne12*ne13);
|
| 5831 |
+
device const uchar * src0 = src0s + id*nb02;
|
| 5832 |
|
| 5833 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 5834 |
|
|
|
|
| 5843 |
}
|
| 5844 |
|
| 5845 |
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
| 5846 |
+
src0,
|
| 5847 |
src1,
|
| 5848 |
src1ids,
|
| 5849 |
dst,
|
|
|
|
| 5969 |
//
|
| 5970 |
|
| 5971 |
typedef void (mat_mm_id_t)(
|
| 5972 |
+
device const uchar * src0s,
|
| 5973 |
device const uchar * src1,
|
| 5974 |
device float * dst,
|
| 5975 |
+
device const uchar * ids,
|
| 5976 |
constant uint64_t & nbi1,
|
| 5977 |
constant int64_t & ne00,
|
| 5978 |
constant int64_t & ne02,
|
|
|
|
| 5989 |
constant uint & r2,
|
| 5990 |
constant uint & r3,
|
| 5991 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5992 |
threadgroup uchar *,
|
| 5993 |
uint3, uint, uint);
|
| 5994 |
|
|
|
|
| 6024 |
|
| 6025 |
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
| 6026 |
kernel void kernel_mul_mv_id_f32_f32(
|
| 6027 |
+
device const char * src0s,
|
| 6028 |
device const char * src1,
|
| 6029 |
device float * dst,
|
| 6030 |
+
device const char * ids,
|
| 6031 |
constant uint64_t & nbi1,
|
| 6032 |
constant int64_t & ne00,
|
| 6033 |
constant int64_t & ne01,
|
|
|
|
| 6048 |
constant uint & r2,
|
| 6049 |
constant uint & r3,
|
| 6050 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6051 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6052 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6053 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6054 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6055 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6056 |
|
| 6057 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6058 |
|
| 6059 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6060 |
+
device const char * src0 = src0s + id*nb02;
|
| 6061 |
|
| 6062 |
kernel_mul_mv_f32_f32_impl(
|
| 6063 |
+
src0,
|
| 6064 |
src1 + bid*nb11,
|
| 6065 |
dst + bid*ne0,
|
| 6066 |
ne00,
|
|
|
|
| 6085 |
|
| 6086 |
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
| 6087 |
kernel void kernel_mul_mv_id_f16_f32(
|
| 6088 |
+
device const char * src0s,
|
| 6089 |
device const char * src1,
|
| 6090 |
device float * dst,
|
| 6091 |
+
device const char * ids,
|
| 6092 |
constant uint64_t & nbi1,
|
| 6093 |
constant int64_t & ne00,
|
| 6094 |
constant int64_t & ne01,
|
|
|
|
| 6109 |
constant uint & r2,
|
| 6110 |
constant uint & r3,
|
| 6111 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6112 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6113 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6114 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6115 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6116 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6117 |
|
| 6118 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6119 |
|
| 6120 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6121 |
+
device const char * src0 = src0s + id*nb02;
|
| 6122 |
|
| 6123 |
kernel_mul_mv_f16_f32_impl(
|
| 6124 |
+
src0,
|
| 6125 |
src1 + bid*nb11,
|
| 6126 |
dst + bid*ne0,
|
| 6127 |
ne00,
|
|
|
|
| 6146 |
|
| 6147 |
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
| 6148 |
kernel void kernel_mul_mv_id_q8_0_f32(
|
| 6149 |
+
device const char * src0s,
|
| 6150 |
device const char * src1,
|
| 6151 |
device float * dst,
|
| 6152 |
+
device const char * ids,
|
| 6153 |
constant uint64_t & nbi1,
|
| 6154 |
constant int64_t & ne00,
|
| 6155 |
constant int64_t & ne01,
|
|
|
|
| 6170 |
constant uint & r2,
|
| 6171 |
constant uint & r3,
|
| 6172 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6173 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6174 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6175 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6176 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6177 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6178 |
|
| 6179 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6180 |
|
| 6181 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6182 |
+
device const char * src0 = src0s + id*nb02;
|
| 6183 |
|
| 6184 |
kernel_mul_mv_q8_0_f32_impl(
|
| 6185 |
+
src0,
|
| 6186 |
(device const float *) (src1 + bid*nb11),
|
| 6187 |
dst + bid*ne0,
|
| 6188 |
ne00,
|
|
|
|
| 6201 |
|
| 6202 |
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
| 6203 |
kernel void kernel_mul_mv_id_q4_0_f32(
|
| 6204 |
+
device const char * src0s,
|
| 6205 |
device const char * src1,
|
| 6206 |
device float * dst,
|
| 6207 |
+
device const char * ids,
|
| 6208 |
constant uint64_t & nbi1,
|
| 6209 |
constant int64_t & ne00,
|
| 6210 |
constant int64_t & ne01,
|
|
|
|
| 6225 |
constant uint & r2,
|
| 6226 |
constant uint & r3,
|
| 6227 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6228 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6229 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6230 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6231 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6232 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6233 |
|
| 6234 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6235 |
|
| 6236 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6237 |
+
device const char * src0 = src0s + id*nb02;
|
| 6238 |
|
| 6239 |
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
| 6240 |
+
src0,
|
| 6241 |
(device const float *) (src1 + bid*nb11),
|
| 6242 |
dst + bid*ne0,
|
| 6243 |
ne00,
|
|
|
|
| 6256 |
|
| 6257 |
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
| 6258 |
kernel void kernel_mul_mv_id_q4_1_f32(
|
| 6259 |
+
device const char * src0s,
|
| 6260 |
device const char * src1,
|
| 6261 |
device float * dst,
|
| 6262 |
+
device const char * ids,
|
| 6263 |
constant uint64_t & nbi1,
|
| 6264 |
constant int64_t & ne00,
|
| 6265 |
constant int64_t & ne01,
|
|
|
|
| 6280 |
constant uint & r2,
|
| 6281 |
constant uint & r3,
|
| 6282 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6283 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6284 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6285 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6286 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6287 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6288 |
|
| 6289 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6290 |
|
| 6291 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6292 |
+
device const char * src0 = src0s + id*nb02;
|
| 6293 |
|
| 6294 |
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
| 6295 |
+
src0,
|
| 6296 |
(device const float *) (src1 + bid*nb11),
|
| 6297 |
dst + bid*ne0,
|
| 6298 |
ne00,
|
|
|
|
| 6311 |
|
| 6312 |
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
| 6313 |
kernel void kernel_mul_mv_id_q5_0_f32(
|
| 6314 |
+
device const char * src0s,
|
| 6315 |
device const char * src1,
|
| 6316 |
device float * dst,
|
| 6317 |
+
device const char * ids,
|
| 6318 |
constant uint64_t & nbi1,
|
| 6319 |
constant int64_t & ne00,
|
| 6320 |
constant int64_t & ne01,
|
|
|
|
| 6335 |
constant uint & r2,
|
| 6336 |
constant uint & r3,
|
| 6337 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6338 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6339 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6340 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6341 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6342 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6343 |
|
| 6344 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6345 |
|
| 6346 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6347 |
+
device const char * src0 = src0s + id*nb02;
|
| 6348 |
|
| 6349 |
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
| 6350 |
+
src0,
|
| 6351 |
(device const float *) (src1 + bid*nb11),
|
| 6352 |
dst + bid*ne0,
|
| 6353 |
ne00,
|
|
|
|
| 6366 |
|
| 6367 |
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
| 6368 |
kernel void kernel_mul_mv_id_q5_1_f32(
|
| 6369 |
+
device const char * src0s,
|
| 6370 |
device const char * src1,
|
| 6371 |
device float * dst,
|
| 6372 |
+
device const char * ids,
|
| 6373 |
constant uint64_t & nbi1,
|
| 6374 |
constant int64_t & ne00,
|
| 6375 |
constant int64_t & ne01,
|
|
|
|
| 6390 |
constant uint & r2,
|
| 6391 |
constant uint & r3,
|
| 6392 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6393 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6394 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6395 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6396 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6397 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6398 |
|
| 6399 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6400 |
|
| 6401 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6402 |
+
device const char * src0 = src0s + id*nb02;
|
| 6403 |
|
| 6404 |
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
| 6405 |
+
src0,
|
| 6406 |
(device const float *) (src1 + bid*nb11),
|
| 6407 |
dst + bid*ne0,
|
| 6408 |
ne00,
|
|
|
|
| 6421 |
|
| 6422 |
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
| 6423 |
kernel void kernel_mul_mv_id_q2_K_f32(
|
| 6424 |
+
device const char * src0s,
|
| 6425 |
device const char * src1,
|
| 6426 |
device float * dst,
|
| 6427 |
+
device const char * ids,
|
| 6428 |
constant uint64_t & nbi1,
|
| 6429 |
constant int64_t & ne00,
|
| 6430 |
constant int64_t & ne01,
|
|
|
|
| 6445 |
constant uint & r2,
|
| 6446 |
constant uint & r3,
|
| 6447 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6448 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6449 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6450 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6451 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6452 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6453 |
|
| 6454 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6455 |
|
| 6456 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6457 |
+
device const char * src0 = src0s + id*nb02;
|
| 6458 |
|
| 6459 |
kernel_mul_mv_q2_K_f32_impl(
|
| 6460 |
+
src0,
|
| 6461 |
(device const float *) (src1 + bid*nb11),
|
| 6462 |
dst + bid*ne0,
|
| 6463 |
ne00,
|
|
|
|
| 6476 |
|
| 6477 |
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
| 6478 |
kernel void kernel_mul_mv_id_q3_K_f32(
|
| 6479 |
+
device const char * src0s,
|
| 6480 |
device const char * src1,
|
| 6481 |
device float * dst,
|
| 6482 |
+
device const char * ids,
|
| 6483 |
constant uint64_t & nbi1,
|
| 6484 |
constant int64_t & ne00,
|
| 6485 |
constant int64_t & ne01,
|
|
|
|
| 6500 |
constant uint & r2,
|
| 6501 |
constant uint & r3,
|
| 6502 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6503 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6504 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6505 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6506 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6507 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6508 |
|
| 6509 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6510 |
|
| 6511 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6512 |
+
device const char * src0 = src0s + id*nb02;
|
| 6513 |
|
| 6514 |
kernel_mul_mv_q3_K_f32_impl(
|
| 6515 |
+
src0,
|
| 6516 |
(device const float *) (src1 + bid*nb11),
|
| 6517 |
dst + bid*ne0,
|
| 6518 |
ne00,
|
|
|
|
| 6531 |
|
| 6532 |
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
| 6533 |
kernel void kernel_mul_mv_id_q4_K_f32(
|
| 6534 |
+
device const char * src0s,
|
| 6535 |
device const char * src1,
|
| 6536 |
device float * dst,
|
| 6537 |
+
device const char * ids,
|
| 6538 |
constant uint64_t & nbi1,
|
| 6539 |
constant int64_t & ne00,
|
| 6540 |
constant int64_t & ne01,
|
|
|
|
| 6555 |
constant uint & r2,
|
| 6556 |
constant uint & r3,
|
| 6557 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6558 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6559 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6560 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6561 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6562 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6563 |
|
| 6564 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6565 |
|
| 6566 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6567 |
+
device const char * src0 = src0s + id*nb02;
|
| 6568 |
|
| 6569 |
kernel_mul_mv_q4_K_f32_impl(
|
| 6570 |
+
src0,
|
| 6571 |
(device const float *) (src1 + bid*nb11),
|
| 6572 |
dst + bid*ne0,
|
| 6573 |
ne00,
|
|
|
|
| 6586 |
|
| 6587 |
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
| 6588 |
kernel void kernel_mul_mv_id_q5_K_f32(
|
| 6589 |
+
device const char * src0s,
|
| 6590 |
device const char * src1,
|
| 6591 |
device float * dst,
|
| 6592 |
+
device const char * ids,
|
| 6593 |
constant uint64_t & nbi1,
|
| 6594 |
constant int64_t & ne00,
|
| 6595 |
constant int64_t & ne01,
|
|
|
|
| 6610 |
constant uint & r2,
|
| 6611 |
constant uint & r3,
|
| 6612 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6613 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6614 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6615 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6616 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6617 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6618 |
|
| 6619 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6620 |
|
| 6621 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6622 |
+
device const char * src0 = src0s + id*nb02;
|
| 6623 |
|
| 6624 |
kernel_mul_mv_q5_K_f32_impl(
|
| 6625 |
+
src0,
|
| 6626 |
(device const float *) (src1 + bid*nb11),
|
| 6627 |
dst + bid*ne0,
|
| 6628 |
ne00,
|
|
|
|
| 6641 |
|
| 6642 |
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
| 6643 |
kernel void kernel_mul_mv_id_q6_K_f32(
|
| 6644 |
+
device const char * src0s,
|
| 6645 |
device const char * src1,
|
| 6646 |
device float * dst,
|
| 6647 |
+
device const char * ids,
|
| 6648 |
constant uint64_t & nbi1,
|
| 6649 |
constant int64_t & ne00,
|
| 6650 |
constant int64_t & ne01,
|
|
|
|
| 6665 |
constant uint & r2,
|
| 6666 |
constant uint & r3,
|
| 6667 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6668 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6669 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6670 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6671 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6672 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6673 |
|
| 6674 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6675 |
|
| 6676 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6677 |
+
device const char * src0 = src0s + id*nb02;
|
| 6678 |
|
| 6679 |
kernel_mul_mv_q6_K_f32_impl(
|
| 6680 |
+
src0,
|
| 6681 |
(device const float *) (src1 + bid*nb11),
|
| 6682 |
dst + bid*ne0,
|
| 6683 |
ne00,
|
|
|
|
| 6696 |
|
| 6697 |
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
| 6698 |
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
| 6699 |
+
device const char * src0s,
|
| 6700 |
device const char * src1,
|
| 6701 |
device float * dst,
|
| 6702 |
+
device const char * ids,
|
| 6703 |
constant uint64_t & nbi1,
|
| 6704 |
constant int64_t & ne00,
|
| 6705 |
constant int64_t & ne01,
|
|
|
|
| 6720 |
constant uint & r2,
|
| 6721 |
constant uint & r3,
|
| 6722 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6723 |
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6724 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6725 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6726 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6727 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6728 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6729 |
|
| 6730 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6731 |
|
| 6732 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6733 |
+
device const char * src0 = src0s + id*nb02;
|
| 6734 |
|
| 6735 |
kernel_mul_mv_iq2_xxs_f32_impl(
|
| 6736 |
+
src0,
|
| 6737 |
(device const float *) (src1 + bid*nb11),
|
| 6738 |
dst + bid*ne0,
|
| 6739 |
ne00,
|
|
|
|
| 6753 |
|
| 6754 |
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
| 6755 |
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
| 6756 |
+
device const char * src0s,
|
| 6757 |
device const char * src1,
|
| 6758 |
device float * dst,
|
| 6759 |
+
device const char * ids,
|
| 6760 |
constant uint64_t & nbi1,
|
| 6761 |
constant int64_t & ne00,
|
| 6762 |
constant int64_t & ne01,
|
|
|
|
| 6777 |
constant uint & r2,
|
| 6778 |
constant uint & r3,
|
| 6779 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6780 |
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6781 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6782 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6783 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6784 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6785 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6786 |
|
| 6787 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6788 |
|
| 6789 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6790 |
+
device const char * src0 = src0s + id*nb02;
|
| 6791 |
|
| 6792 |
kernel_mul_mv_iq2_xs_f32_impl(
|
| 6793 |
+
src0,
|
| 6794 |
(device const float *) (src1 + bid*nb11),
|
| 6795 |
dst + bid*ne0,
|
| 6796 |
ne00,
|
|
|
|
| 6810 |
|
| 6811 |
[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
|
| 6812 |
kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
| 6813 |
+
device const char * src0s,
|
| 6814 |
device const char * src1,
|
| 6815 |
device float * dst,
|
| 6816 |
+
device const char * ids,
|
| 6817 |
constant uint64_t & nbi1,
|
| 6818 |
constant int64_t & ne00,
|
| 6819 |
constant int64_t & ne01,
|
|
|
|
| 6834 |
constant uint & r2,
|
| 6835 |
constant uint & r3,
|
| 6836 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6837 |
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6838 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6839 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6840 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6841 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6842 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6843 |
|
| 6844 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6845 |
|
| 6846 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6847 |
+
device const char * src0 = src0s + id*nb02;
|
| 6848 |
|
| 6849 |
kernel_mul_mv_iq3_xxs_f32_impl(
|
| 6850 |
+
src0,
|
| 6851 |
(device const float *) (src1 + bid*nb11),
|
| 6852 |
dst + bid*ne0,
|
| 6853 |
ne00,
|
|
|
|
| 6867 |
|
| 6868 |
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
| 6869 |
kernel void kernel_mul_mv_id_iq3_s_f32(
|
| 6870 |
+
device const char * src0s,
|
| 6871 |
device const char * src1,
|
| 6872 |
device float * dst,
|
| 6873 |
+
device const char * ids,
|
| 6874 |
constant uint64_t & nbi1,
|
| 6875 |
constant int64_t & ne00,
|
| 6876 |
constant int64_t & ne01,
|
|
|
|
| 6891 |
constant uint & r2,
|
| 6892 |
constant uint & r3,
|
| 6893 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6894 |
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6895 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6896 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6897 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6898 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6899 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6900 |
|
| 6901 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6902 |
|
| 6903 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6904 |
+
device const char * src0 = src0s + id*nb02;
|
| 6905 |
|
| 6906 |
kernel_mul_mv_iq3_s_f32_impl(
|
| 6907 |
+
src0,
|
| 6908 |
(device const float *) (src1 + bid*nb11),
|
| 6909 |
dst + bid*ne0,
|
| 6910 |
ne00,
|
|
|
|
| 6924 |
|
| 6925 |
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
|
| 6926 |
kernel void kernel_mul_mv_id_iq2_s_f32(
|
| 6927 |
+
device const char * src0s,
|
| 6928 |
device const char * src1,
|
| 6929 |
device float * dst,
|
| 6930 |
+
device const char * ids,
|
| 6931 |
constant uint64_t & nbi1,
|
| 6932 |
constant int64_t & ne00,
|
| 6933 |
constant int64_t & ne01,
|
|
|
|
| 6948 |
constant uint & r2,
|
| 6949 |
constant uint & r3,
|
| 6950 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6951 |
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6952 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6953 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6954 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6955 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 6956 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6957 |
|
| 6958 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6959 |
|
| 6960 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6961 |
+
device const char * src0 = src0s + id*nb02;
|
| 6962 |
|
| 6963 |
kernel_mul_mv_iq2_s_f32_impl(
|
| 6964 |
+
src0,
|
| 6965 |
(device const float *) (src1 + bid*nb11),
|
| 6966 |
dst + bid*ne0,
|
| 6967 |
ne00,
|
|
|
|
| 6981 |
|
| 6982 |
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
| 6983 |
kernel void kernel_mul_mv_id_iq1_s_f32(
|
| 6984 |
+
device const char * src0s,
|
| 6985 |
device const char * src1,
|
| 6986 |
device float * dst,
|
| 6987 |
+
device const char * ids,
|
| 6988 |
constant uint64_t & nbi1,
|
| 6989 |
constant int64_t & ne00,
|
| 6990 |
constant int64_t & ne01,
|
|
|
|
| 7005 |
constant uint & r2,
|
| 7006 |
constant uint & r3,
|
| 7007 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7008 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 7009 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 7010 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 7011 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 7012 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 7013 |
|
| 7014 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 7015 |
|
| 7016 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 7017 |
+
device const char * src0 = src0s + id*nb02;
|
| 7018 |
|
| 7019 |
kernel_mul_mv_iq1_s_f32_impl(
|
| 7020 |
+
src0,
|
| 7021 |
(device const float *) (src1 + bid*nb11),
|
| 7022 |
dst + bid*ne0,
|
| 7023 |
ne00,
|
|
|
|
| 7036 |
|
| 7037 |
[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
|
| 7038 |
kernel void kernel_mul_mv_id_iq1_m_f32(
|
| 7039 |
+
device const char * src0s,
|
| 7040 |
device const char * src1,
|
| 7041 |
device float * dst,
|
| 7042 |
+
device const char * ids,
|
| 7043 |
constant uint64_t & nbi1,
|
| 7044 |
constant int64_t & ne00,
|
| 7045 |
constant int64_t & ne01,
|
|
|
|
| 7060 |
constant uint & r2,
|
| 7061 |
constant uint & r3,
|
| 7062 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7063 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 7064 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 7065 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 7066 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 7067 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 7068 |
|
| 7069 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 7070 |
|
| 7071 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 7072 |
+
device const char * src0 = src0s + id*nb02;
|
| 7073 |
|
| 7074 |
kernel_mul_mv_iq1_m_f32_impl(
|
| 7075 |
+
src0,
|
| 7076 |
(device const float *) (src1 + bid*nb11),
|
| 7077 |
dst + bid*ne0,
|
| 7078 |
ne00,
|
|
|
|
| 7091 |
|
| 7092 |
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
| 7093 |
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
| 7094 |
+
device const char * src0s,
|
| 7095 |
device const char * src1,
|
| 7096 |
device float * dst,
|
| 7097 |
+
device const char * ids,
|
| 7098 |
constant uint64_t & nbi1,
|
| 7099 |
constant int64_t & ne00,
|
| 7100 |
constant int64_t & ne01,
|
|
|
|
| 7115 |
constant uint & r2,
|
| 7116 |
constant uint & r3,
|
| 7117 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7118 |
threadgroup float * shared_values [[threadgroup(0)]],
|
| 7119 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 7120 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 7121 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 7122 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 7123 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 7124 |
|
| 7125 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 7126 |
|
| 7127 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 7128 |
+
device const char * src0 = src0s + id*nb02;
|
| 7129 |
|
| 7130 |
kernel_mul_mv_iq4_nl_f32_impl(
|
| 7131 |
+
src0,
|
| 7132 |
(device const float *) (src1 + bid*nb11),
|
| 7133 |
dst + bid*ne0,
|
| 7134 |
ne00,
|
|
|
|
| 7148 |
|
| 7149 |
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
|
| 7150 |
kernel void kernel_mul_mv_id_iq4_xs_f32(
|
| 7151 |
+
device const char * src0s,
|
| 7152 |
device const char * src1,
|
| 7153 |
device float * dst,
|
| 7154 |
+
device const char * ids,
|
| 7155 |
constant uint64_t & nbi1,
|
| 7156 |
constant int64_t & ne00,
|
| 7157 |
constant int64_t & ne01,
|
|
|
|
| 7172 |
constant uint & r2,
|
| 7173 |
constant uint & r3,
|
| 7174 |
constant int & idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7175 |
threadgroup float * shared_values [[threadgroup(0)]],
|
| 7176 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 7177 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 7178 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 7179 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
| 7180 |
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 7181 |
|
| 7182 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 7183 |
|
| 7184 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 7185 |
+
device const char * src0 = src0s + id*nb02;
|
| 7186 |
|
| 7187 |
#if QK_K == 64
|
| 7188 |
kernel_mul_mv_iq4_nl_f32_impl(
|
| 7189 |
#else
|
| 7190 |
kernel_mul_mv_iq4_xs_f32_impl(
|
| 7191 |
#endif
|
| 7192 |
+
src0,
|
| 7193 |
(device const float *) (src1 + bid*nb11),
|
| 7194 |
dst + bid*ne0,
|
| 7195 |
ne00,
|
|
@@ -4573,45 +4573,38 @@ void ggml_mul_mat_set_prec(
|
|
| 4573 |
|
| 4574 |
// ggml_mul_mat_id
|
| 4575 |
|
|
|
|
|
|
|
| 4576 |
struct ggml_tensor * ggml_mul_mat_id(
|
| 4577 |
struct ggml_context * ctx,
|
| 4578 |
-
struct ggml_tensor *
|
| 4579 |
-
int n_as,
|
| 4580 |
struct ggml_tensor * ids,
|
| 4581 |
int id,
|
| 4582 |
struct ggml_tensor * b) {
|
| 4583 |
|
| 4584 |
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
| 4585 |
-
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
|
| 4586 |
-
GGML_ASSERT(ids->ne[1] == b->ne[1]);
|
| 4587 |
GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
|
| 4588 |
-
GGML_ASSERT(
|
| 4589 |
-
GGML_ASSERT(
|
| 4590 |
|
| 4591 |
bool is_node = false;
|
| 4592 |
|
| 4593 |
-
if (as
|
| 4594 |
is_node = true;
|
| 4595 |
}
|
| 4596 |
|
| 4597 |
-
const int64_t ne[4] = { as
|
| 4598 |
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 4599 |
|
| 4600 |
ggml_set_op_params_i32(result, 0, id);
|
| 4601 |
-
ggml_set_op_params_i32(result, 1, n_as);
|
| 4602 |
|
| 4603 |
result->op = GGML_OP_MUL_MAT_ID;
|
| 4604 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 4605 |
-
result->src[0] =
|
| 4606 |
result->src[1] = b;
|
| 4607 |
-
|
| 4608 |
-
for (int i = 0; i < n_as; i++) {
|
| 4609 |
-
struct ggml_tensor * a = as[i];
|
| 4610 |
-
GGML_ASSERT(ggml_are_same_shape(as[0], a));
|
| 4611 |
-
GGML_ASSERT(ggml_can_mul_mat(a, b));
|
| 4612 |
-
GGML_ASSERT(!ggml_is_transposed(a));
|
| 4613 |
-
result->src[i + 2] = a;
|
| 4614 |
-
}
|
| 4615 |
|
| 4616 |
return result;
|
| 4617 |
}
|
|
@@ -10948,10 +10941,9 @@ static void ggml_compute_forward_mul_mat_id(
|
|
| 10948 |
const struct ggml_compute_params * params,
|
| 10949 |
struct ggml_tensor * dst) {
|
| 10950 |
|
| 10951 |
-
const struct ggml_tensor *
|
| 10952 |
const struct ggml_tensor * src1 = dst->src[1];
|
| 10953 |
-
|
| 10954 |
-
const struct ggml_tensor * src0 = dst->src[2]; // only for GGML_TENSOR_BINARY_OP_LOCALS
|
| 10955 |
|
| 10956 |
GGML_TENSOR_BINARY_OP_LOCALS
|
| 10957 |
|
|
@@ -10981,13 +10973,13 @@ static void ggml_compute_forward_mul_mat_id(
|
|
| 10981 |
GGML_ASSERT(nb1 <= nb2);
|
| 10982 |
GGML_ASSERT(nb2 <= nb3);
|
| 10983 |
|
| 10984 |
-
// broadcast
|
| 10985 |
-
|
| 10986 |
-
|
| 10987 |
|
| 10988 |
// row groups
|
| 10989 |
const int id = ggml_get_op_params_i32(dst, 0);
|
| 10990 |
-
const int n_as =
|
| 10991 |
|
| 10992 |
char * wdata_src1_end = (src1->type == vec_dot_type) ?
|
| 10993 |
(char *) params->wdata :
|
|
@@ -11047,7 +11039,7 @@ static void ggml_compute_forward_mul_mat_id(
|
|
| 11047 |
continue;
|
| 11048 |
}
|
| 11049 |
|
| 11050 |
-
|
| 11051 |
|
| 11052 |
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
| 11053 |
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
|
@@ -11082,9 +11074,6 @@ static void ggml_compute_forward_mul_mat_id(
|
|
| 11082 |
continue;
|
| 11083 |
}
|
| 11084 |
|
| 11085 |
-
assert(ne12 % ne02 == 0);
|
| 11086 |
-
assert(ne13 % ne03 == 0);
|
| 11087 |
-
|
| 11088 |
// block-tiling attempt
|
| 11089 |
const int64_t blck_0 = 16;
|
| 11090 |
const int64_t blck_1 = 16;
|
|
@@ -11101,14 +11090,14 @@ static void ggml_compute_forward_mul_mat_id(
|
|
| 11101 |
const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
|
| 11102 |
|
| 11103 |
// broadcast src0 into src1
|
| 11104 |
-
const int64_t i03 = i13/r3;
|
| 11105 |
-
const int64_t i02 = i12/r2;
|
| 11106 |
|
| 11107 |
const int64_t i1 = i11;
|
| 11108 |
const int64_t i2 = i12;
|
| 11109 |
const int64_t i3 = i13;
|
| 11110 |
|
| 11111 |
-
const char * src0_row = (const char *)
|
| 11112 |
|
| 11113 |
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
| 11114 |
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
|
@@ -18464,13 +18453,13 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|
| 18464 |
case GGML_OP_MUL_MAT_ID:
|
| 18465 |
{
|
| 18466 |
cur = 0;
|
| 18467 |
-
const struct ggml_tensor * src0 = node->src[
|
| 18468 |
const struct ggml_tensor * src1 = node->src[1];
|
| 18469 |
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
|
| 18470 |
if (src1->type != vec_dot_type) {
|
| 18471 |
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
|
| 18472 |
}
|
| 18473 |
-
const int n_as =
|
| 18474 |
cur += GGML_PAD(cur, sizeof(int64_t)); // align
|
| 18475 |
cur += n_as * sizeof(int64_t); // matrix_row_counts
|
| 18476 |
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
|
|
|
|
| 4573 |
|
| 4574 |
// ggml_mul_mat_id
|
| 4575 |
|
| 4576 |
+
// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed
|
| 4577 |
+
// this will allow computing all the used experts in a single matrix multiplication
|
| 4578 |
struct ggml_tensor * ggml_mul_mat_id(
|
| 4579 |
struct ggml_context * ctx,
|
| 4580 |
+
struct ggml_tensor * as,
|
|
|
|
| 4581 |
struct ggml_tensor * ids,
|
| 4582 |
int id,
|
| 4583 |
struct ggml_tensor * b) {
|
| 4584 |
|
| 4585 |
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
| 4586 |
+
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
|
| 4587 |
+
GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row
|
| 4588 |
GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
|
| 4589 |
+
GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id
|
| 4590 |
+
GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
|
| 4591 |
|
| 4592 |
bool is_node = false;
|
| 4593 |
|
| 4594 |
+
if (as->grad || b->grad) {
|
| 4595 |
is_node = true;
|
| 4596 |
}
|
| 4597 |
|
| 4598 |
+
const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] };
|
| 4599 |
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 4600 |
|
| 4601 |
ggml_set_op_params_i32(result, 0, id);
|
|
|
|
| 4602 |
|
| 4603 |
result->op = GGML_OP_MUL_MAT_ID;
|
| 4604 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 4605 |
+
result->src[0] = as;
|
| 4606 |
result->src[1] = b;
|
| 4607 |
+
result->src[2] = ids;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4608 |
|
| 4609 |
return result;
|
| 4610 |
}
|
|
|
|
| 10941 |
const struct ggml_compute_params * params,
|
| 10942 |
struct ggml_tensor * dst) {
|
| 10943 |
|
| 10944 |
+
const struct ggml_tensor * src0 = dst->src[0];
|
| 10945 |
const struct ggml_tensor * src1 = dst->src[1];
|
| 10946 |
+
const struct ggml_tensor * ids = dst->src[2];
|
|
|
|
| 10947 |
|
| 10948 |
GGML_TENSOR_BINARY_OP_LOCALS
|
| 10949 |
|
|
|
|
| 10973 |
GGML_ASSERT(nb1 <= nb2);
|
| 10974 |
GGML_ASSERT(nb2 <= nb3);
|
| 10975 |
|
| 10976 |
+
// broadcast is not supported with mmid
|
| 10977 |
+
assert(ne12 == 1);
|
| 10978 |
+
assert(ne13 == 1);
|
| 10979 |
|
| 10980 |
// row groups
|
| 10981 |
const int id = ggml_get_op_params_i32(dst, 0);
|
| 10982 |
+
const int n_as = src0->ne[2];
|
| 10983 |
|
| 10984 |
char * wdata_src1_end = (src1->type == vec_dot_type) ?
|
| 10985 |
(char *) params->wdata :
|
|
|
|
| 11039 |
continue;
|
| 11040 |
}
|
| 11041 |
|
| 11042 |
+
size_t src0_offset = cur_a*src0->nb[2];
|
| 11043 |
|
| 11044 |
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
| 11045 |
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
|
|
|
| 11074 |
continue;
|
| 11075 |
}
|
| 11076 |
|
|
|
|
|
|
|
|
|
|
| 11077 |
// block-tiling attempt
|
| 11078 |
const int64_t blck_0 = 16;
|
| 11079 |
const int64_t blck_1 = 16;
|
|
|
|
| 11090 |
const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
|
| 11091 |
|
| 11092 |
// broadcast src0 into src1
|
| 11093 |
+
//const int64_t i03 = i13/r3;
|
| 11094 |
+
//const int64_t i02 = i12/r2;
|
| 11095 |
|
| 11096 |
const int64_t i1 = i11;
|
| 11097 |
const int64_t i2 = i12;
|
| 11098 |
const int64_t i3 = i13;
|
| 11099 |
|
| 11100 |
+
const char * src0_row = (const char *) src0->data + src0_offset;
|
| 11101 |
|
| 11102 |
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
| 11103 |
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
|
|
|
| 18453 |
case GGML_OP_MUL_MAT_ID:
|
| 18454 |
{
|
| 18455 |
cur = 0;
|
| 18456 |
+
const struct ggml_tensor * src0 = node->src[0];
|
| 18457 |
const struct ggml_tensor * src1 = node->src[1];
|
| 18458 |
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
|
| 18459 |
if (src1->type != vec_dot_type) {
|
| 18460 |
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
|
| 18461 |
}
|
| 18462 |
+
const int n_as = src0->ne[2];
|
| 18463 |
cur += GGML_PAD(cur, sizeof(int64_t)); // align
|
| 18464 |
cur += n_as * sizeof(int64_t); // matrix_row_counts
|
| 18465 |
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
|
|
@@ -1164,8 +1164,7 @@ extern "C" {
|
|
| 1164 |
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
|
| 1165 |
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
| 1166 |
struct ggml_context * ctx,
|
| 1167 |
-
struct ggml_tensor *
|
| 1168 |
-
int n_as,
|
| 1169 |
struct ggml_tensor * ids,
|
| 1170 |
int id,
|
| 1171 |
struct ggml_tensor * b);
|
|
|
|
| 1164 |
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
|
| 1165 |
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
| 1166 |
struct ggml_context * ctx,
|
| 1167 |
+
struct ggml_tensor * as,
|
|
|
|
| 1168 |
struct ggml_tensor * ids,
|
| 1169 |
int id,
|
| 1170 |
struct ggml_tensor * b);
|