BB-fat alexju commited on
Commit
092277a
·
1 Parent(s): 838efb6

metal : simplify kernel arguments using a struct (ggml/3229) (llama/12194)

Browse files

* metal : refactor im2col parameters into a struct

* metal: Change im2col offset types from int32_t to uint64_t to support larger memory offsets

* metal : refactor sum_rows parameters into a struct

* metal : refactor soft_max parameters into a struct

* metal : refactor diag_mask_inf parameters into a struct

* metal : refactor ssm_conv parameters into a struct

* metal : refactor ssm_scan parameters into a struct

* metal : refactor get_rows parameters into a struct

* metal : refactor group_norm parameters into a struct

* metal : refactor conv_transpose_1d parameters into a struct

* metal : refactor upscale parameters into a struct

* metal : refactor pad parameters into a struct

* metal : refactor pad_reflect_1d parameters into a struct

* metal : refactor arange parameters into a struct

* metal : refactor timestep_embedding parameters into a struct

* metal : refactor argsort parameters into a struct

* metal : refactor leaky_relu parameters into a struct

* metal : refactor pool_2d parameters into a struct

* metal : fix trailing whitespace

---------

Co-authored-by: alexju <[email protected]>

ggml/src/ggml-metal/ggml-metal-impl.h CHANGED
@@ -285,4 +285,239 @@ typedef struct {
285
  float eps;
286
  } ggml_metal_kargs_rms_norm;
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  #endif // GGML_METAL_IMPL
 
285
  float eps;
286
  } ggml_metal_kargs_rms_norm;
287
 
288
+ typedef struct {
289
+ int64_t ne00;
290
+ int64_t ne01;
291
+ int64_t ne02;
292
+ uint64_t nb00;
293
+ uint64_t nb01;
294
+ uint64_t nb02;
295
+ int32_t n_groups;
296
+ float eps;
297
+ } ggml_metal_kargs_group_norm;
298
+
299
+ typedef struct {
300
+ int32_t IC;
301
+ int32_t IL;
302
+ int32_t K;
303
+ int32_t s0;
304
+ uint64_t nb0;
305
+ uint64_t nb1;
306
+ } ggml_metal_kargs_conv_transpose_1d;
307
+
308
+ typedef struct {
309
+ uint64_t ofs0;
310
+ uint64_t ofs1;
311
+ int32_t IW;
312
+ int32_t IH;
313
+ int32_t CHW;
314
+ int32_t s0;
315
+ int32_t s1;
316
+ int32_t p0;
317
+ int32_t p1;
318
+ int32_t d0;
319
+ int32_t d1;
320
+ int32_t N;
321
+ int32_t KH;
322
+ int32_t KW;
323
+ int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
324
+ } ggml_metal_kargs_im2col;
325
+
326
+ typedef struct {
327
+ int64_t ne00;
328
+ int64_t ne01;
329
+ int64_t ne02;
330
+ int64_t ne03;
331
+ uint64_t nb00;
332
+ uint64_t nb01;
333
+ uint64_t nb02;
334
+ uint64_t nb03;
335
+ int64_t ne10;
336
+ int64_t ne11;
337
+ int64_t ne12;
338
+ int64_t ne13;
339
+ uint64_t nb10;
340
+ uint64_t nb11;
341
+ uint64_t nb12;
342
+ uint64_t nb13;
343
+ int64_t ne0;
344
+ int64_t ne1;
345
+ int64_t ne2;
346
+ int64_t ne3;
347
+ uint64_t nb0;
348
+ uint64_t nb1;
349
+ uint64_t nb2;
350
+ uint64_t nb3;
351
+ } ggml_metal_kargs_sum_rows;
352
+
353
+ typedef struct {
354
+ int64_t ne00;
355
+ int64_t ne01;
356
+ int64_t ne02;
357
+ float scale;
358
+ float max_bias;
359
+ float m0;
360
+ float m1;
361
+ uint32_t n_head_log2;
362
+ } ggml_metal_kargs_soft_max;
363
+
364
+ typedef struct {
365
+ int64_t ne00;
366
+ int64_t ne01;
367
+ int n_past;
368
+ } ggml_metal_kargs_diag_mask_inf;
369
+
370
+ typedef struct {
371
+ int64_t ne00;
372
+ int64_t ne01;
373
+ int64_t ne02;
374
+ uint64_t nb00;
375
+ uint64_t nb01;
376
+ uint64_t nb02;
377
+ int64_t ne10;
378
+ int64_t ne11;
379
+ uint64_t nb10;
380
+ uint64_t nb11;
381
+ int64_t ne0;
382
+ int64_t ne1;
383
+ int64_t ne2;
384
+ uint64_t nb0;
385
+ uint64_t nb1;
386
+ uint64_t nb2;
387
+ } ggml_metal_kargs_ssm_conv;
388
+
389
+ typedef struct {
390
+ int64_t d_state;
391
+ int64_t d_inner;
392
+ int64_t n_seq_tokens;
393
+ int64_t n_seqs;
394
+ uint64_t nb00;
395
+ uint64_t nb01;
396
+ uint64_t nb02;
397
+ uint64_t nb10;
398
+ uint64_t nb11;
399
+ uint64_t nb12;
400
+ uint64_t nb13;
401
+ uint64_t nb20;
402
+ uint64_t nb21;
403
+ uint64_t nb22;
404
+ uint64_t nb30;
405
+ uint64_t nb31;
406
+ uint64_t nb40;
407
+ uint64_t nb41;
408
+ uint64_t nb42;
409
+ uint64_t nb50;
410
+ uint64_t nb51;
411
+ uint64_t nb52;
412
+ } ggml_metal_kargs_ssm_scan;
413
+
414
+ typedef struct {
415
+ int64_t ne00;
416
+ uint64_t nb01;
417
+ uint64_t nb02;
418
+ int64_t ne10;
419
+ uint64_t nb10;
420
+ uint64_t nb11;
421
+ uint64_t nb1;
422
+ uint64_t nb2;
423
+ } ggml_metal_kargs_get_rows;
424
+
425
+ typedef struct {
426
+ int64_t ne00;
427
+ int64_t ne01;
428
+ int64_t ne02;
429
+ int64_t ne03;
430
+ uint64_t nb00;
431
+ uint64_t nb01;
432
+ uint64_t nb02;
433
+ uint64_t nb03;
434
+ int64_t ne0;
435
+ int64_t ne1;
436
+ int64_t ne2;
437
+ int64_t ne3;
438
+ uint64_t nb0;
439
+ uint64_t nb1;
440
+ uint64_t nb2;
441
+ uint64_t nb3;
442
+ float sf0;
443
+ float sf1;
444
+ float sf2;
445
+ float sf3;
446
+ } ggml_metal_kargs_upscale;
447
+
448
+ typedef struct {
449
+ int64_t ne00;
450
+ int64_t ne01;
451
+ int64_t ne02;
452
+ int64_t ne03;
453
+ uint64_t nb00;
454
+ uint64_t nb01;
455
+ uint64_t nb02;
456
+ uint64_t nb03;
457
+ int64_t ne0;
458
+ int64_t ne1;
459
+ int64_t ne2;
460
+ int64_t ne3;
461
+ uint64_t nb0;
462
+ uint64_t nb1;
463
+ uint64_t nb2;
464
+ uint64_t nb3;
465
+ } ggml_metal_kargs_pad;
466
+
467
+ typedef struct {
468
+ int64_t ne00;
469
+ int64_t ne01;
470
+ int64_t ne02;
471
+ int64_t ne03;
472
+ uint64_t nb00;
473
+ uint64_t nb01;
474
+ uint64_t nb02;
475
+ uint64_t nb03;
476
+ int64_t ne0;
477
+ int64_t ne1;
478
+ int64_t ne2;
479
+ int64_t ne3;
480
+ uint64_t nb0;
481
+ uint64_t nb1;
482
+ uint64_t nb2;
483
+ uint64_t nb3;
484
+ int32_t p0;
485
+ int32_t p1;
486
+ } ggml_metal_kargs_pad_reflect_1d;
487
+
488
+ typedef struct {
489
+ uint64_t nb1;
490
+ int dim;
491
+ int max_period;
492
+ } ggml_metal_kargs_timestep_embedding;
493
+
494
+ typedef struct {
495
+ float slope;
496
+ } ggml_metal_kargs_leaky_relu;
497
+
498
+ typedef struct {
499
+ int64_t ncols;
500
+ int64_t ncols_pad;
501
+ } ggml_metal_kargs_argsort;
502
+
503
+ typedef struct {
504
+ int64_t ne0;
505
+ float start;
506
+ float step;
507
+ } ggml_metal_kargs_arange;
508
+
509
+ typedef struct {
510
+ int32_t k0;
511
+ int32_t k1;
512
+ int32_t s0;
513
+ int32_t s1;
514
+ int32_t p0;
515
+ int32_t p1;
516
+ int64_t IH;
517
+ int64_t IW;
518
+ int64_t OH;
519
+ int64_t OW;
520
+ int64_t parallel_elements;
521
+ } ggml_metal_kargs_pool_2d;
522
+
523
  #endif // GGML_METAL_IMPL
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -1945,34 +1945,38 @@ static void ggml_metal_encode_node(
1945
 
1946
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
1947
 
1948
- // TODO: add ggml_metal_kargs struct
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1949
  [encoder setComputePipelineState:pipeline];
1950
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1951
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1952
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1953
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1954
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1955
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1956
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1957
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1958
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1959
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1960
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1961
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1962
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1963
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1964
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1965
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1966
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1967
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1968
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1969
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1970
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1971
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1972
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1973
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1974
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1975
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1976
 
1977
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1978
  } break;
@@ -2021,8 +2025,17 @@ static void ggml_metal_encode_node(
2021
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2022
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2023
 
2024
- // TODO: add ggml_metal_kargs struct
2025
- // TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
 
 
 
 
 
 
 
 
 
2026
  [encoder setComputePipelineState:pipeline];
2027
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2028
  if (id_src1) {
@@ -2031,14 +2044,7 @@ static void ggml_metal_encode_node(
2031
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2032
  }
2033
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2034
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
2035
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
2036
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
2037
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
2038
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
2039
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
2040
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
2041
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
2042
 
2043
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2044
 
@@ -2056,13 +2062,16 @@ static void ggml_metal_encode_node(
2056
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
2057
  }
2058
 
2059
- // TODO: add ggml_metal_kargs struct
 
 
 
 
 
2060
  [encoder setComputePipelineState:pipeline];
2061
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2062
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2063
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2064
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2065
- [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
2066
 
2067
  if (ne00%8 == 0) {
2068
  [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
@@ -2081,27 +2090,30 @@ static void ggml_metal_encode_node(
2081
 
2082
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
2083
 
2084
- // TODO: add ggml_metal_kargs struct
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2085
  [encoder setComputePipelineState:pipeline];
2086
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2087
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2088
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2089
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
2090
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
2091
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
2092
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2093
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2094
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2095
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
2096
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
2097
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
2098
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
2099
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
2100
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
2101
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
2102
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
2103
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
2104
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
2105
 
2106
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2107
  } break;
@@ -2152,7 +2164,31 @@ static void ggml_metal_encode_node(
2152
 
2153
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
2154
 
2155
- // TODO: add ggml_metal_kargs struct
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2156
  [encoder setComputePipelineState:pipeline];
2157
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2158
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2161,30 +2197,7 @@ static void ggml_metal_encode_node(
2161
  [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2162
  [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2163
  [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
2164
-
2165
- [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
2166
- [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
2167
- [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
2168
- [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
2169
-
2170
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
2171
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
2172
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
2173
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
2174
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
2175
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
2176
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
2177
- [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
2178
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
2179
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
2180
- [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
2181
- [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
2182
- [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
2183
- [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
2184
- [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
2185
- [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
2186
- [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
2187
- [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
2188
 
2189
  [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2190
  } break;
@@ -3041,19 +3054,22 @@ static void ggml_metal_encode_node(
3041
  default: GGML_ABORT("not implemented");
3042
  }
3043
 
3044
- // TODO: add ggml_metal_kargs struct
 
 
 
 
 
 
 
 
 
 
3045
  [encoder setComputePipelineState:pipeline];
3046
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3047
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3048
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3049
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
3050
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
3051
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
3052
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
3053
- [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
3054
- [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
3055
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
3056
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
3057
 
3058
  [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
3059
  } break;
@@ -3110,18 +3126,21 @@ static void ggml_metal_encode_node(
3110
 
3111
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
3112
 
3113
- // TODO: add ggml_metal_kargs struct
 
 
 
 
 
 
 
 
 
 
3114
  [encoder setComputePipelineState:pipeline];
3115
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3116
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3117
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
3118
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
3119
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
3120
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
3121
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
3122
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
3123
- [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
3124
- [encoder setBytes:&eps length:sizeof( float) atIndex:9];
3125
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
3126
 
3127
  [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -3279,8 +3298,8 @@ static void ggml_metal_encode_node(
3279
 
3280
  const int32_t CHW = IC * KH * KW;
3281
 
3282
- const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
3283
- const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
3284
 
3285
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
3286
 
@@ -3302,27 +3321,30 @@ static void ggml_metal_encode_node(
3302
  default: GGML_ABORT("fatal error");
3303
  };
3304
 
3305
- // TODO: add ggml_metal_kargs struct
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3306
  [encoder setComputePipelineState:pipeline];
3307
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
3308
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3309
- [encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
3310
- [encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
3311
- [encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
3312
- [encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
3313
- [encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
3314
- [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
3315
- [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
3316
- [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
3317
- [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
3318
- [encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
3319
- [encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
3320
 
3321
  if (is_gt_mttpt) {
3322
- [encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
3323
- [encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
3324
- [encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
3325
-
3326
  const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
3327
 
3328
  const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
@@ -3362,16 +3384,20 @@ static void ggml_metal_encode_node(
3362
  default: GGML_ABORT("fatal error");
3363
  };
3364
 
 
 
 
 
 
 
 
 
 
3365
  [encoder setComputePipelineState:pipeline];
3366
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3367
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3368
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3369
- [encoder setBytes:&IC length:sizeof( int32_t) atIndex:3];
3370
- [encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
3371
- [encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
3372
- [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
3373
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
3374
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
3375
 
3376
  [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3377
  } break;
@@ -3386,30 +3412,33 @@ static void ggml_metal_encode_node(
3386
 
3387
  const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
3388
 
3389
- // TODO: add ggml_metal_kargs struct
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3390
  [encoder setComputePipelineState:pipeline];
3391
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3392
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3393
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
3394
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
3395
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
3396
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
3397
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
3398
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
3399
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
3400
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
3401
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
3402
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
3403
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
3404
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
3405
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
3406
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
3407
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
3408
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
3409
- [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
3410
- [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
3411
- [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
3412
- [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
3413
 
3414
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
3415
 
@@ -3421,26 +3450,29 @@ static void ggml_metal_encode_node(
3421
 
3422
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
3423
 
3424
- // TODO: add ggml_metal_kargs struct
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3425
  [encoder setComputePipelineState:pipeline];
3426
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3427
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3428
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
3429
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
3430
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
3431
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
3432
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
3433
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
3434
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
3435
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
3436
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
3437
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
3438
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
3439
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
3440
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
3441
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
3442
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
3443
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
3444
 
3445
  const int nth = MIN(1024, ne0);
3446
 
@@ -3455,24 +3487,31 @@ static void ggml_metal_encode_node(
3455
 
3456
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
3457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3458
  [encoder setComputePipelineState:pipeline];
3459
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3460
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3461
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
3462
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
3463
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
3464
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
3465
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
3466
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
3467
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
3468
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
3469
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
3470
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
3471
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
3472
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
3473
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
3474
- [encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
3475
- [encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
3476
 
3477
  const int nth = MIN(1024, ne0);
3478
 
@@ -3490,12 +3529,15 @@ static void ggml_metal_encode_node(
3490
 
3491
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
3492
 
3493
- // TODO: add ggml_metal_kargs struct
 
 
 
 
 
3494
  [encoder setComputePipelineState:pipeline];
3495
- [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
3496
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
3497
- [encoder setBytes:&start length:sizeof(start) atIndex:2];
3498
- [encoder setBytes:&step length:sizeof(step) atIndex:3];
3499
 
3500
  const int nth = MIN(1024, ne0);
3501
 
@@ -3512,13 +3554,16 @@ static void ggml_metal_encode_node(
3512
 
3513
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
3514
 
3515
- // TODO: add ggml_metal_kargs struct
 
 
 
 
 
3516
  [encoder setComputePipelineState:pipeline];
3517
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3518
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3519
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
3520
- [encoder setBytes:&dim length:sizeof(dim) atIndex:3];
3521
- [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
3522
 
3523
  const int nth = MIN(1024, half);
3524
 
@@ -3551,12 +3596,15 @@ static void ggml_metal_encode_node(
3551
  default: GGML_ABORT("fatal error");
3552
  };
3553
 
3554
- // TODO: add ggml_metal_kargs struct
 
 
 
 
3555
  [encoder setComputePipelineState:pipeline];
3556
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3557
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3558
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
3559
- [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
3560
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
3561
 
3562
  [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
@@ -3570,11 +3618,14 @@ static void ggml_metal_encode_node(
3570
 
3571
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
3572
 
3573
- // TODO: add ggml_metal_kargs struct
 
 
 
3574
  [encoder setComputePipelineState:pipeline];
3575
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3576
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3577
- [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
3578
 
3579
  const int64_t n = ggml_nelements(dst);
3580
 
@@ -4150,21 +4201,24 @@ static void ggml_metal_encode_node(
4150
  const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
4151
  const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
4152
 
4153
- // TODO: add ggml_metal_kargs struct
 
 
 
 
 
 
 
 
 
 
 
 
 
4154
  [encoder setComputePipelineState:pipeline];
4155
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
4156
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
4157
- [encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
4158
- [encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
4159
- [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
4160
- [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
4161
- [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
4162
- [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
4163
- [encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
4164
- [encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
4165
- [encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
4166
- [encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
4167
- [encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];
4168
 
4169
  [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
4170
  } break;
 
1945
 
1946
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
1947
 
1948
+
1949
+ ggml_metal_kargs_sum_rows args = {
1950
+ /*.ne00 =*/ ne00,
1951
+ /*.ne01 =*/ ne01,
1952
+ /*.ne02 =*/ ne02,
1953
+ /*.ne03 =*/ ne03,
1954
+ /*.nb00 =*/ nb00,
1955
+ /*.nb01 =*/ nb01,
1956
+ /*.nb02 =*/ nb02,
1957
+ /*.nb03 =*/ nb03,
1958
+ /*.ne10 =*/ ne10,
1959
+ /*.ne11 =*/ ne11,
1960
+ /*.ne12 =*/ ne12,
1961
+ /*.ne13 =*/ ne13,
1962
+ /*.nb10 =*/ nb10,
1963
+ /*.nb11 =*/ nb11,
1964
+ /*.nb12 =*/ nb12,
1965
+ /*.nb13 =*/ nb13,
1966
+ /*.ne0 =*/ ne0,
1967
+ /*.ne1 =*/ ne1,
1968
+ /*.ne2 =*/ ne2,
1969
+ /*.ne3 =*/ ne3,
1970
+ /*.nb0 =*/ nb0,
1971
+ /*.nb1 =*/ nb1,
1972
+ /*.nb2 =*/ nb2,
1973
+ /*.nb3 =*/ nb3,
1974
+ };
1975
+
1976
  [encoder setComputePipelineState:pipeline];
1977
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1978
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1979
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1980
 
1981
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1982
  } break;
 
2025
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2026
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2027
 
2028
+ ggml_metal_kargs_soft_max args = {
2029
+ /*.ne00 =*/ ne00,
2030
+ /*.ne01 =*/ ne01,
2031
+ /*.ne02 =*/ ne02,
2032
+ /*.scale =*/ scale,
2033
+ /*.max_bias =*/ max_bias,
2034
+ /*.m0 =*/ m0,
2035
+ /*.m1 =*/ m1,
2036
+ /*.n_head_log2 =*/ n_head_log2,
2037
+ };
2038
+
2039
  [encoder setComputePipelineState:pipeline];
2040
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2041
  if (id_src1) {
 
2044
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2045
  }
2046
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2047
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
 
 
 
 
 
 
 
2048
 
2049
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2050
 
 
2062
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
2063
  }
2064
 
2065
+ ggml_metal_kargs_diag_mask_inf args = {
2066
+ /*.ne00 =*/ ne00,
2067
+ /*.ne01 =*/ ne01,
2068
+ /*.n_past =*/ n_past,
2069
+ };
2070
+
2071
  [encoder setComputePipelineState:pipeline];
2072
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2073
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2074
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
 
2075
 
2076
  if (ne00%8 == 0) {
2077
  [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
 
2090
 
2091
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
2092
 
2093
+ ggml_metal_kargs_ssm_conv args = {
2094
+ /*.ne00 =*/ ne00,
2095
+ /*.ne01 =*/ ne01,
2096
+ /*.ne02 =*/ ne02,
2097
+ /*.nb00 =*/ nb00,
2098
+ /*.nb01 =*/ nb01,
2099
+ /*.nb02 =*/ nb02,
2100
+ /*.ne10 =*/ ne10,
2101
+ /*.ne11 =*/ ne11,
2102
+ /*.nb10 =*/ nb10,
2103
+ /*.nb11 =*/ nb11,
2104
+ /*.ne0 =*/ ne0,
2105
+ /*.ne1 =*/ ne1,
2106
+ /*.ne2 =*/ ne2,
2107
+ /*.nb0 =*/ nb0,
2108
+ /*.nb1 =*/ nb1,
2109
+ /*.nb2 =*/ nb2,
2110
+ };
2111
+
2112
  [encoder setComputePipelineState:pipeline];
2113
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2114
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2115
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2116
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2117
 
2118
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2119
  } break;
 
2164
 
2165
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
2166
 
2167
+ ggml_metal_kargs_ssm_scan args = {
2168
+ /*.d_state =*/ d_state,
2169
+ /*.d_inner =*/ d_inner,
2170
+ /*.n_seq_tokens =*/ n_seq_tokens,
2171
+ /*.n_seqs =*/ n_seqs,
2172
+ /*.nb00 =*/ nb00,
2173
+ /*.nb01 =*/ nb01,
2174
+ /*.nb02 =*/ nb02,
2175
+ /*.nb10 =*/ nb10,
2176
+ /*.nb11 =*/ nb11,
2177
+ /*.nb12 =*/ nb12,
2178
+ /*.nb13 =*/ nb13,
2179
+ /*.nb20 =*/ nb20,
2180
+ /*.nb21 =*/ nb21,
2181
+ /*.nb22 =*/ nb22,
2182
+ /*.nb30 =*/ nb30,
2183
+ /*.nb31 =*/ nb31,
2184
+ /*.nb40 =*/ nb40,
2185
+ /*.nb41 =*/ nb41,
2186
+ /*.nb42 =*/ nb42,
2187
+ /*.nb50 =*/ nb50,
2188
+ /*.nb51 =*/ nb51,
2189
+ /*.nb52 =*/ nb52,
2190
+ };
2191
+
2192
  [encoder setComputePipelineState:pipeline];
2193
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2194
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
 
2197
  [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2198
  [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2199
  [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
2200
+ [encoder setBytes:&args length:sizeof(args) atIndex:7];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2201
 
2202
  [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2203
  } break;
 
3054
  default: GGML_ABORT("not implemented");
3055
  }
3056
 
3057
+ ggml_metal_kargs_get_rows args = {
3058
+ /*.ne00 =*/ ne00,
3059
+ /*.nb01 =*/ nb01,
3060
+ /*.nb02 =*/ nb02,
3061
+ /*.ne10 =*/ ne10,
3062
+ /*.nb10 =*/ nb10,
3063
+ /*.nb11 =*/ nb11,
3064
+ /*.nb1 =*/ nb1,
3065
+ /*.nb2 =*/ nb2,
3066
+ };
3067
+
3068
  [encoder setComputePipelineState:pipeline];
3069
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3070
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3071
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3072
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
 
 
 
 
 
 
 
3073
 
3074
  [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
3075
  } break;
 
3126
 
3127
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
3128
 
3129
+ ggml_metal_kargs_group_norm args = {
3130
+ /*.ne00 =*/ ne00,
3131
+ /*.ne01 =*/ ne01,
3132
+ /*.ne02 =*/ ne02,
3133
+ /*.nb00 =*/ nb00,
3134
+ /*.nb01 =*/ nb01,
3135
+ /*.nb02 =*/ nb02,
3136
+ /*.n_groups =*/ n_groups,
3137
+ /*.eps =*/ eps,
3138
+ };
3139
+
3140
  [encoder setComputePipelineState:pipeline];
3141
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3142
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3143
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
 
 
 
 
 
 
3144
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
3145
 
3146
  [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
 
3298
 
3299
  const int32_t CHW = IC * KH * KW;
3300
 
3301
+ const uint64_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
3302
+ const uint64_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
3303
 
3304
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
3305
 
 
3321
  default: GGML_ABORT("fatal error");
3322
  };
3323
 
3324
+ ggml_metal_kargs_im2col args = {
3325
+ /*.ofs0 =*/ ofs0,
3326
+ /*.ofs1 =*/ ofs1,
3327
+ /*.IW =*/ IW,
3328
+ /*.IH =*/ IH,
3329
+ /*.CHW =*/ CHW,
3330
+ /*.s0 =*/ s0,
3331
+ /*.s1 =*/ s1,
3332
+ /*.p0 =*/ p0,
3333
+ /*.p1 =*/ p1,
3334
+ /*.d0 =*/ d0,
3335
+ /*.d1 =*/ d1,
3336
+ /*.N =*/ N,
3337
+ /*.KH =*/ KH,
3338
+ /*.KW =*/ KW,
3339
+ /*.KHW =*/ KH * KW,
3340
+ };
3341
+
3342
  [encoder setComputePipelineState:pipeline];
3343
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
3344
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3345
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
 
 
 
 
 
 
 
 
 
3346
 
3347
  if (is_gt_mttpt) {
 
 
 
 
3348
  const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
3349
 
3350
  const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
 
3384
  default: GGML_ABORT("fatal error");
3385
  };
3386
 
3387
+ ggml_metal_kargs_conv_transpose_1d args = {
3388
+ /*.IC =*/ IC,
3389
+ /*.IL =*/ IL,
3390
+ /*.K =*/ K,
3391
+ /*.s0 =*/ s0,
3392
+ /*.nb0 =*/ nb0,
3393
+ /*.nb1 =*/ nb1,
3394
+ };
3395
+
3396
  [encoder setComputePipelineState:pipeline];
3397
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3398
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3399
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3400
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
 
 
 
 
 
3401
 
3402
  [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3403
  } break;
 
3412
 
3413
  const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
3414
 
3415
+ ggml_metal_kargs_upscale args = {
3416
+ /*.ne00 =*/ ne00,
3417
+ /*.ne01 =*/ ne01,
3418
+ /*.ne02 =*/ ne02,
3419
+ /*.ne03 =*/ ne03,
3420
+ /*.nb00 =*/ nb00,
3421
+ /*.nb01 =*/ nb01,
3422
+ /*.nb02 =*/ nb02,
3423
+ /*.nb03 =*/ nb03,
3424
+ /*.ne0 =*/ ne0,
3425
+ /*.ne1 =*/ ne1,
3426
+ /*.ne2 =*/ ne2,
3427
+ /*.ne3 =*/ ne3,
3428
+ /*.nb0 =*/ nb0,
3429
+ /*.nb1 =*/ nb1,
3430
+ /*.nb2 =*/ nb2,
3431
+ /*.nb3 =*/ nb3,
3432
+ /*.sf0 =*/ sf0,
3433
+ /*.sf1 =*/ sf1,
3434
+ /*.sf2 =*/ sf2,
3435
+ /*.sf3 =*/ sf3
3436
+ };
3437
+
3438
  [encoder setComputePipelineState:pipeline];
3439
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3440
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3441
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3442
 
3443
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
3444
 
 
3450
 
3451
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
3452
 
3453
+ ggml_metal_kargs_pad args = {
3454
+ /*.ne00 =*/ ne00,
3455
+ /*.ne01 =*/ ne01,
3456
+ /*.ne02 =*/ ne02,
3457
+ /*.ne03 =*/ ne03,
3458
+ /*.nb00 =*/ nb00,
3459
+ /*.nb01 =*/ nb01,
3460
+ /*.nb02 =*/ nb02,
3461
+ /*.nb03 =*/ nb03,
3462
+ /*.ne0 =*/ ne0,
3463
+ /*.ne1 =*/ ne1,
3464
+ /*.ne2 =*/ ne2,
3465
+ /*.ne3 =*/ ne3,
3466
+ /*.nb0 =*/ nb0,
3467
+ /*.nb1 =*/ nb1,
3468
+ /*.nb2 =*/ nb2,
3469
+ /*.nb3 =*/ nb3
3470
+ };
3471
+
3472
  [encoder setComputePipelineState:pipeline];
3473
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3474
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3475
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3476
 
3477
  const int nth = MIN(1024, ne0);
3478
 
 
3487
 
3488
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
3489
 
3490
+ ggml_metal_kargs_pad_reflect_1d args = {
3491
+ /*.ne00 =*/ ne00,
3492
+ /*.ne01 =*/ ne01,
3493
+ /*.ne02 =*/ ne02,
3494
+ /*.ne03 =*/ ne03,
3495
+ /*.nb00 =*/ nb00,
3496
+ /*.nb01 =*/ nb01,
3497
+ /*.nb02 =*/ nb02,
3498
+ /*.nb03 =*/ nb03,
3499
+ /*.ne0 =*/ ne0,
3500
+ /*.ne1 =*/ ne1,
3501
+ /*.ne2 =*/ ne2,
3502
+ /*.ne3 =*/ ne3,
3503
+ /*.nb0 =*/ nb0,
3504
+ /*.nb1 =*/ nb1,
3505
+ /*.nb2 =*/ nb2,
3506
+ /*.nb3 =*/ nb3,
3507
+ /*.p0 =*/ p0,
3508
+ /*.p1 =*/ p1
3509
+ };
3510
+
3511
  [encoder setComputePipelineState:pipeline];
3512
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3513
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3514
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3515
 
3516
  const int nth = MIN(1024, ne0);
3517
 
 
3529
 
3530
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
3531
 
3532
+ ggml_metal_kargs_arange args = {
3533
+ /*.ne0 =*/ ne0,
3534
+ /*.start =*/ start,
3535
+ /*.step =*/ step
3536
+ };
3537
+
3538
  [encoder setComputePipelineState:pipeline];
3539
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
3540
+ [encoder setBytes:&args length:sizeof(args) atIndex:1];
 
 
3541
 
3542
  const int nth = MIN(1024, ne0);
3543
 
 
3554
 
3555
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
3556
 
3557
+ ggml_metal_kargs_timestep_embedding args = {
3558
+ /*.nb1 =*/ nb1,
3559
+ /*.dim =*/ dim,
3560
+ /*.max_period =*/ max_period
3561
+ };
3562
+
3563
  [encoder setComputePipelineState:pipeline];
3564
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3565
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3566
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
 
3567
 
3568
  const int nth = MIN(1024, half);
3569
 
 
3596
  default: GGML_ABORT("fatal error");
3597
  };
3598
 
3599
+ ggml_metal_kargs_argsort args = {
3600
+ /*.ncols =*/ ne00,
3601
+ /*.ncols_pad =*/ ne00_padded
3602
+ };
3603
+
3604
  [encoder setComputePipelineState:pipeline];
3605
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3606
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3607
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
3608
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
3609
 
3610
  [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
 
3618
 
3619
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
3620
 
3621
+ ggml_metal_kargs_leaky_relu args = {
3622
+ /*.slope =*/ slope
3623
+ };
3624
+
3625
  [encoder setComputePipelineState:pipeline];
3626
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3627
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3628
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
3629
 
3630
  const int64_t n = ggml_nelements(dst);
3631
 
 
4201
  const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
4202
  const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
4203
 
4204
+ ggml_metal_kargs_pool_2d args_pool_2d = {
4205
+ /* .k0 = */ k0,
4206
+ /* .k1 = */ k1,
4207
+ /* .s0 = */ s0,
4208
+ /* .s1 = */ s1,
4209
+ /* .p0 = */ p0,
4210
+ /* .p1 = */ p1,
4211
+ /* .IH = */ IH,
4212
+ /* .IW = */ IW,
4213
+ /* .OH = */ OH,
4214
+ /* .OW = */ OW,
4215
+ /* .parallel_elements = */ parallel_elements
4216
+ };
4217
+
4218
  [encoder setComputePipelineState:pipeline];
4219
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
4220
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
4221
+ [encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2];
 
 
 
 
 
 
 
 
 
 
4222
 
4223
  [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
4224
  } break;
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -947,45 +947,22 @@ kernel void kernel_cos(
947
  kernel void kernel_sum_rows(
948
  device const float * src0,
949
  device float * dst,
950
- constant int64_t & ne00,
951
- constant int64_t & ne01,
952
- constant int64_t & ne02,
953
- constant int64_t & ne03,
954
- constant uint64_t & nb00,
955
- constant uint64_t & nb01,
956
- constant uint64_t & nb02,
957
- constant uint64_t & nb03,
958
- constant int64_t & ne10,
959
- constant int64_t & ne11,
960
- constant int64_t & ne12,
961
- constant int64_t & ne13,
962
- constant uint64_t & nb10,
963
- constant uint64_t & nb11,
964
- constant uint64_t & nb12,
965
- constant uint64_t & nb13,
966
- constant int64_t & ne0,
967
- constant int64_t & ne1,
968
- constant int64_t & ne2,
969
- constant int64_t & ne3,
970
- constant uint64_t & nb0,
971
- constant uint64_t & nb1,
972
- constant uint64_t & nb2,
973
- constant uint64_t & nb3,
974
  uint3 tpig[[thread_position_in_grid]]) {
975
  int64_t i3 = tpig.z;
976
  int64_t i2 = tpig.y;
977
  int64_t i1 = tpig.x;
978
 
979
- if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
980
  return;
981
  }
982
 
983
- device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
984
- device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
985
 
986
  float row_sum = 0;
987
 
988
- for (int64_t i0 = 0; i0 < ne00; i0++) {
989
  row_sum += src_row[i0];
990
  }
991
 
@@ -997,36 +974,29 @@ kernel void kernel_soft_max(
997
  device const char * src0,
998
  device const char * src1,
999
  device char * dst,
1000
- constant int64_t & ne00,
1001
- constant int64_t & ne01,
1002
- constant int64_t & ne02,
1003
- constant float & scale,
1004
- constant float & max_bias,
1005
- constant float & m0,
1006
- constant float & m1,
1007
- constant uint32_t & n_head_log2,
1008
  threadgroup float * buf [[threadgroup(0)]],
1009
  uint tgpig[[threadgroup_position_in_grid]],
1010
  uint tpitg[[thread_position_in_threadgroup]],
1011
  uint sgitg[[simdgroup_index_in_threadgroup]],
1012
  uint tiisg[[thread_index_in_simdgroup]],
1013
  uint ntg[[threads_per_threadgroup]]) {
1014
- const int64_t i03 = (tgpig) / (ne02*ne01);
1015
- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
1016
- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
1017
 
1018
- device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
1019
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
1020
- device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
1021
 
1022
  float slope = 1.0f;
1023
 
1024
  // ALiBi
1025
- if (max_bias > 0.0f) {
1026
  const int64_t h = i02;
1027
 
1028
- const float base = h < n_head_log2 ? m0 : m1;
1029
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
1030
 
1031
  slope = pow(base, exp);
1032
  }
@@ -1034,8 +1004,8 @@ kernel void kernel_soft_max(
1034
  // parallel max
1035
  float lmax = -INFINITY;
1036
 
1037
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1038
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
1039
  }
1040
 
1041
  // find the max value in the block
@@ -1059,8 +1029,8 @@ kernel void kernel_soft_max(
1059
 
1060
  // parallel sum
1061
  float lsum = 0.0f;
1062
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1063
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
1064
  lsum += exp_psrc0;
1065
  pdst[i00] = exp_psrc0;
1066
  }
@@ -1090,7 +1060,7 @@ kernel void kernel_soft_max(
1090
 
1091
  const float inv_sum = 1.0f/sum;
1092
 
1093
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1094
  pdst[i00] *= inv_sum;
1095
  }
1096
  }
@@ -1100,35 +1070,28 @@ kernel void kernel_soft_max_4(
1100
  device const char * src0,
1101
  device const char * src1,
1102
  device char * dst,
1103
- constant int64_t & ne00,
1104
- constant int64_t & ne01,
1105
- constant int64_t & ne02,
1106
- constant float & scale,
1107
- constant float & max_bias,
1108
- constant float & m0,
1109
- constant float & m1,
1110
- constant uint32_t & n_head_log2,
1111
  threadgroup float * buf [[threadgroup(0)]],
1112
  uint tgpig[[threadgroup_position_in_grid]],
1113
  uint tpitg[[thread_position_in_threadgroup]],
1114
  uint sgitg[[simdgroup_index_in_threadgroup]],
1115
  uint tiisg[[thread_index_in_simdgroup]],
1116
  uint ntg[[threads_per_threadgroup]]) {
1117
- const int64_t i03 = (tgpig) / (ne02*ne01);
1118
- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
1119
- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
1120
 
1121
- device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
1122
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
1123
- device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
1124
 
1125
  float slope = 1.0f;
1126
 
1127
- if (max_bias > 0.0f) {
1128
  const int64_t h = i02;
1129
 
1130
- const float base = h < n_head_log2 ? m0 : m1;
1131
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
1132
 
1133
  slope = pow(base, exp);
1134
  }
@@ -1136,8 +1099,8 @@ kernel void kernel_soft_max_4(
1136
  // parallel max
1137
  float4 lmax4 = -INFINITY;
1138
 
1139
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1140
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
1141
  }
1142
 
1143
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -1162,8 +1125,8 @@ kernel void kernel_soft_max_4(
1162
 
1163
  // parallel sum
1164
  float4 lsum4 = 0.0f;
1165
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1166
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
1167
  lsum4 += exp_psrc4;
1168
  pdst4[i00] = exp_psrc4;
1169
  }
@@ -1195,7 +1158,7 @@ kernel void kernel_soft_max_4(
1195
 
1196
  const float inv_sum = 1.0f/sum;
1197
 
1198
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1199
  pdst4[i00] *= inv_sum;
1200
  }
1201
  }
@@ -1211,27 +1174,23 @@ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kerne
1211
  kernel void kernel_diag_mask_inf(
1212
  device const float * src0,
1213
  device float * dst,
1214
- constant int64_t & ne00,
1215
- constant int64_t & ne01,
1216
- constant int & n_past,
1217
  uint3 tpig[[thread_position_in_grid]]) {
1218
  const int64_t i02 = tpig[2];
1219
  const int64_t i01 = tpig[1];
1220
  const int64_t i00 = tpig[0];
1221
 
1222
- if (i00 > n_past + i01) {
1223
- dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
1224
  } else {
1225
- dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
1226
  }
1227
  }
1228
 
1229
  kernel void kernel_diag_mask_inf_8(
1230
  device const float4 * src0,
1231
  device float4 * dst,
1232
- constant int64_t & ne00,
1233
- constant int64_t & ne01,
1234
- constant int & n_past,
1235
  uint3 tpig[[thread_position_in_grid]]) {
1236
 
1237
  const int64_t i = 2*tpig[0];
@@ -1239,42 +1198,26 @@ kernel void kernel_diag_mask_inf_8(
1239
  dst[i+0] = src0[i+0];
1240
  dst[i+1] = src0[i+1];
1241
  int64_t i4 = 4*i;
1242
- const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
1243
- const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
1244
  const int64_t i00 = i4;
1245
  for (int k = 3; k >= 0; --k) {
1246
- if (i00 + 4 + k <= n_past + i01) {
1247
  break;
1248
  }
1249
  dst[i+1][k] = -INFINITY;
1250
- if (i00 + k > n_past + i01) {
1251
  dst[i][k] = -INFINITY;
1252
  }
1253
  }
1254
  }
1255
 
1256
  // ref: ggml.c:ggml_compute_forward_ssm_conv_f32
1257
- // TODO: optimize
1258
  kernel void kernel_ssm_conv_f32(
1259
  device const void * src0,
1260
  device const void * src1,
1261
  device float * dst,
1262
- constant int64_t & ne00,
1263
- constant int64_t & ne01,
1264
- constant int64_t & ne02,
1265
- constant uint64_t & nb00,
1266
- constant uint64_t & nb01,
1267
- constant uint64_t & nb02,
1268
- constant int64_t & ne10,
1269
- constant int64_t & ne11,
1270
- constant uint64_t & nb10,
1271
- constant uint64_t & nb11,
1272
- constant int64_t & ne0,
1273
- constant int64_t & ne1,
1274
- constant int64_t & ne2,
1275
- constant uint64_t & nb0,
1276
- constant uint64_t & nb1,
1277
- constant uint64_t & nb2,
1278
  uint3 tgpig[[threadgroup_position_in_grid]],
1279
  uint3 tpitg[[thread_position_in_threadgroup]],
1280
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -1282,15 +1225,15 @@ kernel void kernel_ssm_conv_f32(
1282
  const int64_t i2 = tgpig.y;
1283
  const int64_t i3 = tgpig.z;
1284
 
1285
- const int64_t nc = ne10;
1286
- //const int64_t ncs = ne00;
1287
- //const int64_t nr = ne01;
1288
- //const int64_t n_t = ne1;
1289
- //const int64_t n_s = ne2;
1290
 
1291
- device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
1292
- device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
1293
- device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
1294
 
1295
  float sumf = 0.0f;
1296
 
@@ -1302,7 +1245,6 @@ kernel void kernel_ssm_conv_f32(
1302
  }
1303
 
1304
  // ref: ggml.c:ggml_compute_forward_ssm_scan_f32
1305
- // TODO: optimize
1306
  kernel void kernel_ssm_scan_f32(
1307
  device const void * src0,
1308
  device const void * src1,
@@ -1311,48 +1253,27 @@ kernel void kernel_ssm_scan_f32(
1311
  device const void * src4,
1312
  device const void * src5,
1313
  device float * dst,
1314
- constant int64_t & d_state,
1315
- constant int64_t & d_inner,
1316
- constant int64_t & n_seq_tokens,
1317
- constant int64_t & n_seqs,
1318
- constant uint64_t & nb00,
1319
- constant uint64_t & nb01,
1320
- constant uint64_t & nb02,
1321
- constant uint64_t & nb10,
1322
- constant uint64_t & nb11,
1323
- constant uint64_t & nb12,
1324
- constant uint64_t & nb13,
1325
- constant uint64_t & nb20,
1326
- constant uint64_t & nb21,
1327
- constant uint64_t & nb22,
1328
- constant uint64_t & nb30,
1329
- constant uint64_t & nb31,
1330
- constant uint64_t & nb40,
1331
- constant uint64_t & nb41,
1332
- constant uint64_t & nb42,
1333
- constant uint64_t & nb50,
1334
- constant uint64_t & nb51,
1335
- constant uint64_t & nb52,
1336
  uint3 tgpig[[threadgroup_position_in_grid]],
1337
  uint3 tpitg[[thread_position_in_threadgroup]],
1338
  uint3 ntg[[threads_per_threadgroup]]) {
1339
  const int64_t ir = tgpig.x;
1340
  const int64_t i3 = tgpig.y;
1341
 
1342
- const int64_t nc = d_state;
1343
- //const int64_t nr = d_inner;
1344
- const int64_t n_t = n_seq_tokens;
1345
- //const int64_t n_s = n_seqs;
1346
 
1347
  for (int64_t i2 = 0; i2 < n_t; ++i2) {
1348
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
1349
- device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
1350
- device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
1351
- device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
1352
- device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
1353
- device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
1354
- device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
1355
- device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
1356
 
1357
  if (i2 > 0) {
1358
  s0 = s;
@@ -1545,22 +1466,15 @@ kernel void kernel_rms_norm(
1545
  kernel void kernel_group_norm(
1546
  device const float * src0,
1547
  device float * dst,
1548
- constant int64_t & ne00,
1549
- constant int64_t & ne01,
1550
- constant int64_t & ne02,
1551
- constant uint64_t & nb00,
1552
- constant uint64_t & nb01,
1553
- constant uint64_t & nb02,
1554
- constant int32_t & n_groups,
1555
- constant float & eps,
1556
  threadgroup float * buf [[threadgroup(0)]],
1557
  uint tgpig[[threadgroup_position_in_grid]],
1558
  uint tpitg[[thread_position_in_threadgroup]],
1559
  uint sgitg[[simdgroup_index_in_threadgroup]],
1560
  uint tiisg[[thread_index_in_simdgroup]],
1561
  uint ntg[[threads_per_threadgroup]]) {
1562
- const int64_t ne = ne00*ne01*ne02;
1563
- const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
1564
 
1565
  int start = tgpig * gs;
1566
  int end = start + gs;
@@ -1624,7 +1538,7 @@ kernel void kernel_group_norm(
1624
  }
1625
 
1626
  const float variance = tmp / gs;
1627
- const float scale = 1.0f/sqrt(variance + eps);
1628
  for (int j = start; j < end; j += ntg) {
1629
  dst[j] *= scale;
1630
  }
@@ -2588,17 +2502,7 @@ template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_
2588
  typedef void (im2col_t)(
2589
  device const float * x,
2590
  device char * dst,
2591
- constant int32_t & ofs0,
2592
- constant int32_t & ofs1,
2593
- constant int32_t & IW,
2594
- constant int32_t & IH,
2595
- constant int32_t & CHW,
2596
- constant int32_t & s0,
2597
- constant int32_t & s1,
2598
- constant int32_t & p0,
2599
- constant int32_t & p1,
2600
- constant int32_t & d0,
2601
- constant int32_t & d1,
2602
  uint3 tgpig[[threadgroup_position_in_grid]],
2603
  uint3 tgpg[[threadgroups_per_grid]],
2604
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2608,17 +2512,7 @@ template <typename T>
2608
  kernel void kernel_im2col(
2609
  device const float * x,
2610
  device char * dst,
2611
- constant int32_t & ofs0,
2612
- constant int32_t & ofs1,
2613
- constant int32_t & IW,
2614
- constant int32_t & IH,
2615
- constant int32_t & CHW,
2616
- constant int32_t & s0,
2617
- constant int32_t & s1,
2618
- constant int32_t & p0,
2619
- constant int32_t & p1,
2620
- constant int32_t & d0,
2621
- constant int32_t & d1,
2622
  uint3 tgpig[[threadgroup_position_in_grid]],
2623
  uint3 tgpg[[threadgroups_per_grid]],
2624
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2639,17 +2533,17 @@ kernel void kernel_im2col(
2639
  const int64_t ioh = tgpig[1];
2640
  const int64_t iow = tgpig[2];
2641
 
2642
- const int64_t iiw = iow*s0 + ikw*d0 - p0;
2643
- const int64_t iih = ioh*s1 + ikh*d1 - p1;
2644
 
2645
- const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);
2646
 
2647
  device T * pdst = (device T *) (dst);
2648
 
2649
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
2650
  pdst[offset_dst] = 0.0f;
2651
  } else {
2652
- const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw;
2653
  pdst[offset_dst] = x[offset_src];
2654
  }
2655
  }
@@ -2660,20 +2554,7 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
2660
  typedef void (im2col_ext_t)(
2661
  device const float * x,
2662
  device char * dst,
2663
- constant int32_t & ofs0,
2664
- constant int32_t & ofs1,
2665
- constant int32_t & IW,
2666
- constant int32_t & IH,
2667
- constant int32_t & CHW,
2668
- constant int32_t & s0,
2669
- constant int32_t & s1,
2670
- constant int32_t & p0,
2671
- constant int32_t & p1,
2672
- constant int32_t & d0,
2673
- constant int32_t & d1,
2674
- constant int32_t & N,
2675
- constant int32_t & KH,
2676
- constant int32_t & KW,
2677
  uint3 tgpig[[threadgroup_position_in_grid]],
2678
  uint3 tgpg[[threadgroups_per_grid]],
2679
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2683,53 +2564,40 @@ template <typename T>
2683
  kernel void kernel_im2col_ext(
2684
  device const float * x,
2685
  device char * dst,
2686
- constant int32_t & ofs0,
2687
- constant int32_t & ofs1,
2688
- constant int32_t & IW,
2689
- constant int32_t & IH,
2690
- constant int32_t & CHW,
2691
- constant int32_t & s0,
2692
- constant int32_t & s1,
2693
- constant int32_t & p0,
2694
- constant int32_t & p1,
2695
- constant int32_t & d0,
2696
- constant int32_t & d1,
2697
- constant int32_t & N,
2698
- constant int32_t & KH,
2699
- constant int32_t & KW,
2700
  uint3 tgpig[[threadgroup_position_in_grid]],
2701
  uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
2702
  uint3 tpitg[[thread_position_in_threadgroup]],
2703
  uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
2704
- const int64_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
2705
 
2706
- const int64_t d = tgpig[0] / CHW;
2707
- const int64_t chw = tgpig[0] % CHW;
2708
  const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
2709
  const int64_t HW = tgpig[0] % KHW;
2710
 
2711
  const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
2712
- if (tpitg_0 >= N) {
2713
  return;
2714
  }
2715
 
2716
- const int64_t tpitg_1 = HW / KW;
2717
- const int64_t tpitg_2 = HW % KW;
2718
 
2719
- const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
2720
- const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
2721
 
2722
  const int64_t offset_dst =
2723
- (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
2724
- (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
2725
 
2726
  device T * pdst = (device T *) (dst);
2727
 
2728
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
2729
  pdst[offset_dst] = 0.0f;
2730
  } else {
2731
- const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
2732
- pdst[offset_dst] = x[offset_src + iih * IW + iiw];
2733
  }
2734
  }
2735
 
@@ -2740,12 +2608,7 @@ typedef void (conv_transpose_1d_t)(
2740
  device const float * src0,
2741
  device const float * src1,
2742
  device char * dst,
2743
- constant int32_t & IC,
2744
- constant int32_t & IL,
2745
- constant int32_t & K,
2746
- constant int32_t & s0,
2747
- constant uint64_t & nb0,
2748
- constant uint64_t & nb1,
2749
  uint3 tgpig[[threadgroup_position_in_grid]],
2750
  uint3 tgpg[[threadgroups_per_grid]]);
2751
 
@@ -2754,29 +2617,24 @@ kernel void kernel_conv_transpose_1d(
2754
  device const T * src0,
2755
  device const float * src1,
2756
  device char * dst,
2757
- constant int32_t & IC,
2758
- constant int32_t & IL,
2759
- constant int32_t & K,
2760
- constant int32_t & s0,
2761
- constant uint64_t & nb0,
2762
- constant uint64_t & nb1,
2763
  uint3 tgpig[[threadgroup_position_in_grid]],
2764
  uint3 tgpg[[threadgroups_per_grid]]) {
2765
 
2766
  float v = 0.0f;
2767
 
2768
- for (int64_t c = 0; c < IC; c++) {
2769
- const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
2770
- const int32_t input_offset = c * IL;
2771
 
2772
- for (int64_t i = 0; i < IL; i++) {
2773
- if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
2774
- v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
2775
  }
2776
  }
2777
  }
2778
 
2779
- device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
2780
 
2781
  dst_ptr[0] = v;
2782
  }
@@ -2786,12 +2644,7 @@ kernel void kernel_conv_transpose_1d<float>(
2786
  device const float * src0,
2787
  device const float * src1,
2788
  device char * dst,
2789
- constant int32_t & IC,
2790
- constant int32_t & IL,
2791
- constant int32_t & K,
2792
- constant int32_t & s0,
2793
- constant uint64_t & nb0,
2794
- constant uint64_t & nb1,
2795
  uint3 tgpig[[threadgroup_position_in_grid]],
2796
  uint3 tgpg[[threadgroups_per_grid]]);
2797
 
@@ -2800,38 +2653,14 @@ kernel void kernel_conv_transpose_1d<half>(
2800
  device const half * src0,
2801
  device const float * src1,
2802
  device char * dst,
2803
- constant int32_t & IC,
2804
- constant int32_t & IL,
2805
- constant int32_t & K,
2806
- constant int32_t & s0,
2807
- constant uint64_t & nb0,
2808
- constant uint64_t & nb1,
2809
  uint3 tgpig[[threadgroup_position_in_grid]],
2810
  uint3 tgpg[[threadgroups_per_grid]]);
2811
 
2812
  kernel void kernel_upscale_f32(
2813
  device const char * src0,
2814
  device char * dst,
2815
- constant int64_t & ne00,
2816
- constant int64_t & ne01,
2817
- constant int64_t & ne02,
2818
- constant int64_t & ne03,
2819
- constant uint64_t & nb00,
2820
- constant uint64_t & nb01,
2821
- constant uint64_t & nb02,
2822
- constant uint64_t & nb03,
2823
- constant int64_t & ne0,
2824
- constant int64_t & ne1,
2825
- constant int64_t & ne2,
2826
- constant int64_t & ne3,
2827
- constant uint64_t & nb0,
2828
- constant uint64_t & nb1,
2829
- constant uint64_t & nb2,
2830
- constant uint64_t & nb3,
2831
- constant float & sf0,
2832
- constant float & sf1,
2833
- constant float & sf2,
2834
- constant float & sf3,
2835
  uint3 tgpig[[threadgroup_position_in_grid]],
2836
  uint3 tpitg[[thread_position_in_threadgroup]],
2837
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -2840,15 +2669,15 @@ kernel void kernel_upscale_f32(
2840
  const int64_t i2 = tgpig.y;
2841
  const int64_t i1 = tgpig.x;
2842
 
2843
- const int64_t i03 = i3/sf3;
2844
- const int64_t i02 = i2/sf2;
2845
- const int64_t i01 = i1/sf1;
2846
 
2847
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2848
- const int64_t i00 = i0/sf0;
2849
 
2850
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2851
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2852
 
2853
  dst_ptr[0] = src0_ptr[0];
2854
  }
@@ -2857,22 +2686,7 @@ kernel void kernel_upscale_f32(
2857
  kernel void kernel_pad_f32(
2858
  device const char * src0,
2859
  device char * dst,
2860
- constant int64_t & ne00,
2861
- constant int64_t & ne01,
2862
- constant int64_t & ne02,
2863
- constant int64_t & ne03,
2864
- constant uint64_t & nb00,
2865
- constant uint64_t & nb01,
2866
- constant uint64_t & nb02,
2867
- constant uint64_t & nb03,
2868
- constant int64_t & ne0,
2869
- constant int64_t & ne1,
2870
- constant int64_t & ne2,
2871
- constant int64_t & ne3,
2872
- constant uint64_t & nb0,
2873
- constant uint64_t & nb1,
2874
- constant uint64_t & nb2,
2875
- constant uint64_t & nb3,
2876
  uint3 tgpig[[threadgroup_position_in_grid]],
2877
  uint3 tpitg[[thread_position_in_threadgroup]],
2878
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -2885,12 +2699,12 @@ kernel void kernel_pad_f32(
2885
  const int64_t i02 = i2;
2886
  const int64_t i01 = i1;
2887
 
2888
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
2889
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
2890
 
2891
- if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
2892
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2893
- if (i0 < ne00) {
2894
  dst_ptr[i0] = src0_ptr[i0];
2895
  } else {
2896
  dst_ptr[i0] = 0.0f;
@@ -2900,7 +2714,7 @@ kernel void kernel_pad_f32(
2900
  return;
2901
  }
2902
 
2903
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2904
  dst_ptr[i0] = 0.0f;
2905
  }
2906
  }
@@ -2908,21 +2722,7 @@ kernel void kernel_pad_f32(
2908
  kernel void kernel_pad_reflect_1d_f32(
2909
  device const char * src0,
2910
  device char * dst,
2911
- constant int64_t & ne00,
2912
- constant int64_t & ne01,
2913
- constant int64_t & ne02,
2914
- constant int64_t & ne03,
2915
- constant int64_t & ne0,
2916
- constant uint64_t & nb00,
2917
- constant uint64_t & nb01,
2918
- constant uint64_t & nb02,
2919
- constant uint64_t & nb03,
2920
- constant uint64_t & nb0,
2921
- constant uint64_t & nb1,
2922
- constant uint64_t & nb2,
2923
- constant uint64_t & nb3,
2924
- constant int32_t & p0,
2925
- constant int32_t & p1,
2926
  uint3 tgpig[[threadgroup_position_in_grid]],
2927
  uint3 tgpg[[threadgroups_per_grid]],
2928
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2936,17 +2736,17 @@ kernel void kernel_pad_reflect_1d_f32(
2936
  const int64_t i02 = i2;
2937
  const int64_t i01 = i1;
2938
 
2939
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
2940
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
2941
 
2942
- if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
2943
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2944
- if (i0 < p0) {
2945
- dst_ptr[i0] = src0_ptr[p0 - i0];
2946
- } else if (i0 < ne0 - p1) {
2947
- dst_ptr[i0] = src0_ptr[i0 - p0];
2948
  } else {
2949
- dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
2950
  }
2951
  }
2952
  }
@@ -2954,44 +2754,40 @@ kernel void kernel_pad_reflect_1d_f32(
2954
 
2955
  kernel void kernel_arange_f32(
2956
  device char * dst,
2957
- constant int64_t & ne0,
2958
- constant float & start,
2959
- constant float & step,
2960
  uint3 tgpig[[threadgroup_position_in_grid]],
2961
  uint3 tpitg[[thread_position_in_threadgroup]],
2962
  uint3 ntg[[threads_per_threadgroup]]) {
2963
 
2964
  device float * dst_ptr = (device float *) dst;
2965
 
2966
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2967
- dst_ptr[i0] = start + step * i0;
2968
  }
2969
  }
2970
 
2971
  kernel void kernel_timestep_embedding_f32(
2972
  device const char * src0,
2973
  device char * dst,
2974
- constant uint64_t & nb1,
2975
- constant int & dim,
2976
- constant int & max_period,
2977
  uint3 tgpig[[threadgroup_position_in_grid]],
2978
  uint3 tpitg[[thread_position_in_threadgroup]],
2979
  uint3 ntg[[threads_per_threadgroup]]) {
2980
 
2981
  int i = tgpig.x;
2982
- device float * embed_data = (device float *)(dst + i*nb1);
2983
 
2984
- int half_ = dim / 2;
2985
  for (int j = tpitg.x; j < half_; j += ntg.x) {
2986
  float timestep = ((device float *)src0)[i];
2987
- float freq = (float)exp(-log((float)max_period) * j / half_);
2988
  float arg = timestep * freq;
2989
  embed_data[j ] = cos(arg);
2990
  embed_data[j + half_] = sin(arg);
2991
  }
2992
 
2993
- if (dim % 2 != 0 && tpitg.x == 0) {
2994
- embed_data[dim] = 0.f;
2995
  }
2996
  }
2997
 
@@ -2999,8 +2795,7 @@ kernel void kernel_timestep_embedding_f32(
2999
  typedef void (argsort_t)(
3000
  device const float * x,
3001
  device int32_t * dst,
3002
- constant int64_t & ncols,
3003
- constant int64_t & ncols_pad,
3004
  threadgroup int32_t * shared_values [[threadgroup(0)]],
3005
  uint3 tgpig[[threadgroup_position_in_grid]],
3006
  uint3 tpitg[[thread_position_in_threadgroup]]);
@@ -3009,8 +2804,7 @@ template<ggml_sort_order order>
3009
  kernel void kernel_argsort_f32_i32(
3010
  device const float * x,
3011
  device int32_t * dst,
3012
- constant int64_t & ncols,
3013
- constant int64_t & ncols_pad,
3014
  threadgroup int32_t * shared_values [[threadgroup(0)]],
3015
  uint3 tgpig[[threadgroup_position_in_grid]],
3016
  uint3 tpitg[[thread_position_in_threadgroup]]) {
@@ -3018,9 +2812,9 @@ kernel void kernel_argsort_f32_i32(
3018
  int col = tpitg[0];
3019
  int row = tgpig[1];
3020
 
3021
- if (col >= ncols_pad) return;
3022
 
3023
- device const float * x_row = x + row * ncols;
3024
  threadgroup int32_t * dst_row = shared_values;
3025
 
3026
  // initialize indices
@@ -3028,21 +2822,21 @@ kernel void kernel_argsort_f32_i32(
3028
 
3029
  threadgroup_barrier(mem_flags::mem_threadgroup);
3030
 
3031
- for (int k = 2; k <= ncols_pad; k *= 2) {
3032
  for (int j = k / 2; j > 0; j /= 2) {
3033
  int ixj = col ^ j;
3034
  if (ixj > col) {
3035
  if ((col & k) == 0) {
3036
- if (dst_row[col] >= ncols ||
3037
- (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
3038
  x_row[dst_row[col]] > x_row[dst_row[ixj]] :
3039
  x_row[dst_row[col]] < x_row[dst_row[ixj]]))
3040
  ) {
3041
  SWAP(dst_row[col], dst_row[ixj]);
3042
  }
3043
  } else {
3044
- if (dst_row[ixj] >= ncols ||
3045
- (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
3046
  x_row[dst_row[col]] < x_row[dst_row[ixj]] :
3047
  x_row[dst_row[col]] > x_row[dst_row[ixj]]))
3048
  ) {
@@ -3055,8 +2849,8 @@ kernel void kernel_argsort_f32_i32(
3055
  }
3056
 
3057
  // copy the result to dst without the padding
3058
- if (col < ncols) {
3059
- dst[row * ncols + col] = dst_row[col];
3060
  }
3061
  }
3062
 
@@ -3066,9 +2860,9 @@ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_ar
3066
  kernel void kernel_leaky_relu_f32(
3067
  device const float * src0,
3068
  device float * dst,
3069
- constant float & slope,
3070
  uint tpig[[thread_position_in_grid]]) {
3071
- dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
3072
  }
3073
 
3074
  // ref: https://arxiv.org/pdf/2307.08691.pdf
@@ -6009,28 +5803,21 @@ kernel void kernel_get_rows_q(
6009
  device const void * src0,
6010
  device const void * src1,
6011
  device float * dst,
6012
- constant int64_t & ne00,
6013
- constant uint64_t & nb01,
6014
- constant uint64_t & nb02,
6015
- constant int64_t & ne10,
6016
- constant uint64_t & nb10,
6017
- constant uint64_t & nb11,
6018
- constant uint64_t & nb1,
6019
- constant uint64_t & nb2,
6020
  uint3 tgpig[[threadgroup_position_in_grid]],
6021
  uint tiitg[[thread_index_in_threadgroup]],
6022
  uint3 tptg [[threads_per_threadgroup]]) {
6023
  const int64_t i10 = tgpig.x;
6024
  const int64_t i11 = tgpig.y;
6025
 
6026
- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
6027
 
6028
  const int64_t i02 = i11;
6029
 
6030
- for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
6031
  float4x4 temp;
6032
- dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
6033
- *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
6034
  }
6035
  }
6036
 
@@ -6039,27 +5826,20 @@ kernel void kernel_get_rows_f(
6039
  device const void * src0,
6040
  device const void * src1,
6041
  device float * dst,
6042
- constant int64_t & ne00,
6043
- constant uint64_t & nb01,
6044
- constant uint64_t & nb02,
6045
- constant int64_t & ne10,
6046
- constant uint64_t & nb10,
6047
- constant uint64_t & nb11,
6048
- constant uint64_t & nb1,
6049
- constant uint64_t & nb2,
6050
  uint3 tgpig[[threadgroup_position_in_grid]],
6051
  uint tiitg[[thread_index_in_threadgroup]],
6052
  uint3 tptg [[threads_per_threadgroup]]) {
6053
  const int64_t i10 = tgpig.x;
6054
  const int64_t i11 = tgpig.y;
6055
 
6056
- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
6057
 
6058
  const int64_t i02 = i11;
6059
 
6060
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
6061
- (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
6062
- ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
6063
  }
6064
  }
6065
 
@@ -6067,27 +5847,20 @@ kernel void kernel_get_rows_i32(
6067
  device const void * src0,
6068
  device const void * src1,
6069
  device int32_t * dst,
6070
- constant int64_t & ne00,
6071
- constant uint64_t & nb01,
6072
- constant uint64_t & nb02,
6073
- constant int64_t & ne10,
6074
- constant uint64_t & nb10,
6075
- constant uint64_t & nb11,
6076
- constant uint64_t & nb1,
6077
- constant uint64_t & nb2,
6078
  uint3 tgpig[[threadgroup_position_in_grid]],
6079
  uint tiitg[[thread_index_in_threadgroup]],
6080
  uint3 tptg [[threads_per_threadgroup]]) {
6081
  const int64_t i10 = tgpig.x;
6082
  const int64_t i11 = tgpig.y;
6083
 
6084
- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
6085
 
6086
  const int64_t i02 = i11;
6087
 
6088
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
6089
- (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
6090
- ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
6091
  }
6092
  }
6093
 
@@ -6689,98 +6462,78 @@ template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t
6689
  kernel void kernel_pool_2d_max_f32(
6690
  device const float * src0,
6691
  device float * dst,
6692
- constant int32_t & k0,
6693
- constant int32_t & k1,
6694
- constant int32_t & s0,
6695
- constant int32_t & s1,
6696
- constant int32_t & p0,
6697
- constant int32_t & p1,
6698
- constant int64_t & IH,
6699
- constant int64_t & IW,
6700
- constant int64_t & OH,
6701
- constant int64_t & OW,
6702
- constant int64_t & parallel_elements,
6703
  uint gid[[thread_position_in_grid]]) {
6704
 
6705
- if (gid >= parallel_elements) {
6706
  return;
6707
  }
6708
 
6709
  const int idx = gid;
6710
- const int I_HW = IH * IW;
6711
- const int O_HW = OH * OW;
6712
  const int nc = idx / O_HW;
6713
- const int cur_oh = idx % O_HW / OW;
6714
- const int cur_ow = idx % O_HW % OW;
6715
 
6716
  device const float * i_ptr = src0 + nc * I_HW;
6717
  device float * o_ptr = dst + nc * O_HW;
6718
 
6719
- const int start_h = cur_oh * s1 - p1;
6720
  const int bh = MAX(0, start_h);
6721
- const int eh = MIN(IH, start_h + k1);
6722
- const int start_w = cur_ow * s0 - p0;
6723
  const int bw = MAX(0, start_w);
6724
- const int ew = MIN(IW, start_w + k0);
6725
 
6726
  float res = -INFINITY;
6727
 
6728
  for (int i = bh; i < eh; i += 1) {
6729
  for (int j = bw; j < ew; j += 1) {
6730
- res = MAX(res, i_ptr[i * IW + j]);
6731
  }
6732
  }
6733
 
6734
- o_ptr[cur_oh * OW + cur_ow] = res;
6735
  }
6736
 
6737
  kernel void kernel_pool_2d_avg_f32(
6738
  device const float * src0,
6739
  device float * dst,
6740
- constant int32_t & k0,
6741
- constant int32_t & k1,
6742
- constant int32_t & s0,
6743
- constant int32_t & s1,
6744
- constant int32_t & p0,
6745
- constant int32_t & p1,
6746
- constant int64_t & IH,
6747
- constant int64_t & IW,
6748
- constant int64_t & OH,
6749
- constant int64_t & OW,
6750
- constant int64_t & parallel_elements,
6751
  uint gid[[thread_position_in_grid]]) {
6752
 
6753
- if (gid >= parallel_elements) {
6754
  return;
6755
  }
6756
 
6757
  const int idx = gid;
6758
- const int I_HW = IH * IW;
6759
- const int O_HW = OH * OW;
6760
  const int nc = idx / O_HW;
6761
- const int cur_oh = idx % O_HW / OW;
6762
- const int cur_ow = idx % O_HW % OW;
6763
 
6764
  device const float * i_ptr = src0 + nc * I_HW;
6765
  device float * o_ptr = dst + nc * O_HW;
6766
 
6767
- const int start_h = cur_oh * s1 - p1;
6768
  const int bh = MAX(0, start_h);
6769
- const int eh = MIN(IH, start_h + k1);
6770
- const int start_w = cur_ow * s0 - p0;
6771
  const int bw = MAX(0, start_w);
6772
- const int ew = MIN(IW, start_w + k0);
6773
  // const float scale = 1. / ((eh - bh) * (ew - bw));
6774
- const float scale = 1. / (k0 * k1);
6775
 
6776
  float res = 0;
6777
 
6778
  for (int i = bh; i < eh; i += 1) {
6779
  for (int j = bw; j < ew; j += 1) {
6780
- float cur = i_ptr[i * IW + j];
6781
  res += cur * scale;
6782
  }
6783
  }
6784
 
6785
- o_ptr[cur_oh * OW + cur_ow] = res;
6786
  }
 
947
  kernel void kernel_sum_rows(
948
  device const float * src0,
949
  device float * dst,
950
+ constant ggml_metal_kargs_sum_rows & args,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
951
  uint3 tpig[[thread_position_in_grid]]) {
952
  int64_t i3 = tpig.z;
953
  int64_t i2 = tpig.y;
954
  int64_t i1 = tpig.x;
955
 
956
+ if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
957
  return;
958
  }
959
 
960
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
961
+ device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
962
 
963
  float row_sum = 0;
964
 
965
+ for (int64_t i0 = 0; i0 < args.ne00; i0++) {
966
  row_sum += src_row[i0];
967
  }
968
 
 
974
  device const char * src0,
975
  device const char * src1,
976
  device char * dst,
977
+ constant ggml_metal_kargs_soft_max & args,
 
 
 
 
 
 
 
978
  threadgroup float * buf [[threadgroup(0)]],
979
  uint tgpig[[threadgroup_position_in_grid]],
980
  uint tpitg[[thread_position_in_threadgroup]],
981
  uint sgitg[[simdgroup_index_in_threadgroup]],
982
  uint tiisg[[thread_index_in_simdgroup]],
983
  uint ntg[[threads_per_threadgroup]]) {
984
+ const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
985
+ const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
986
+ const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
987
 
988
+ device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
989
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
990
+ device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
991
 
992
  float slope = 1.0f;
993
 
994
  // ALiBi
995
+ if (args.max_bias > 0.0f) {
996
  const int64_t h = i02;
997
 
998
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
999
+ const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
1000
 
1001
  slope = pow(base, exp);
1002
  }
 
1004
  // parallel max
1005
  float lmax = -INFINITY;
1006
 
1007
+ for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1008
+ lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
1009
  }
1010
 
1011
  // find the max value in the block
 
1029
 
1030
  // parallel sum
1031
  float lsum = 0.0f;
1032
+ for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1033
+ const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
1034
  lsum += exp_psrc0;
1035
  pdst[i00] = exp_psrc0;
1036
  }
 
1060
 
1061
  const float inv_sum = 1.0f/sum;
1062
 
1063
+ for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1064
  pdst[i00] *= inv_sum;
1065
  }
1066
  }
 
1070
  device const char * src0,
1071
  device const char * src1,
1072
  device char * dst,
1073
+ constant ggml_metal_kargs_soft_max & args,
 
 
 
 
 
 
 
1074
  threadgroup float * buf [[threadgroup(0)]],
1075
  uint tgpig[[threadgroup_position_in_grid]],
1076
  uint tpitg[[thread_position_in_threadgroup]],
1077
  uint sgitg[[simdgroup_index_in_threadgroup]],
1078
  uint tiisg[[thread_index_in_simdgroup]],
1079
  uint ntg[[threads_per_threadgroup]]) {
1080
+ const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
1081
+ const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
1082
+ const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
1083
 
1084
+ device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
1085
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
1086
+ device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
1087
 
1088
  float slope = 1.0f;
1089
 
1090
+ if (args.max_bias > 0.0f) {
1091
  const int64_t h = i02;
1092
 
1093
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1094
+ const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
1095
 
1096
  slope = pow(base, exp);
1097
  }
 
1099
  // parallel max
1100
  float4 lmax4 = -INFINITY;
1101
 
1102
+ for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1103
+ lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
1104
  }
1105
 
1106
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
 
1125
 
1126
  // parallel sum
1127
  float4 lsum4 = 0.0f;
1128
+ for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1129
+ const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
1130
  lsum4 += exp_psrc4;
1131
  pdst4[i00] = exp_psrc4;
1132
  }
 
1158
 
1159
  const float inv_sum = 1.0f/sum;
1160
 
1161
+ for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1162
  pdst4[i00] *= inv_sum;
1163
  }
1164
  }
 
1174
  kernel void kernel_diag_mask_inf(
1175
  device const float * src0,
1176
  device float * dst,
1177
+ constant ggml_metal_kargs_diag_mask_inf & args,
 
 
1178
  uint3 tpig[[thread_position_in_grid]]) {
1179
  const int64_t i02 = tpig[2];
1180
  const int64_t i01 = tpig[1];
1181
  const int64_t i00 = tpig[0];
1182
 
1183
+ if (i00 > args.n_past + i01) {
1184
+ dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY;
1185
  } else {
1186
+ dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00];
1187
  }
1188
  }
1189
 
1190
  kernel void kernel_diag_mask_inf_8(
1191
  device const float4 * src0,
1192
  device float4 * dst,
1193
+ constant ggml_metal_kargs_diag_mask_inf & args,
 
 
1194
  uint3 tpig[[thread_position_in_grid]]) {
1195
 
1196
  const int64_t i = 2*tpig[0];
 
1198
  dst[i+0] = src0[i+0];
1199
  dst[i+1] = src0[i+1];
1200
  int64_t i4 = 4*i;
1201
+ const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01;
1202
+ const int64_t i01 = i4/(args.ne00); i4 -= i01*args.ne00;
1203
  const int64_t i00 = i4;
1204
  for (int k = 3; k >= 0; --k) {
1205
+ if (i00 + 4 + k <= args.n_past + i01) {
1206
  break;
1207
  }
1208
  dst[i+1][k] = -INFINITY;
1209
+ if (i00 + k > args.n_past + i01) {
1210
  dst[i][k] = -INFINITY;
1211
  }
1212
  }
1213
  }
1214
 
1215
  // ref: ggml.c:ggml_compute_forward_ssm_conv_f32
 
1216
  kernel void kernel_ssm_conv_f32(
1217
  device const void * src0,
1218
  device const void * src1,
1219
  device float * dst,
1220
+ constant ggml_metal_kargs_ssm_conv & args,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1221
  uint3 tgpig[[threadgroup_position_in_grid]],
1222
  uint3 tpitg[[thread_position_in_threadgroup]],
1223
  uint3 ntg[[threads_per_threadgroup]]) {
 
1225
  const int64_t i2 = tgpig.y;
1226
  const int64_t i3 = tgpig.z;
1227
 
1228
+ const int64_t nc = args.ne10;
1229
+ //const int64_t ncs = args.ne00;
1230
+ //const int64_t nr = args.ne01;
1231
+ //const int64_t n_t = args.ne1;
1232
+ //const int64_t n_s = args.ne2;
1233
 
1234
+ device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
1235
+ device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
1236
+ device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
1237
 
1238
  float sumf = 0.0f;
1239
 
 
1245
  }
1246
 
1247
  // ref: ggml.c:ggml_compute_forward_ssm_scan_f32
 
1248
  kernel void kernel_ssm_scan_f32(
1249
  device const void * src0,
1250
  device const void * src1,
 
1253
  device const void * src4,
1254
  device const void * src5,
1255
  device float * dst,
1256
+ constant ggml_metal_kargs_ssm_scan & args,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1257
  uint3 tgpig[[threadgroup_position_in_grid]],
1258
  uint3 tpitg[[thread_position_in_threadgroup]],
1259
  uint3 ntg[[threads_per_threadgroup]]) {
1260
  const int64_t ir = tgpig.x;
1261
  const int64_t i3 = tgpig.y;
1262
 
1263
+ const int64_t nc = args.d_state;
1264
+ // const int64_t nr = args.d_inner;
1265
+ const int64_t n_t = args.n_seq_tokens;
1266
+ // const int64_t n_s = args.n_seqs;
1267
 
1268
  for (int64_t i2 = 0; i2 < n_t; ++i2) {
1269
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
1270
+ device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
1271
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
1272
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
1273
+ device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
1274
+ device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
1275
+ device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
1276
+ device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
1277
 
1278
  if (i2 > 0) {
1279
  s0 = s;
 
1466
  kernel void kernel_group_norm(
1467
  device const float * src0,
1468
  device float * dst,
1469
+ constant ggml_metal_kargs_group_norm & args,
 
 
 
 
 
 
 
1470
  threadgroup float * buf [[threadgroup(0)]],
1471
  uint tgpig[[threadgroup_position_in_grid]],
1472
  uint tpitg[[thread_position_in_threadgroup]],
1473
  uint sgitg[[simdgroup_index_in_threadgroup]],
1474
  uint tiisg[[thread_index_in_simdgroup]],
1475
  uint ntg[[threads_per_threadgroup]]) {
1476
+ const int64_t ne = args.ne00*args.ne01*args.ne02;
1477
+ const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups);
1478
 
1479
  int start = tgpig * gs;
1480
  int end = start + gs;
 
1538
  }
1539
 
1540
  const float variance = tmp / gs;
1541
+ const float scale = 1.0f/sqrt(variance + args.eps);
1542
  for (int j = start; j < end; j += ntg) {
1543
  dst[j] *= scale;
1544
  }
 
2502
  typedef void (im2col_t)(
2503
  device const float * x,
2504
  device char * dst,
2505
+ constant ggml_metal_kargs_im2col & args,
 
 
 
 
 
 
 
 
 
 
2506
  uint3 tgpig[[threadgroup_position_in_grid]],
2507
  uint3 tgpg[[threadgroups_per_grid]],
2508
  uint3 tpitg[[thread_position_in_threadgroup]],
 
2512
  kernel void kernel_im2col(
2513
  device const float * x,
2514
  device char * dst,
2515
+ constant ggml_metal_kargs_im2col & args,
 
 
 
 
 
 
 
 
 
 
2516
  uint3 tgpig[[threadgroup_position_in_grid]],
2517
  uint3 tgpg[[threadgroups_per_grid]],
2518
  uint3 tpitg[[thread_position_in_threadgroup]],
 
2533
  const int64_t ioh = tgpig[1];
2534
  const int64_t iow = tgpig[2];
2535
 
2536
+ const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
2537
+ const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
2538
 
2539
+ const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
2540
 
2541
  device T * pdst = (device T *) (dst);
2542
 
2543
+ if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
2544
  pdst[offset_dst] = 0.0f;
2545
  } else {
2546
+ const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
2547
  pdst[offset_dst] = x[offset_src];
2548
  }
2549
  }
 
2554
  typedef void (im2col_ext_t)(
2555
  device const float * x,
2556
  device char * dst,
2557
+ constant ggml_metal_kargs_im2col & args,
 
 
 
 
 
 
 
 
 
 
 
 
 
2558
  uint3 tgpig[[threadgroup_position_in_grid]],
2559
  uint3 tgpg[[threadgroups_per_grid]],
2560
  uint3 tpitg[[thread_position_in_threadgroup]],
 
2564
  kernel void kernel_im2col_ext(
2565
  device const float * x,
2566
  device char * dst,
2567
+ constant ggml_metal_kargs_im2col & args,
 
 
 
 
 
 
 
 
 
 
 
 
 
2568
  uint3 tgpig[[threadgroup_position_in_grid]],
2569
  uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
2570
  uint3 tpitg[[thread_position_in_threadgroup]],
2571
  uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
2572
+ const int64_t KHW = (int64_t)args.KHW;
2573
 
2574
+ const int64_t d = tgpig[0] / args.CHW;
2575
+ const int64_t chw = tgpig[0] % args.CHW;
2576
  const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
2577
  const int64_t HW = tgpig[0] % KHW;
2578
 
2579
  const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
2580
+ if (tpitg_0 >= args.N) {
2581
  return;
2582
  }
2583
 
2584
+ const int64_t tpitg_1 = HW / args.KW;
2585
+ const int64_t tpitg_2 = HW % args.KW;
2586
 
2587
+ const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
2588
+ const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
2589
 
2590
  const int64_t offset_dst =
2591
+ (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
2592
+ (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
2593
 
2594
  device T * pdst = (device T *) (dst);
2595
 
2596
+ if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
2597
  pdst[offset_dst] = 0.0f;
2598
  } else {
2599
+ const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
2600
+ pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
2601
  }
2602
  }
2603
 
 
2608
  device const float * src0,
2609
  device const float * src1,
2610
  device char * dst,
2611
+ constant ggml_metal_kargs_conv_transpose_1d & args,
 
 
 
 
 
2612
  uint3 tgpig[[threadgroup_position_in_grid]],
2613
  uint3 tgpg[[threadgroups_per_grid]]);
2614
 
 
2617
  device const T * src0,
2618
  device const float * src1,
2619
  device char * dst,
2620
+ constant ggml_metal_kargs_conv_transpose_1d & args,
 
 
 
 
 
2621
  uint3 tgpig[[threadgroup_position_in_grid]],
2622
  uint3 tgpg[[threadgroups_per_grid]]) {
2623
 
2624
  float v = 0.0f;
2625
 
2626
+ for (int64_t c = 0; c < args.IC; c++) {
2627
+ const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
2628
+ const int32_t input_offset = c * args.IL;
2629
 
2630
+ for (int64_t i = 0; i < args.IL; i++) {
2631
+ if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
2632
+ v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
2633
  }
2634
  }
2635
  }
2636
 
2637
+ device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
2638
 
2639
  dst_ptr[0] = v;
2640
  }
 
2644
  device const float * src0,
2645
  device const float * src1,
2646
  device char * dst,
2647
+ constant ggml_metal_kargs_conv_transpose_1d & args,
 
 
 
 
 
2648
  uint3 tgpig[[threadgroup_position_in_grid]],
2649
  uint3 tgpg[[threadgroups_per_grid]]);
2650
 
 
2653
  device const half * src0,
2654
  device const float * src1,
2655
  device char * dst,
2656
+ constant ggml_metal_kargs_conv_transpose_1d & args,
 
 
 
 
 
2657
  uint3 tgpig[[threadgroup_position_in_grid]],
2658
  uint3 tgpg[[threadgroups_per_grid]]);
2659
 
2660
  kernel void kernel_upscale_f32(
2661
  device const char * src0,
2662
  device char * dst,
2663
+ constant ggml_metal_kargs_upscale & args,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2664
  uint3 tgpig[[threadgroup_position_in_grid]],
2665
  uint3 tpitg[[thread_position_in_threadgroup]],
2666
  uint3 ntg[[threads_per_threadgroup]]) {
 
2669
  const int64_t i2 = tgpig.y;
2670
  const int64_t i1 = tgpig.x;
2671
 
2672
+ const int64_t i03 = i3/args.sf3;
2673
+ const int64_t i02 = i2/args.sf2;
2674
+ const int64_t i01 = i1/args.sf1;
2675
 
2676
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
2677
+ const int64_t i00 = i0/args.sf0;
2678
 
2679
+ device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
2680
+ device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
2681
 
2682
  dst_ptr[0] = src0_ptr[0];
2683
  }
 
2686
  kernel void kernel_pad_f32(
2687
  device const char * src0,
2688
  device char * dst,
2689
+ constant ggml_metal_kargs_pad & args,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2690
  uint3 tgpig[[threadgroup_position_in_grid]],
2691
  uint3 tpitg[[thread_position_in_threadgroup]],
2692
  uint3 ntg[[threads_per_threadgroup]]) {
 
2699
  const int64_t i02 = i2;
2700
  const int64_t i01 = i1;
2701
 
2702
+ device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
2703
+ device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
2704
 
2705
+ if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
2706
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
2707
+ if (i0 < args.ne00) {
2708
  dst_ptr[i0] = src0_ptr[i0];
2709
  } else {
2710
  dst_ptr[i0] = 0.0f;
 
2714
  return;
2715
  }
2716
 
2717
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
2718
  dst_ptr[i0] = 0.0f;
2719
  }
2720
  }
 
2722
  kernel void kernel_pad_reflect_1d_f32(
2723
  device const char * src0,
2724
  device char * dst,
2725
+ constant ggml_metal_kargs_pad_reflect_1d & args,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2726
  uint3 tgpig[[threadgroup_position_in_grid]],
2727
  uint3 tgpg[[threadgroups_per_grid]],
2728
  uint3 tpitg[[thread_position_in_threadgroup]],
 
2736
  const int64_t i02 = i2;
2737
  const int64_t i01 = i1;
2738
 
2739
+ device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
2740
+ device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
2741
 
2742
+ if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
2743
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
2744
+ if (i0 < args.p0) {
2745
+ dst_ptr[i0] = src0_ptr[args.p0 - i0];
2746
+ } else if (i0 < args.ne0 - args.p1) {
2747
+ dst_ptr[i0] = src0_ptr[i0 - args.p0];
2748
  } else {
2749
+ dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
2750
  }
2751
  }
2752
  }
 
2754
 
2755
  kernel void kernel_arange_f32(
2756
  device char * dst,
2757
+ constant ggml_metal_kargs_arange & args,
 
 
2758
  uint3 tgpig[[threadgroup_position_in_grid]],
2759
  uint3 tpitg[[thread_position_in_threadgroup]],
2760
  uint3 ntg[[threads_per_threadgroup]]) {
2761
 
2762
  device float * dst_ptr = (device float *) dst;
2763
 
2764
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
2765
+ dst_ptr[i0] = args.start + args.step * i0;
2766
  }
2767
  }
2768
 
2769
  kernel void kernel_timestep_embedding_f32(
2770
  device const char * src0,
2771
  device char * dst,
2772
+ constant ggml_metal_kargs_timestep_embedding & args,
 
 
2773
  uint3 tgpig[[threadgroup_position_in_grid]],
2774
  uint3 tpitg[[thread_position_in_threadgroup]],
2775
  uint3 ntg[[threads_per_threadgroup]]) {
2776
 
2777
  int i = tgpig.x;
2778
+ device float * embed_data = (device float *)(dst + i*args.nb1);
2779
 
2780
+ int half_ = args.dim / 2;
2781
  for (int j = tpitg.x; j < half_; j += ntg.x) {
2782
  float timestep = ((device float *)src0)[i];
2783
+ float freq = (float)exp(-log((float)args.max_period) * j / half_);
2784
  float arg = timestep * freq;
2785
  embed_data[j ] = cos(arg);
2786
  embed_data[j + half_] = sin(arg);
2787
  }
2788
 
2789
+ if (args.dim % 2 != 0 && tpitg.x == 0) {
2790
+ embed_data[args.dim] = 0.f;
2791
  }
2792
  }
2793
 
 
2795
  typedef void (argsort_t)(
2796
  device const float * x,
2797
  device int32_t * dst,
2798
+ constant ggml_metal_kargs_argsort & args,
 
2799
  threadgroup int32_t * shared_values [[threadgroup(0)]],
2800
  uint3 tgpig[[threadgroup_position_in_grid]],
2801
  uint3 tpitg[[thread_position_in_threadgroup]]);
 
2804
  kernel void kernel_argsort_f32_i32(
2805
  device const float * x,
2806
  device int32_t * dst,
2807
+ constant ggml_metal_kargs_argsort & args,
 
2808
  threadgroup int32_t * shared_values [[threadgroup(0)]],
2809
  uint3 tgpig[[threadgroup_position_in_grid]],
2810
  uint3 tpitg[[thread_position_in_threadgroup]]) {
 
2812
  int col = tpitg[0];
2813
  int row = tgpig[1];
2814
 
2815
+ if (col >= args.ncols_pad) return;
2816
 
2817
+ device const float * x_row = x + row * args.ncols;
2818
  threadgroup int32_t * dst_row = shared_values;
2819
 
2820
  // initialize indices
 
2822
 
2823
  threadgroup_barrier(mem_flags::mem_threadgroup);
2824
 
2825
+ for (int k = 2; k <= args.ncols_pad; k *= 2) {
2826
  for (int j = k / 2; j > 0; j /= 2) {
2827
  int ixj = col ^ j;
2828
  if (ixj > col) {
2829
  if ((col & k) == 0) {
2830
+ if (dst_row[col] >= args.ncols ||
2831
+ (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
2832
  x_row[dst_row[col]] > x_row[dst_row[ixj]] :
2833
  x_row[dst_row[col]] < x_row[dst_row[ixj]]))
2834
  ) {
2835
  SWAP(dst_row[col], dst_row[ixj]);
2836
  }
2837
  } else {
2838
+ if (dst_row[ixj] >= args.ncols ||
2839
+ (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
2840
  x_row[dst_row[col]] < x_row[dst_row[ixj]] :
2841
  x_row[dst_row[col]] > x_row[dst_row[ixj]]))
2842
  ) {
 
2849
  }
2850
 
2851
  // copy the result to dst without the padding
2852
+ if (col < args.ncols) {
2853
+ dst[row * args.ncols + col] = dst_row[col];
2854
  }
2855
  }
2856
 
 
2860
  kernel void kernel_leaky_relu_f32(
2861
  device const float * src0,
2862
  device float * dst,
2863
+ constant ggml_metal_kargs_leaky_relu & args,
2864
  uint tpig[[thread_position_in_grid]]) {
2865
+ dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope;
2866
  }
2867
 
2868
  // ref: https://arxiv.org/pdf/2307.08691.pdf
 
5803
  device const void * src0,
5804
  device const void * src1,
5805
  device float * dst,
5806
+ constant ggml_metal_kargs_get_rows & args,
 
 
 
 
 
 
 
5807
  uint3 tgpig[[threadgroup_position_in_grid]],
5808
  uint tiitg[[thread_index_in_threadgroup]],
5809
  uint3 tptg [[threads_per_threadgroup]]) {
5810
  const int64_t i10 = tgpig.x;
5811
  const int64_t i11 = tgpig.y;
5812
 
5813
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
5814
 
5815
  const int64_t i02 = i11;
5816
 
5817
+ for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) {
5818
  float4x4 temp;
5819
+ dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp);
5820
+ *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp;
5821
  }
5822
  }
5823
 
 
5826
  device const void * src0,
5827
  device const void * src1,
5828
  device float * dst,
5829
+ constant ggml_metal_kargs_get_rows & args,
 
 
 
 
 
 
 
5830
  uint3 tgpig[[threadgroup_position_in_grid]],
5831
  uint tiitg[[thread_index_in_threadgroup]],
5832
  uint3 tptg [[threads_per_threadgroup]]) {
5833
  const int64_t i10 = tgpig.x;
5834
  const int64_t i11 = tgpig.y;
5835
 
5836
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
5837
 
5838
  const int64_t i02 = i11;
5839
 
5840
+ for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
5841
+ (( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
5842
+ ((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
5843
  }
5844
  }
5845
 
 
5847
  device const void * src0,
5848
  device const void * src1,
5849
  device int32_t * dst,
5850
+ constant ggml_metal_kargs_get_rows & args,
 
 
 
 
 
 
 
5851
  uint3 tgpig[[threadgroup_position_in_grid]],
5852
  uint tiitg[[thread_index_in_threadgroup]],
5853
  uint3 tptg [[threads_per_threadgroup]]) {
5854
  const int64_t i10 = tgpig.x;
5855
  const int64_t i11 = tgpig.y;
5856
 
5857
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
5858
 
5859
  const int64_t i02 = i11;
5860
 
5861
+ for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
5862
+ (( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
5863
+ ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
5864
  }
5865
  }
5866
 
 
6462
  kernel void kernel_pool_2d_max_f32(
6463
  device const float * src0,
6464
  device float * dst,
6465
+ constant ggml_metal_kargs_pool_2d & args,
 
 
 
 
 
 
 
 
 
 
6466
  uint gid[[thread_position_in_grid]]) {
6467
 
6468
+ if (gid >= args.parallel_elements) {
6469
  return;
6470
  }
6471
 
6472
  const int idx = gid;
6473
+ const int I_HW = args.IH * args.IW;
6474
+ const int O_HW = args.OH * args.OW;
6475
  const int nc = idx / O_HW;
6476
+ const int cur_oh = idx % O_HW / args.OW;
6477
+ const int cur_ow = idx % O_HW % args.OW;
6478
 
6479
  device const float * i_ptr = src0 + nc * I_HW;
6480
  device float * o_ptr = dst + nc * O_HW;
6481
 
6482
+ const int start_h = cur_oh * args.s1 - args.p1;
6483
  const int bh = MAX(0, start_h);
6484
+ const int eh = MIN(args.IH, start_h + args.k1);
6485
+ const int start_w = cur_ow * args.s0 - args.p0;
6486
  const int bw = MAX(0, start_w);
6487
+ const int ew = MIN(args.IW, start_w + args.k0);
6488
 
6489
  float res = -INFINITY;
6490
 
6491
  for (int i = bh; i < eh; i += 1) {
6492
  for (int j = bw; j < ew; j += 1) {
6493
+ res = MAX(res, i_ptr[i * args.IW + j]);
6494
  }
6495
  }
6496
 
6497
+ o_ptr[cur_oh * args.OW + cur_ow] = res;
6498
  }
6499
 
6500
  kernel void kernel_pool_2d_avg_f32(
6501
  device const float * src0,
6502
  device float * dst,
6503
+ constant ggml_metal_kargs_pool_2d & args,
 
 
 
 
 
 
 
 
 
 
6504
  uint gid[[thread_position_in_grid]]) {
6505
 
6506
+ if (gid >= args.parallel_elements) {
6507
  return;
6508
  }
6509
 
6510
  const int idx = gid;
6511
+ const int I_HW = args.IH * args.IW;
6512
+ const int O_HW = args.OH * args.OW;
6513
  const int nc = idx / O_HW;
6514
+ const int cur_oh = idx % O_HW / args.OW;
6515
+ const int cur_ow = idx % O_HW % args.OW;
6516
 
6517
  device const float * i_ptr = src0 + nc * I_HW;
6518
  device float * o_ptr = dst + nc * O_HW;
6519
 
6520
+ const int start_h = cur_oh * args.s1 - args.p1;
6521
  const int bh = MAX(0, start_h);
6522
+ const int eh = MIN(args.IH, start_h + args.k1);
6523
+ const int start_w = cur_ow * args.s0 - args.p0;
6524
  const int bw = MAX(0, start_w);
6525
+ const int ew = MIN(args.IW, start_w + args.k0);
6526
  // const float scale = 1. / ((eh - bh) * (ew - bw));
6527
+ const float scale = 1. / (args.k0 * args.k1);
6528
 
6529
  float res = 0;
6530
 
6531
  for (int i = bh; i < eh; i += 1) {
6532
  for (int j = bw; j < ew; j += 1) {
6533
+ float cur = i_ptr[i * args.IW + j];
6534
  res += cur * scale;
6535
  }
6536
  }
6537
 
6538
+ o_ptr[cur_oh * args.OW + cur_ow] = res;
6539
  }