Spaces:
Running
metal : simplify kernel arguments using a struct (ggml/3229) (llama/12194)
Browse files* metal : refactor im2col parameters into a struct
* metal: Change im2col offset types from int32_t to uint64_t to support larger memory offsets
* metal : refactor sum_rows parameters into a struct
* metal : refactor soft_max parameters into a struct
* metal : refactor diag_mask_inf parameters into a struct
* metal : refactor ssm_conv parameters into a struct
* metal : refactor ssm_scan parameters into a struct
* metal : refactor get_rows parameters into a struct
* metal : refactor group_norm parameters into a struct
* metal : refactor conv_transpose_1d parameters into a struct
* metal : refactor upscale parameters into a struct
* metal : refactor pad parameters into a struct
* metal : refactor pad_reflect_1d parameters into a struct
* metal : refactor arange parameters into a struct
* metal : refactor timestep_embedding parameters into a struct
* metal : refactor argsort parameters into a struct
* metal : refactor leaky_relu parameters into a struct
* metal : refactor pool_2d parameters into a struct
* metal : fix trailing whitespace
---------
Co-authored-by: alexju <[email protected]>
- ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
- ggml/src/ggml-metal/ggml-metal.m +260 -206
- ggml/src/ggml-metal/ggml-metal.metal +190 -437
|
@@ -285,4 +285,239 @@ typedef struct {
|
|
| 285 |
float eps;
|
| 286 |
} ggml_metal_kargs_rms_norm;
|
| 287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
#endif // GGML_METAL_IMPL
|
|
|
|
| 285 |
float eps;
|
| 286 |
} ggml_metal_kargs_rms_norm;
|
| 287 |
|
| 288 |
+
typedef struct {
|
| 289 |
+
int64_t ne00;
|
| 290 |
+
int64_t ne01;
|
| 291 |
+
int64_t ne02;
|
| 292 |
+
uint64_t nb00;
|
| 293 |
+
uint64_t nb01;
|
| 294 |
+
uint64_t nb02;
|
| 295 |
+
int32_t n_groups;
|
| 296 |
+
float eps;
|
| 297 |
+
} ggml_metal_kargs_group_norm;
|
| 298 |
+
|
| 299 |
+
typedef struct {
|
| 300 |
+
int32_t IC;
|
| 301 |
+
int32_t IL;
|
| 302 |
+
int32_t K;
|
| 303 |
+
int32_t s0;
|
| 304 |
+
uint64_t nb0;
|
| 305 |
+
uint64_t nb1;
|
| 306 |
+
} ggml_metal_kargs_conv_transpose_1d;
|
| 307 |
+
|
| 308 |
+
typedef struct {
|
| 309 |
+
uint64_t ofs0;
|
| 310 |
+
uint64_t ofs1;
|
| 311 |
+
int32_t IW;
|
| 312 |
+
int32_t IH;
|
| 313 |
+
int32_t CHW;
|
| 314 |
+
int32_t s0;
|
| 315 |
+
int32_t s1;
|
| 316 |
+
int32_t p0;
|
| 317 |
+
int32_t p1;
|
| 318 |
+
int32_t d0;
|
| 319 |
+
int32_t d1;
|
| 320 |
+
int32_t N;
|
| 321 |
+
int32_t KH;
|
| 322 |
+
int32_t KW;
|
| 323 |
+
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
|
| 324 |
+
} ggml_metal_kargs_im2col;
|
| 325 |
+
|
| 326 |
+
typedef struct {
|
| 327 |
+
int64_t ne00;
|
| 328 |
+
int64_t ne01;
|
| 329 |
+
int64_t ne02;
|
| 330 |
+
int64_t ne03;
|
| 331 |
+
uint64_t nb00;
|
| 332 |
+
uint64_t nb01;
|
| 333 |
+
uint64_t nb02;
|
| 334 |
+
uint64_t nb03;
|
| 335 |
+
int64_t ne10;
|
| 336 |
+
int64_t ne11;
|
| 337 |
+
int64_t ne12;
|
| 338 |
+
int64_t ne13;
|
| 339 |
+
uint64_t nb10;
|
| 340 |
+
uint64_t nb11;
|
| 341 |
+
uint64_t nb12;
|
| 342 |
+
uint64_t nb13;
|
| 343 |
+
int64_t ne0;
|
| 344 |
+
int64_t ne1;
|
| 345 |
+
int64_t ne2;
|
| 346 |
+
int64_t ne3;
|
| 347 |
+
uint64_t nb0;
|
| 348 |
+
uint64_t nb1;
|
| 349 |
+
uint64_t nb2;
|
| 350 |
+
uint64_t nb3;
|
| 351 |
+
} ggml_metal_kargs_sum_rows;
|
| 352 |
+
|
| 353 |
+
typedef struct {
|
| 354 |
+
int64_t ne00;
|
| 355 |
+
int64_t ne01;
|
| 356 |
+
int64_t ne02;
|
| 357 |
+
float scale;
|
| 358 |
+
float max_bias;
|
| 359 |
+
float m0;
|
| 360 |
+
float m1;
|
| 361 |
+
uint32_t n_head_log2;
|
| 362 |
+
} ggml_metal_kargs_soft_max;
|
| 363 |
+
|
| 364 |
+
typedef struct {
|
| 365 |
+
int64_t ne00;
|
| 366 |
+
int64_t ne01;
|
| 367 |
+
int n_past;
|
| 368 |
+
} ggml_metal_kargs_diag_mask_inf;
|
| 369 |
+
|
| 370 |
+
typedef struct {
|
| 371 |
+
int64_t ne00;
|
| 372 |
+
int64_t ne01;
|
| 373 |
+
int64_t ne02;
|
| 374 |
+
uint64_t nb00;
|
| 375 |
+
uint64_t nb01;
|
| 376 |
+
uint64_t nb02;
|
| 377 |
+
int64_t ne10;
|
| 378 |
+
int64_t ne11;
|
| 379 |
+
uint64_t nb10;
|
| 380 |
+
uint64_t nb11;
|
| 381 |
+
int64_t ne0;
|
| 382 |
+
int64_t ne1;
|
| 383 |
+
int64_t ne2;
|
| 384 |
+
uint64_t nb0;
|
| 385 |
+
uint64_t nb1;
|
| 386 |
+
uint64_t nb2;
|
| 387 |
+
} ggml_metal_kargs_ssm_conv;
|
| 388 |
+
|
| 389 |
+
typedef struct {
|
| 390 |
+
int64_t d_state;
|
| 391 |
+
int64_t d_inner;
|
| 392 |
+
int64_t n_seq_tokens;
|
| 393 |
+
int64_t n_seqs;
|
| 394 |
+
uint64_t nb00;
|
| 395 |
+
uint64_t nb01;
|
| 396 |
+
uint64_t nb02;
|
| 397 |
+
uint64_t nb10;
|
| 398 |
+
uint64_t nb11;
|
| 399 |
+
uint64_t nb12;
|
| 400 |
+
uint64_t nb13;
|
| 401 |
+
uint64_t nb20;
|
| 402 |
+
uint64_t nb21;
|
| 403 |
+
uint64_t nb22;
|
| 404 |
+
uint64_t nb30;
|
| 405 |
+
uint64_t nb31;
|
| 406 |
+
uint64_t nb40;
|
| 407 |
+
uint64_t nb41;
|
| 408 |
+
uint64_t nb42;
|
| 409 |
+
uint64_t nb50;
|
| 410 |
+
uint64_t nb51;
|
| 411 |
+
uint64_t nb52;
|
| 412 |
+
} ggml_metal_kargs_ssm_scan;
|
| 413 |
+
|
| 414 |
+
typedef struct {
|
| 415 |
+
int64_t ne00;
|
| 416 |
+
uint64_t nb01;
|
| 417 |
+
uint64_t nb02;
|
| 418 |
+
int64_t ne10;
|
| 419 |
+
uint64_t nb10;
|
| 420 |
+
uint64_t nb11;
|
| 421 |
+
uint64_t nb1;
|
| 422 |
+
uint64_t nb2;
|
| 423 |
+
} ggml_metal_kargs_get_rows;
|
| 424 |
+
|
| 425 |
+
typedef struct {
|
| 426 |
+
int64_t ne00;
|
| 427 |
+
int64_t ne01;
|
| 428 |
+
int64_t ne02;
|
| 429 |
+
int64_t ne03;
|
| 430 |
+
uint64_t nb00;
|
| 431 |
+
uint64_t nb01;
|
| 432 |
+
uint64_t nb02;
|
| 433 |
+
uint64_t nb03;
|
| 434 |
+
int64_t ne0;
|
| 435 |
+
int64_t ne1;
|
| 436 |
+
int64_t ne2;
|
| 437 |
+
int64_t ne3;
|
| 438 |
+
uint64_t nb0;
|
| 439 |
+
uint64_t nb1;
|
| 440 |
+
uint64_t nb2;
|
| 441 |
+
uint64_t nb3;
|
| 442 |
+
float sf0;
|
| 443 |
+
float sf1;
|
| 444 |
+
float sf2;
|
| 445 |
+
float sf3;
|
| 446 |
+
} ggml_metal_kargs_upscale;
|
| 447 |
+
|
| 448 |
+
typedef struct {
|
| 449 |
+
int64_t ne00;
|
| 450 |
+
int64_t ne01;
|
| 451 |
+
int64_t ne02;
|
| 452 |
+
int64_t ne03;
|
| 453 |
+
uint64_t nb00;
|
| 454 |
+
uint64_t nb01;
|
| 455 |
+
uint64_t nb02;
|
| 456 |
+
uint64_t nb03;
|
| 457 |
+
int64_t ne0;
|
| 458 |
+
int64_t ne1;
|
| 459 |
+
int64_t ne2;
|
| 460 |
+
int64_t ne3;
|
| 461 |
+
uint64_t nb0;
|
| 462 |
+
uint64_t nb1;
|
| 463 |
+
uint64_t nb2;
|
| 464 |
+
uint64_t nb3;
|
| 465 |
+
} ggml_metal_kargs_pad;
|
| 466 |
+
|
| 467 |
+
typedef struct {
|
| 468 |
+
int64_t ne00;
|
| 469 |
+
int64_t ne01;
|
| 470 |
+
int64_t ne02;
|
| 471 |
+
int64_t ne03;
|
| 472 |
+
uint64_t nb00;
|
| 473 |
+
uint64_t nb01;
|
| 474 |
+
uint64_t nb02;
|
| 475 |
+
uint64_t nb03;
|
| 476 |
+
int64_t ne0;
|
| 477 |
+
int64_t ne1;
|
| 478 |
+
int64_t ne2;
|
| 479 |
+
int64_t ne3;
|
| 480 |
+
uint64_t nb0;
|
| 481 |
+
uint64_t nb1;
|
| 482 |
+
uint64_t nb2;
|
| 483 |
+
uint64_t nb3;
|
| 484 |
+
int32_t p0;
|
| 485 |
+
int32_t p1;
|
| 486 |
+
} ggml_metal_kargs_pad_reflect_1d;
|
| 487 |
+
|
| 488 |
+
typedef struct {
|
| 489 |
+
uint64_t nb1;
|
| 490 |
+
int dim;
|
| 491 |
+
int max_period;
|
| 492 |
+
} ggml_metal_kargs_timestep_embedding;
|
| 493 |
+
|
| 494 |
+
typedef struct {
|
| 495 |
+
float slope;
|
| 496 |
+
} ggml_metal_kargs_leaky_relu;
|
| 497 |
+
|
| 498 |
+
typedef struct {
|
| 499 |
+
int64_t ncols;
|
| 500 |
+
int64_t ncols_pad;
|
| 501 |
+
} ggml_metal_kargs_argsort;
|
| 502 |
+
|
| 503 |
+
typedef struct {
|
| 504 |
+
int64_t ne0;
|
| 505 |
+
float start;
|
| 506 |
+
float step;
|
| 507 |
+
} ggml_metal_kargs_arange;
|
| 508 |
+
|
| 509 |
+
typedef struct {
|
| 510 |
+
int32_t k0;
|
| 511 |
+
int32_t k1;
|
| 512 |
+
int32_t s0;
|
| 513 |
+
int32_t s1;
|
| 514 |
+
int32_t p0;
|
| 515 |
+
int32_t p1;
|
| 516 |
+
int64_t IH;
|
| 517 |
+
int64_t IW;
|
| 518 |
+
int64_t OH;
|
| 519 |
+
int64_t OW;
|
| 520 |
+
int64_t parallel_elements;
|
| 521 |
+
} ggml_metal_kargs_pool_2d;
|
| 522 |
+
|
| 523 |
#endif // GGML_METAL_IMPL
|
|
@@ -1945,34 +1945,38 @@ static void ggml_metal_encode_node(
|
|
| 1945 |
|
| 1946 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
| 1947 |
|
| 1948 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1949 |
[encoder setComputePipelineState:pipeline];
|
| 1950 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1951 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1952 |
-
[encoder setBytes:&
|
| 1953 |
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
| 1954 |
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
| 1955 |
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
| 1956 |
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
| 1957 |
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
| 1958 |
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
| 1959 |
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
| 1960 |
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
| 1961 |
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
| 1962 |
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
| 1963 |
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
| 1964 |
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
| 1965 |
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
| 1966 |
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
| 1967 |
-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
| 1968 |
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
| 1969 |
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
| 1970 |
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
| 1971 |
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
| 1972 |
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
| 1973 |
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
| 1974 |
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
| 1975 |
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
| 1976 |
|
| 1977 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1978 |
} break;
|
|
@@ -2021,8 +2025,17 @@ static void ggml_metal_encode_node(
|
|
| 2021 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 2022 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 2023 |
|
| 2024 |
-
|
| 2025 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2026 |
[encoder setComputePipelineState:pipeline];
|
| 2027 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2028 |
if (id_src1) {
|
|
@@ -2031,14 +2044,7 @@ static void ggml_metal_encode_node(
|
|
| 2031 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 2032 |
}
|
| 2033 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 2034 |
-
[encoder setBytes:&
|
| 2035 |
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
| 2036 |
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
| 2037 |
-
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
| 2038 |
-
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
|
| 2039 |
-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
|
| 2040 |
-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
|
| 2041 |
-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
|
| 2042 |
|
| 2043 |
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
| 2044 |
|
|
@@ -2056,13 +2062,16 @@ static void ggml_metal_encode_node(
|
|
| 2056 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
|
| 2057 |
}
|
| 2058 |
|
| 2059 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2060 |
[encoder setComputePipelineState:pipeline];
|
| 2061 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2062 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 2063 |
-
[encoder setBytes:&
|
| 2064 |
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
| 2065 |
-
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
| 2066 |
|
| 2067 |
if (ne00%8 == 0) {
|
| 2068 |
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
@@ -2081,27 +2090,30 @@ static void ggml_metal_encode_node(
|
|
| 2081 |
|
| 2082 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
|
| 2083 |
|
| 2084 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2085 |
[encoder setComputePipelineState:pipeline];
|
| 2086 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2087 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 2088 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 2089 |
-
[encoder setBytes:&
|
| 2090 |
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
| 2091 |
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
| 2092 |
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
| 2093 |
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
| 2094 |
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
| 2095 |
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
| 2096 |
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
| 2097 |
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
|
| 2098 |
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
|
| 2099 |
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
| 2100 |
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
| 2101 |
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
|
| 2102 |
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
|
| 2103 |
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
|
| 2104 |
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
|
| 2105 |
|
| 2106 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2107 |
} break;
|
|
@@ -2152,7 +2164,31 @@ static void ggml_metal_encode_node(
|
|
| 2152 |
|
| 2153 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
| 2154 |
|
| 2155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2156 |
[encoder setComputePipelineState:pipeline];
|
| 2157 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2158 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
@@ -2161,30 +2197,7 @@ static void ggml_metal_encode_node(
|
|
| 2161 |
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
| 2162 |
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
| 2163 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
| 2164 |
-
|
| 2165 |
-
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
|
| 2166 |
-
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
|
| 2167 |
-
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
|
| 2168 |
-
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
|
| 2169 |
-
|
| 2170 |
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
|
| 2171 |
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
|
| 2172 |
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
|
| 2173 |
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
| 2174 |
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
| 2175 |
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
| 2176 |
-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
| 2177 |
-
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
|
| 2178 |
-
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
|
| 2179 |
-
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
|
| 2180 |
-
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
|
| 2181 |
-
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
|
| 2182 |
-
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
|
| 2183 |
-
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
|
| 2184 |
-
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
|
| 2185 |
-
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
|
| 2186 |
-
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
|
| 2187 |
-
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
|
| 2188 |
|
| 2189 |
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2190 |
} break;
|
|
@@ -3041,19 +3054,22 @@ static void ggml_metal_encode_node(
|
|
| 3041 |
default: GGML_ABORT("not implemented");
|
| 3042 |
}
|
| 3043 |
|
| 3044 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3045 |
[encoder setComputePipelineState:pipeline];
|
| 3046 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3047 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 3048 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 3049 |
-
[encoder setBytes:&
|
| 3050 |
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
| 3051 |
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
| 3052 |
-
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
| 3053 |
-
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
| 3054 |
-
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
| 3055 |
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
| 3056 |
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
| 3057 |
|
| 3058 |
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
| 3059 |
} break;
|
|
@@ -3110,18 +3126,21 @@ static void ggml_metal_encode_node(
|
|
| 3110 |
|
| 3111 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
|
| 3112 |
|
| 3113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3114 |
[encoder setComputePipelineState:pipeline];
|
| 3115 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3116 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3117 |
-
[encoder setBytes:&
|
| 3118 |
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
| 3119 |
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
| 3120 |
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
| 3121 |
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
| 3122 |
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
| 3123 |
-
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
| 3124 |
-
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
| 3125 |
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
| 3126 |
|
| 3127 |
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
@@ -3279,8 +3298,8 @@ static void ggml_metal_encode_node(
|
|
| 3279 |
|
| 3280 |
const int32_t CHW = IC * KH * KW;
|
| 3281 |
|
| 3282 |
-
const
|
| 3283 |
-
const
|
| 3284 |
|
| 3285 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
|
| 3286 |
|
|
@@ -3302,27 +3321,30 @@ static void ggml_metal_encode_node(
|
|
| 3302 |
default: GGML_ABORT("fatal error");
|
| 3303 |
};
|
| 3304 |
|
| 3305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3306 |
[encoder setComputePipelineState:pipeline];
|
| 3307 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
| 3308 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3309 |
-
[encoder setBytes:&
|
| 3310 |
-
[encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
|
| 3311 |
-
[encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
|
| 3312 |
-
[encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
|
| 3313 |
-
[encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
|
| 3314 |
-
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
|
| 3315 |
-
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
|
| 3316 |
-
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
|
| 3317 |
-
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
|
| 3318 |
-
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
|
| 3319 |
-
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
|
| 3320 |
|
| 3321 |
if (is_gt_mttpt) {
|
| 3322 |
-
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
|
| 3323 |
-
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
|
| 3324 |
-
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
|
| 3325 |
-
|
| 3326 |
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
|
| 3327 |
|
| 3328 |
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
|
|
@@ -3362,16 +3384,20 @@ static void ggml_metal_encode_node(
|
|
| 3362 |
default: GGML_ABORT("fatal error");
|
| 3363 |
};
|
| 3364 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3365 |
[encoder setComputePipelineState:pipeline];
|
| 3366 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3367 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 3368 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 3369 |
-
[encoder setBytes:&
|
| 3370 |
-
[encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
|
| 3371 |
-
[encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
|
| 3372 |
-
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
|
| 3373 |
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
|
| 3374 |
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
|
| 3375 |
|
| 3376 |
[encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 3377 |
} break;
|
|
@@ -3386,30 +3412,33 @@ static void ggml_metal_encode_node(
|
|
| 3386 |
|
| 3387 |
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
| 3388 |
|
| 3389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3390 |
[encoder setComputePipelineState:pipeline];
|
| 3391 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3392 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3393 |
-
[encoder setBytes:&
|
| 3394 |
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
| 3395 |
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
| 3396 |
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
| 3397 |
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
| 3398 |
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
| 3399 |
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
| 3400 |
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
| 3401 |
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
| 3402 |
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
| 3403 |
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
| 3404 |
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
| 3405 |
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
| 3406 |
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
| 3407 |
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
| 3408 |
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
| 3409 |
-
[encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
|
| 3410 |
-
[encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
|
| 3411 |
-
[encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
|
| 3412 |
-
[encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
|
| 3413 |
|
| 3414 |
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
| 3415 |
|
|
@@ -3421,26 +3450,29 @@ static void ggml_metal_encode_node(
|
|
| 3421 |
|
| 3422 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
|
| 3423 |
|
| 3424 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3425 |
[encoder setComputePipelineState:pipeline];
|
| 3426 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3427 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3428 |
-
[encoder setBytes:&
|
| 3429 |
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
| 3430 |
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
| 3431 |
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
| 3432 |
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
| 3433 |
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
| 3434 |
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
| 3435 |
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
| 3436 |
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
| 3437 |
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
| 3438 |
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
| 3439 |
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
| 3440 |
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
| 3441 |
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
| 3442 |
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
| 3443 |
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
| 3444 |
|
| 3445 |
const int nth = MIN(1024, ne0);
|
| 3446 |
|
|
@@ -3455,24 +3487,31 @@ static void ggml_metal_encode_node(
|
|
| 3455 |
|
| 3456 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
|
| 3457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3458 |
[encoder setComputePipelineState:pipeline];
|
| 3459 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3460 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3461 |
-
[encoder setBytes:&
|
| 3462 |
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
| 3463 |
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
| 3464 |
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
| 3465 |
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
|
| 3466 |
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
| 3467 |
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
| 3468 |
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
| 3469 |
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
| 3470 |
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
|
| 3471 |
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
|
| 3472 |
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
|
| 3473 |
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
|
| 3474 |
-
[encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
|
| 3475 |
-
[encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
|
| 3476 |
|
| 3477 |
const int nth = MIN(1024, ne0);
|
| 3478 |
|
|
@@ -3490,12 +3529,15 @@ static void ggml_metal_encode_node(
|
|
| 3490 |
|
| 3491 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
|
| 3492 |
|
| 3493 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3494 |
[encoder setComputePipelineState:pipeline];
|
| 3495 |
-
[encoder setBuffer:id_dst offset:offs_dst
|
| 3496 |
-
[encoder setBytes:&
|
| 3497 |
-
[encoder setBytes:&start length:sizeof(start) atIndex:2];
|
| 3498 |
-
[encoder setBytes:&step length:sizeof(step) atIndex:3];
|
| 3499 |
|
| 3500 |
const int nth = MIN(1024, ne0);
|
| 3501 |
|
|
@@ -3512,13 +3554,16 @@ static void ggml_metal_encode_node(
|
|
| 3512 |
|
| 3513 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
|
| 3514 |
|
| 3515 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3516 |
[encoder setComputePipelineState:pipeline];
|
| 3517 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3518 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3519 |
-
[encoder setBytes:&
|
| 3520 |
-
[encoder setBytes:&dim length:sizeof(dim) atIndex:3];
|
| 3521 |
-
[encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
|
| 3522 |
|
| 3523 |
const int nth = MIN(1024, half);
|
| 3524 |
|
|
@@ -3551,12 +3596,15 @@ static void ggml_metal_encode_node(
|
|
| 3551 |
default: GGML_ABORT("fatal error");
|
| 3552 |
};
|
| 3553 |
|
| 3554 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3555 |
[encoder setComputePipelineState:pipeline];
|
| 3556 |
-
[encoder setBuffer:id_src0
|
| 3557 |
-
[encoder setBuffer:id_dst
|
| 3558 |
-
[encoder setBytes:&
|
| 3559 |
-
[encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
|
| 3560 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 3561 |
|
| 3562 |
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
|
|
@@ -3570,11 +3618,14 @@ static void ggml_metal_encode_node(
|
|
| 3570 |
|
| 3571 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
|
| 3572 |
|
| 3573 |
-
|
|
|
|
|
|
|
|
|
|
| 3574 |
[encoder setComputePipelineState:pipeline];
|
| 3575 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3576 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3577 |
-
[encoder setBytes:&
|
| 3578 |
|
| 3579 |
const int64_t n = ggml_nelements(dst);
|
| 3580 |
|
|
@@ -4150,21 +4201,24 @@ static void ggml_metal_encode_node(
|
|
| 4150 |
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
|
| 4151 |
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
|
| 4152 |
|
| 4153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4154 |
[encoder setComputePipelineState:pipeline];
|
| 4155 |
-
[encoder setBuffer:id_src0 offset:offs_src0
|
| 4156 |
-
[encoder setBuffer:id_dst offset:offs_dst
|
| 4157 |
-
[encoder setBytes:&
|
| 4158 |
-
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
|
| 4159 |
-
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
|
| 4160 |
-
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
|
| 4161 |
-
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
|
| 4162 |
-
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
|
| 4163 |
-
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
|
| 4164 |
-
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
|
| 4165 |
-
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
|
| 4166 |
-
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
|
| 4167 |
-
[encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
|
| 4168 |
|
| 4169 |
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
| 4170 |
} break;
|
|
|
|
| 1945 |
|
| 1946 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
| 1947 |
|
| 1948 |
+
|
| 1949 |
+
ggml_metal_kargs_sum_rows args = {
|
| 1950 |
+
/*.ne00 =*/ ne00,
|
| 1951 |
+
/*.ne01 =*/ ne01,
|
| 1952 |
+
/*.ne02 =*/ ne02,
|
| 1953 |
+
/*.ne03 =*/ ne03,
|
| 1954 |
+
/*.nb00 =*/ nb00,
|
| 1955 |
+
/*.nb01 =*/ nb01,
|
| 1956 |
+
/*.nb02 =*/ nb02,
|
| 1957 |
+
/*.nb03 =*/ nb03,
|
| 1958 |
+
/*.ne10 =*/ ne10,
|
| 1959 |
+
/*.ne11 =*/ ne11,
|
| 1960 |
+
/*.ne12 =*/ ne12,
|
| 1961 |
+
/*.ne13 =*/ ne13,
|
| 1962 |
+
/*.nb10 =*/ nb10,
|
| 1963 |
+
/*.nb11 =*/ nb11,
|
| 1964 |
+
/*.nb12 =*/ nb12,
|
| 1965 |
+
/*.nb13 =*/ nb13,
|
| 1966 |
+
/*.ne0 =*/ ne0,
|
| 1967 |
+
/*.ne1 =*/ ne1,
|
| 1968 |
+
/*.ne2 =*/ ne2,
|
| 1969 |
+
/*.ne3 =*/ ne3,
|
| 1970 |
+
/*.nb0 =*/ nb0,
|
| 1971 |
+
/*.nb1 =*/ nb1,
|
| 1972 |
+
/*.nb2 =*/ nb2,
|
| 1973 |
+
/*.nb3 =*/ nb3,
|
| 1974 |
+
};
|
| 1975 |
+
|
| 1976 |
[encoder setComputePipelineState:pipeline];
|
| 1977 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1978 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1979 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1980 |
|
| 1981 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1982 |
} break;
|
|
|
|
| 2025 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 2026 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 2027 |
|
| 2028 |
+
ggml_metal_kargs_soft_max args = {
|
| 2029 |
+
/*.ne00 =*/ ne00,
|
| 2030 |
+
/*.ne01 =*/ ne01,
|
| 2031 |
+
/*.ne02 =*/ ne02,
|
| 2032 |
+
/*.scale =*/ scale,
|
| 2033 |
+
/*.max_bias =*/ max_bias,
|
| 2034 |
+
/*.m0 =*/ m0,
|
| 2035 |
+
/*.m1 =*/ m1,
|
| 2036 |
+
/*.n_head_log2 =*/ n_head_log2,
|
| 2037 |
+
};
|
| 2038 |
+
|
| 2039 |
[encoder setComputePipelineState:pipeline];
|
| 2040 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2041 |
if (id_src1) {
|
|
|
|
| 2044 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 2045 |
}
|
| 2046 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 2047 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2048 |
|
| 2049 |
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
| 2050 |
|
|
|
|
| 2062 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
|
| 2063 |
}
|
| 2064 |
|
| 2065 |
+
ggml_metal_kargs_diag_mask_inf args = {
|
| 2066 |
+
/*.ne00 =*/ ne00,
|
| 2067 |
+
/*.ne01 =*/ ne01,
|
| 2068 |
+
/*.n_past =*/ n_past,
|
| 2069 |
+
};
|
| 2070 |
+
|
| 2071 |
[encoder setComputePipelineState:pipeline];
|
| 2072 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2073 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 2074 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
|
|
|
|
|
|
| 2075 |
|
| 2076 |
if (ne00%8 == 0) {
|
| 2077 |
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
| 2090 |
|
| 2091 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
|
| 2092 |
|
| 2093 |
+
ggml_metal_kargs_ssm_conv args = {
|
| 2094 |
+
/*.ne00 =*/ ne00,
|
| 2095 |
+
/*.ne01 =*/ ne01,
|
| 2096 |
+
/*.ne02 =*/ ne02,
|
| 2097 |
+
/*.nb00 =*/ nb00,
|
| 2098 |
+
/*.nb01 =*/ nb01,
|
| 2099 |
+
/*.nb02 =*/ nb02,
|
| 2100 |
+
/*.ne10 =*/ ne10,
|
| 2101 |
+
/*.ne11 =*/ ne11,
|
| 2102 |
+
/*.nb10 =*/ nb10,
|
| 2103 |
+
/*.nb11 =*/ nb11,
|
| 2104 |
+
/*.ne0 =*/ ne0,
|
| 2105 |
+
/*.ne1 =*/ ne1,
|
| 2106 |
+
/*.ne2 =*/ ne2,
|
| 2107 |
+
/*.nb0 =*/ nb0,
|
| 2108 |
+
/*.nb1 =*/ nb1,
|
| 2109 |
+
/*.nb2 =*/ nb2,
|
| 2110 |
+
};
|
| 2111 |
+
|
| 2112 |
[encoder setComputePipelineState:pipeline];
|
| 2113 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2114 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 2115 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 2116 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2117 |
|
| 2118 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2119 |
} break;
|
|
|
|
| 2164 |
|
| 2165 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
| 2166 |
|
| 2167 |
+
ggml_metal_kargs_ssm_scan args = {
|
| 2168 |
+
/*.d_state =*/ d_state,
|
| 2169 |
+
/*.d_inner =*/ d_inner,
|
| 2170 |
+
/*.n_seq_tokens =*/ n_seq_tokens,
|
| 2171 |
+
/*.n_seqs =*/ n_seqs,
|
| 2172 |
+
/*.nb00 =*/ nb00,
|
| 2173 |
+
/*.nb01 =*/ nb01,
|
| 2174 |
+
/*.nb02 =*/ nb02,
|
| 2175 |
+
/*.nb10 =*/ nb10,
|
| 2176 |
+
/*.nb11 =*/ nb11,
|
| 2177 |
+
/*.nb12 =*/ nb12,
|
| 2178 |
+
/*.nb13 =*/ nb13,
|
| 2179 |
+
/*.nb20 =*/ nb20,
|
| 2180 |
+
/*.nb21 =*/ nb21,
|
| 2181 |
+
/*.nb22 =*/ nb22,
|
| 2182 |
+
/*.nb30 =*/ nb30,
|
| 2183 |
+
/*.nb31 =*/ nb31,
|
| 2184 |
+
/*.nb40 =*/ nb40,
|
| 2185 |
+
/*.nb41 =*/ nb41,
|
| 2186 |
+
/*.nb42 =*/ nb42,
|
| 2187 |
+
/*.nb50 =*/ nb50,
|
| 2188 |
+
/*.nb51 =*/ nb51,
|
| 2189 |
+
/*.nb52 =*/ nb52,
|
| 2190 |
+
};
|
| 2191 |
+
|
| 2192 |
[encoder setComputePipelineState:pipeline];
|
| 2193 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2194 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
| 2197 |
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
| 2198 |
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
| 2199 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
| 2200 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:7];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2201 |
|
| 2202 |
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2203 |
} break;
|
|
|
|
| 3054 |
default: GGML_ABORT("not implemented");
|
| 3055 |
}
|
| 3056 |
|
| 3057 |
+
ggml_metal_kargs_get_rows args = {
|
| 3058 |
+
/*.ne00 =*/ ne00,
|
| 3059 |
+
/*.nb01 =*/ nb01,
|
| 3060 |
+
/*.nb02 =*/ nb02,
|
| 3061 |
+
/*.ne10 =*/ ne10,
|
| 3062 |
+
/*.nb10 =*/ nb10,
|
| 3063 |
+
/*.nb11 =*/ nb11,
|
| 3064 |
+
/*.nb1 =*/ nb1,
|
| 3065 |
+
/*.nb2 =*/ nb2,
|
| 3066 |
+
};
|
| 3067 |
+
|
| 3068 |
[encoder setComputePipelineState:pipeline];
|
| 3069 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3070 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 3071 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 3072 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3073 |
|
| 3074 |
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
| 3075 |
} break;
|
|
|
|
| 3126 |
|
| 3127 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
|
| 3128 |
|
| 3129 |
+
ggml_metal_kargs_group_norm args = {
|
| 3130 |
+
/*.ne00 =*/ ne00,
|
| 3131 |
+
/*.ne01 =*/ ne01,
|
| 3132 |
+
/*.ne02 =*/ ne02,
|
| 3133 |
+
/*.nb00 =*/ nb00,
|
| 3134 |
+
/*.nb01 =*/ nb01,
|
| 3135 |
+
/*.nb02 =*/ nb02,
|
| 3136 |
+
/*.n_groups =*/ n_groups,
|
| 3137 |
+
/*.eps =*/ eps,
|
| 3138 |
+
};
|
| 3139 |
+
|
| 3140 |
[encoder setComputePipelineState:pipeline];
|
| 3141 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3142 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3143 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3144 |
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
| 3145 |
|
| 3146 |
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
| 3298 |
|
| 3299 |
const int32_t CHW = IC * KH * KW;
|
| 3300 |
|
| 3301 |
+
const uint64_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
| 3302 |
+
const uint64_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
| 3303 |
|
| 3304 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
|
| 3305 |
|
|
|
|
| 3321 |
default: GGML_ABORT("fatal error");
|
| 3322 |
};
|
| 3323 |
|
| 3324 |
+
ggml_metal_kargs_im2col args = {
|
| 3325 |
+
/*.ofs0 =*/ ofs0,
|
| 3326 |
+
/*.ofs1 =*/ ofs1,
|
| 3327 |
+
/*.IW =*/ IW,
|
| 3328 |
+
/*.IH =*/ IH,
|
| 3329 |
+
/*.CHW =*/ CHW,
|
| 3330 |
+
/*.s0 =*/ s0,
|
| 3331 |
+
/*.s1 =*/ s1,
|
| 3332 |
+
/*.p0 =*/ p0,
|
| 3333 |
+
/*.p1 =*/ p1,
|
| 3334 |
+
/*.d0 =*/ d0,
|
| 3335 |
+
/*.d1 =*/ d1,
|
| 3336 |
+
/*.N =*/ N,
|
| 3337 |
+
/*.KH =*/ KH,
|
| 3338 |
+
/*.KW =*/ KW,
|
| 3339 |
+
/*.KHW =*/ KH * KW,
|
| 3340 |
+
};
|
| 3341 |
+
|
| 3342 |
[encoder setComputePipelineState:pipeline];
|
| 3343 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
| 3344 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3345 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3346 |
|
| 3347 |
if (is_gt_mttpt) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3348 |
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
|
| 3349 |
|
| 3350 |
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
|
|
|
|
| 3384 |
default: GGML_ABORT("fatal error");
|
| 3385 |
};
|
| 3386 |
|
| 3387 |
+
ggml_metal_kargs_conv_transpose_1d args = {
|
| 3388 |
+
/*.IC =*/ IC,
|
| 3389 |
+
/*.IL =*/ IL,
|
| 3390 |
+
/*.K =*/ K,
|
| 3391 |
+
/*.s0 =*/ s0,
|
| 3392 |
+
/*.nb0 =*/ nb0,
|
| 3393 |
+
/*.nb1 =*/ nb1,
|
| 3394 |
+
};
|
| 3395 |
+
|
| 3396 |
[encoder setComputePipelineState:pipeline];
|
| 3397 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3398 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 3399 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 3400 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3401 |
|
| 3402 |
[encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 3403 |
} break;
|
|
|
|
| 3412 |
|
| 3413 |
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
| 3414 |
|
| 3415 |
+
ggml_metal_kargs_upscale args = {
|
| 3416 |
+
/*.ne00 =*/ ne00,
|
| 3417 |
+
/*.ne01 =*/ ne01,
|
| 3418 |
+
/*.ne02 =*/ ne02,
|
| 3419 |
+
/*.ne03 =*/ ne03,
|
| 3420 |
+
/*.nb00 =*/ nb00,
|
| 3421 |
+
/*.nb01 =*/ nb01,
|
| 3422 |
+
/*.nb02 =*/ nb02,
|
| 3423 |
+
/*.nb03 =*/ nb03,
|
| 3424 |
+
/*.ne0 =*/ ne0,
|
| 3425 |
+
/*.ne1 =*/ ne1,
|
| 3426 |
+
/*.ne2 =*/ ne2,
|
| 3427 |
+
/*.ne3 =*/ ne3,
|
| 3428 |
+
/*.nb0 =*/ nb0,
|
| 3429 |
+
/*.nb1 =*/ nb1,
|
| 3430 |
+
/*.nb2 =*/ nb2,
|
| 3431 |
+
/*.nb3 =*/ nb3,
|
| 3432 |
+
/*.sf0 =*/ sf0,
|
| 3433 |
+
/*.sf1 =*/ sf1,
|
| 3434 |
+
/*.sf2 =*/ sf2,
|
| 3435 |
+
/*.sf3 =*/ sf3
|
| 3436 |
+
};
|
| 3437 |
+
|
| 3438 |
[encoder setComputePipelineState:pipeline];
|
| 3439 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3440 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3441 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3442 |
|
| 3443 |
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
| 3444 |
|
|
|
|
| 3450 |
|
| 3451 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
|
| 3452 |
|
| 3453 |
+
ggml_metal_kargs_pad args = {
|
| 3454 |
+
/*.ne00 =*/ ne00,
|
| 3455 |
+
/*.ne01 =*/ ne01,
|
| 3456 |
+
/*.ne02 =*/ ne02,
|
| 3457 |
+
/*.ne03 =*/ ne03,
|
| 3458 |
+
/*.nb00 =*/ nb00,
|
| 3459 |
+
/*.nb01 =*/ nb01,
|
| 3460 |
+
/*.nb02 =*/ nb02,
|
| 3461 |
+
/*.nb03 =*/ nb03,
|
| 3462 |
+
/*.ne0 =*/ ne0,
|
| 3463 |
+
/*.ne1 =*/ ne1,
|
| 3464 |
+
/*.ne2 =*/ ne2,
|
| 3465 |
+
/*.ne3 =*/ ne3,
|
| 3466 |
+
/*.nb0 =*/ nb0,
|
| 3467 |
+
/*.nb1 =*/ nb1,
|
| 3468 |
+
/*.nb2 =*/ nb2,
|
| 3469 |
+
/*.nb3 =*/ nb3
|
| 3470 |
+
};
|
| 3471 |
+
|
| 3472 |
[encoder setComputePipelineState:pipeline];
|
| 3473 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3474 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3475 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3476 |
|
| 3477 |
const int nth = MIN(1024, ne0);
|
| 3478 |
|
|
|
|
| 3487 |
|
| 3488 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
|
| 3489 |
|
| 3490 |
+
ggml_metal_kargs_pad_reflect_1d args = {
|
| 3491 |
+
/*.ne00 =*/ ne00,
|
| 3492 |
+
/*.ne01 =*/ ne01,
|
| 3493 |
+
/*.ne02 =*/ ne02,
|
| 3494 |
+
/*.ne03 =*/ ne03,
|
| 3495 |
+
/*.nb00 =*/ nb00,
|
| 3496 |
+
/*.nb01 =*/ nb01,
|
| 3497 |
+
/*.nb02 =*/ nb02,
|
| 3498 |
+
/*.nb03 =*/ nb03,
|
| 3499 |
+
/*.ne0 =*/ ne0,
|
| 3500 |
+
/*.ne1 =*/ ne1,
|
| 3501 |
+
/*.ne2 =*/ ne2,
|
| 3502 |
+
/*.ne3 =*/ ne3,
|
| 3503 |
+
/*.nb0 =*/ nb0,
|
| 3504 |
+
/*.nb1 =*/ nb1,
|
| 3505 |
+
/*.nb2 =*/ nb2,
|
| 3506 |
+
/*.nb3 =*/ nb3,
|
| 3507 |
+
/*.p0 =*/ p0,
|
| 3508 |
+
/*.p1 =*/ p1
|
| 3509 |
+
};
|
| 3510 |
+
|
| 3511 |
[encoder setComputePipelineState:pipeline];
|
| 3512 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3513 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3514 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3515 |
|
| 3516 |
const int nth = MIN(1024, ne0);
|
| 3517 |
|
|
|
|
| 3529 |
|
| 3530 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
|
| 3531 |
|
| 3532 |
+
ggml_metal_kargs_arange args = {
|
| 3533 |
+
/*.ne0 =*/ ne0,
|
| 3534 |
+
/*.start =*/ start,
|
| 3535 |
+
/*.step =*/ step
|
| 3536 |
+
};
|
| 3537 |
+
|
| 3538 |
[encoder setComputePipelineState:pipeline];
|
| 3539 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
| 3540 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:1];
|
|
|
|
|
|
|
| 3541 |
|
| 3542 |
const int nth = MIN(1024, ne0);
|
| 3543 |
|
|
|
|
| 3554 |
|
| 3555 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
|
| 3556 |
|
| 3557 |
+
ggml_metal_kargs_timestep_embedding args = {
|
| 3558 |
+
/*.nb1 =*/ nb1,
|
| 3559 |
+
/*.dim =*/ dim,
|
| 3560 |
+
/*.max_period =*/ max_period
|
| 3561 |
+
};
|
| 3562 |
+
|
| 3563 |
[encoder setComputePipelineState:pipeline];
|
| 3564 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3565 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3566 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
|
|
|
|
|
|
| 3567 |
|
| 3568 |
const int nth = MIN(1024, half);
|
| 3569 |
|
|
|
|
| 3596 |
default: GGML_ABORT("fatal error");
|
| 3597 |
};
|
| 3598 |
|
| 3599 |
+
ggml_metal_kargs_argsort args = {
|
| 3600 |
+
/*.ncols =*/ ne00,
|
| 3601 |
+
/*.ncols_pad =*/ ne00_padded
|
| 3602 |
+
};
|
| 3603 |
+
|
| 3604 |
[encoder setComputePipelineState:pipeline];
|
| 3605 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3606 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3607 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
|
|
|
| 3608 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 3609 |
|
| 3610 |
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
|
|
|
|
| 3618 |
|
| 3619 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
|
| 3620 |
|
| 3621 |
+
ggml_metal_kargs_leaky_relu args = {
|
| 3622 |
+
/*.slope =*/ slope
|
| 3623 |
+
};
|
| 3624 |
+
|
| 3625 |
[encoder setComputePipelineState:pipeline];
|
| 3626 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3627 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3628 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
| 3629 |
|
| 3630 |
const int64_t n = ggml_nelements(dst);
|
| 3631 |
|
|
|
|
| 4201 |
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
|
| 4202 |
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
|
| 4203 |
|
| 4204 |
+
ggml_metal_kargs_pool_2d args_pool_2d = {
|
| 4205 |
+
/* .k0 = */ k0,
|
| 4206 |
+
/* .k1 = */ k1,
|
| 4207 |
+
/* .s0 = */ s0,
|
| 4208 |
+
/* .s1 = */ s1,
|
| 4209 |
+
/* .p0 = */ p0,
|
| 4210 |
+
/* .p1 = */ p1,
|
| 4211 |
+
/* .IH = */ IH,
|
| 4212 |
+
/* .IW = */ IW,
|
| 4213 |
+
/* .OH = */ OH,
|
| 4214 |
+
/* .OW = */ OW,
|
| 4215 |
+
/* .parallel_elements = */ parallel_elements
|
| 4216 |
+
};
|
| 4217 |
+
|
| 4218 |
[encoder setComputePipelineState:pipeline];
|
| 4219 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 4220 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 4221 |
+
[encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4222 |
|
| 4223 |
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
| 4224 |
} break;
|
|
@@ -947,45 +947,22 @@ kernel void kernel_cos(
|
|
| 947 |
kernel void kernel_sum_rows(
|
| 948 |
device const float * src0,
|
| 949 |
device float * dst,
|
| 950 |
-
constant
|
| 951 |
-
constant int64_t & ne01,
|
| 952 |
-
constant int64_t & ne02,
|
| 953 |
-
constant int64_t & ne03,
|
| 954 |
-
constant uint64_t & nb00,
|
| 955 |
-
constant uint64_t & nb01,
|
| 956 |
-
constant uint64_t & nb02,
|
| 957 |
-
constant uint64_t & nb03,
|
| 958 |
-
constant int64_t & ne10,
|
| 959 |
-
constant int64_t & ne11,
|
| 960 |
-
constant int64_t & ne12,
|
| 961 |
-
constant int64_t & ne13,
|
| 962 |
-
constant uint64_t & nb10,
|
| 963 |
-
constant uint64_t & nb11,
|
| 964 |
-
constant uint64_t & nb12,
|
| 965 |
-
constant uint64_t & nb13,
|
| 966 |
-
constant int64_t & ne0,
|
| 967 |
-
constant int64_t & ne1,
|
| 968 |
-
constant int64_t & ne2,
|
| 969 |
-
constant int64_t & ne3,
|
| 970 |
-
constant uint64_t & nb0,
|
| 971 |
-
constant uint64_t & nb1,
|
| 972 |
-
constant uint64_t & nb2,
|
| 973 |
-
constant uint64_t & nb3,
|
| 974 |
uint3 tpig[[thread_position_in_grid]]) {
|
| 975 |
int64_t i3 = tpig.z;
|
| 976 |
int64_t i2 = tpig.y;
|
| 977 |
int64_t i1 = tpig.x;
|
| 978 |
|
| 979 |
-
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
| 980 |
return;
|
| 981 |
}
|
| 982 |
|
| 983 |
-
device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
| 984 |
-
device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
|
| 985 |
|
| 986 |
float row_sum = 0;
|
| 987 |
|
| 988 |
-
for (int64_t i0 = 0; i0 < ne00; i0++) {
|
| 989 |
row_sum += src_row[i0];
|
| 990 |
}
|
| 991 |
|
|
@@ -997,36 +974,29 @@ kernel void kernel_soft_max(
|
|
| 997 |
device const char * src0,
|
| 998 |
device const char * src1,
|
| 999 |
device char * dst,
|
| 1000 |
-
constant
|
| 1001 |
-
constant int64_t & ne01,
|
| 1002 |
-
constant int64_t & ne02,
|
| 1003 |
-
constant float & scale,
|
| 1004 |
-
constant float & max_bias,
|
| 1005 |
-
constant float & m0,
|
| 1006 |
-
constant float & m1,
|
| 1007 |
-
constant uint32_t & n_head_log2,
|
| 1008 |
threadgroup float * buf [[threadgroup(0)]],
|
| 1009 |
uint tgpig[[threadgroup_position_in_grid]],
|
| 1010 |
uint tpitg[[thread_position_in_threadgroup]],
|
| 1011 |
uint sgitg[[simdgroup_index_in_threadgroup]],
|
| 1012 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1013 |
uint ntg[[threads_per_threadgroup]]) {
|
| 1014 |
-
const int64_t i03 = (tgpig) / (ne02*ne01);
|
| 1015 |
-
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
| 1016 |
-
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
| 1017 |
|
| 1018 |
-
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 1019 |
-
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
| 1020 |
-
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 1021 |
|
| 1022 |
float slope = 1.0f;
|
| 1023 |
|
| 1024 |
// ALiBi
|
| 1025 |
-
if (max_bias > 0.0f) {
|
| 1026 |
const int64_t h = i02;
|
| 1027 |
|
| 1028 |
-
const float base = h < n_head_log2 ? m0 : m1;
|
| 1029 |
-
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 1030 |
|
| 1031 |
slope = pow(base, exp);
|
| 1032 |
}
|
|
@@ -1034,8 +1004,8 @@ kernel void kernel_soft_max(
|
|
| 1034 |
// parallel max
|
| 1035 |
float lmax = -INFINITY;
|
| 1036 |
|
| 1037 |
-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
| 1038 |
-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
| 1039 |
}
|
| 1040 |
|
| 1041 |
// find the max value in the block
|
|
@@ -1059,8 +1029,8 @@ kernel void kernel_soft_max(
|
|
| 1059 |
|
| 1060 |
// parallel sum
|
| 1061 |
float lsum = 0.0f;
|
| 1062 |
-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
| 1063 |
-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
| 1064 |
lsum += exp_psrc0;
|
| 1065 |
pdst[i00] = exp_psrc0;
|
| 1066 |
}
|
|
@@ -1090,7 +1060,7 @@ kernel void kernel_soft_max(
|
|
| 1090 |
|
| 1091 |
const float inv_sum = 1.0f/sum;
|
| 1092 |
|
| 1093 |
-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
| 1094 |
pdst[i00] *= inv_sum;
|
| 1095 |
}
|
| 1096 |
}
|
|
@@ -1100,35 +1070,28 @@ kernel void kernel_soft_max_4(
|
|
| 1100 |
device const char * src0,
|
| 1101 |
device const char * src1,
|
| 1102 |
device char * dst,
|
| 1103 |
-
constant
|
| 1104 |
-
constant int64_t & ne01,
|
| 1105 |
-
constant int64_t & ne02,
|
| 1106 |
-
constant float & scale,
|
| 1107 |
-
constant float & max_bias,
|
| 1108 |
-
constant float & m0,
|
| 1109 |
-
constant float & m1,
|
| 1110 |
-
constant uint32_t & n_head_log2,
|
| 1111 |
threadgroup float * buf [[threadgroup(0)]],
|
| 1112 |
uint tgpig[[threadgroup_position_in_grid]],
|
| 1113 |
uint tpitg[[thread_position_in_threadgroup]],
|
| 1114 |
uint sgitg[[simdgroup_index_in_threadgroup]],
|
| 1115 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1116 |
uint ntg[[threads_per_threadgroup]]) {
|
| 1117 |
-
const int64_t i03 = (tgpig) / (ne02*ne01);
|
| 1118 |
-
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
| 1119 |
-
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
| 1120 |
|
| 1121 |
-
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
| 1122 |
-
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
| 1123 |
-
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
| 1124 |
|
| 1125 |
float slope = 1.0f;
|
| 1126 |
|
| 1127 |
-
if (max_bias > 0.0f) {
|
| 1128 |
const int64_t h = i02;
|
| 1129 |
|
| 1130 |
-
const float base = h < n_head_log2 ? m0 : m1;
|
| 1131 |
-
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 1132 |
|
| 1133 |
slope = pow(base, exp);
|
| 1134 |
}
|
|
@@ -1136,8 +1099,8 @@ kernel void kernel_soft_max_4(
|
|
| 1136 |
// parallel max
|
| 1137 |
float4 lmax4 = -INFINITY;
|
| 1138 |
|
| 1139 |
-
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 1140 |
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
| 1141 |
}
|
| 1142 |
|
| 1143 |
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
@@ -1162,8 +1125,8 @@ kernel void kernel_soft_max_4(
|
|
| 1162 |
|
| 1163 |
// parallel sum
|
| 1164 |
float4 lsum4 = 0.0f;
|
| 1165 |
-
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 1166 |
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
| 1167 |
lsum4 += exp_psrc4;
|
| 1168 |
pdst4[i00] = exp_psrc4;
|
| 1169 |
}
|
|
@@ -1195,7 +1158,7 @@ kernel void kernel_soft_max_4(
|
|
| 1195 |
|
| 1196 |
const float inv_sum = 1.0f/sum;
|
| 1197 |
|
| 1198 |
-
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 1199 |
pdst4[i00] *= inv_sum;
|
| 1200 |
}
|
| 1201 |
}
|
|
@@ -1211,27 +1174,23 @@ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kerne
|
|
| 1211 |
kernel void kernel_diag_mask_inf(
|
| 1212 |
device const float * src0,
|
| 1213 |
device float * dst,
|
| 1214 |
-
constant
|
| 1215 |
-
constant int64_t & ne01,
|
| 1216 |
-
constant int & n_past,
|
| 1217 |
uint3 tpig[[thread_position_in_grid]]) {
|
| 1218 |
const int64_t i02 = tpig[2];
|
| 1219 |
const int64_t i01 = tpig[1];
|
| 1220 |
const int64_t i00 = tpig[0];
|
| 1221 |
|
| 1222 |
-
if (i00 > n_past + i01) {
|
| 1223 |
-
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
| 1224 |
} else {
|
| 1225 |
-
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
| 1226 |
}
|
| 1227 |
}
|
| 1228 |
|
| 1229 |
kernel void kernel_diag_mask_inf_8(
|
| 1230 |
device const float4 * src0,
|
| 1231 |
device float4 * dst,
|
| 1232 |
-
constant
|
| 1233 |
-
constant int64_t & ne01,
|
| 1234 |
-
constant int & n_past,
|
| 1235 |
uint3 tpig[[thread_position_in_grid]]) {
|
| 1236 |
|
| 1237 |
const int64_t i = 2*tpig[0];
|
|
@@ -1239,42 +1198,26 @@ kernel void kernel_diag_mask_inf_8(
|
|
| 1239 |
dst[i+0] = src0[i+0];
|
| 1240 |
dst[i+1] = src0[i+1];
|
| 1241 |
int64_t i4 = 4*i;
|
| 1242 |
-
const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
|
| 1243 |
-
const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
|
| 1244 |
const int64_t i00 = i4;
|
| 1245 |
for (int k = 3; k >= 0; --k) {
|
| 1246 |
-
if (i00 + 4 + k <= n_past + i01) {
|
| 1247 |
break;
|
| 1248 |
}
|
| 1249 |
dst[i+1][k] = -INFINITY;
|
| 1250 |
-
if (i00 + k > n_past + i01) {
|
| 1251 |
dst[i][k] = -INFINITY;
|
| 1252 |
}
|
| 1253 |
}
|
| 1254 |
}
|
| 1255 |
|
| 1256 |
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
|
| 1257 |
-
// TODO: optimize
|
| 1258 |
kernel void kernel_ssm_conv_f32(
|
| 1259 |
device const void * src0,
|
| 1260 |
device const void * src1,
|
| 1261 |
device float * dst,
|
| 1262 |
-
constant
|
| 1263 |
-
constant int64_t & ne01,
|
| 1264 |
-
constant int64_t & ne02,
|
| 1265 |
-
constant uint64_t & nb00,
|
| 1266 |
-
constant uint64_t & nb01,
|
| 1267 |
-
constant uint64_t & nb02,
|
| 1268 |
-
constant int64_t & ne10,
|
| 1269 |
-
constant int64_t & ne11,
|
| 1270 |
-
constant uint64_t & nb10,
|
| 1271 |
-
constant uint64_t & nb11,
|
| 1272 |
-
constant int64_t & ne0,
|
| 1273 |
-
constant int64_t & ne1,
|
| 1274 |
-
constant int64_t & ne2,
|
| 1275 |
-
constant uint64_t & nb0,
|
| 1276 |
-
constant uint64_t & nb1,
|
| 1277 |
-
constant uint64_t & nb2,
|
| 1278 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1279 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1280 |
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -1282,15 +1225,15 @@ kernel void kernel_ssm_conv_f32(
|
|
| 1282 |
const int64_t i2 = tgpig.y;
|
| 1283 |
const int64_t i3 = tgpig.z;
|
| 1284 |
|
| 1285 |
-
const int64_t nc = ne10;
|
| 1286 |
-
//const int64_t ncs = ne00;
|
| 1287 |
-
//const int64_t nr = ne01;
|
| 1288 |
-
//const int64_t n_t = ne1;
|
| 1289 |
-
//const int64_t n_s = ne2;
|
| 1290 |
|
| 1291 |
-
device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
|
| 1292 |
-
device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
|
| 1293 |
-
device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
|
| 1294 |
|
| 1295 |
float sumf = 0.0f;
|
| 1296 |
|
|
@@ -1302,7 +1245,6 @@ kernel void kernel_ssm_conv_f32(
|
|
| 1302 |
}
|
| 1303 |
|
| 1304 |
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
|
| 1305 |
-
// TODO: optimize
|
| 1306 |
kernel void kernel_ssm_scan_f32(
|
| 1307 |
device const void * src0,
|
| 1308 |
device const void * src1,
|
|
@@ -1311,48 +1253,27 @@ kernel void kernel_ssm_scan_f32(
|
|
| 1311 |
device const void * src4,
|
| 1312 |
device const void * src5,
|
| 1313 |
device float * dst,
|
| 1314 |
-
constant
|
| 1315 |
-
constant int64_t & d_inner,
|
| 1316 |
-
constant int64_t & n_seq_tokens,
|
| 1317 |
-
constant int64_t & n_seqs,
|
| 1318 |
-
constant uint64_t & nb00,
|
| 1319 |
-
constant uint64_t & nb01,
|
| 1320 |
-
constant uint64_t & nb02,
|
| 1321 |
-
constant uint64_t & nb10,
|
| 1322 |
-
constant uint64_t & nb11,
|
| 1323 |
-
constant uint64_t & nb12,
|
| 1324 |
-
constant uint64_t & nb13,
|
| 1325 |
-
constant uint64_t & nb20,
|
| 1326 |
-
constant uint64_t & nb21,
|
| 1327 |
-
constant uint64_t & nb22,
|
| 1328 |
-
constant uint64_t & nb30,
|
| 1329 |
-
constant uint64_t & nb31,
|
| 1330 |
-
constant uint64_t & nb40,
|
| 1331 |
-
constant uint64_t & nb41,
|
| 1332 |
-
constant uint64_t & nb42,
|
| 1333 |
-
constant uint64_t & nb50,
|
| 1334 |
-
constant uint64_t & nb51,
|
| 1335 |
-
constant uint64_t & nb52,
|
| 1336 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1337 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1338 |
uint3 ntg[[threads_per_threadgroup]]) {
|
| 1339 |
const int64_t ir = tgpig.x;
|
| 1340 |
const int64_t i3 = tgpig.y;
|
| 1341 |
|
| 1342 |
-
const int64_t nc = d_state;
|
| 1343 |
-
|
| 1344 |
-
const int64_t n_t = n_seq_tokens;
|
| 1345 |
-
|
| 1346 |
|
| 1347 |
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
| 1348 |
-
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
|
| 1349 |
-
device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
|
| 1350 |
-
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
|
| 1351 |
-
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
|
| 1352 |
-
device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
|
| 1353 |
-
device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
|
| 1354 |
-
device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
|
| 1355 |
-
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
|
| 1356 |
|
| 1357 |
if (i2 > 0) {
|
| 1358 |
s0 = s;
|
|
@@ -1545,22 +1466,15 @@ kernel void kernel_rms_norm(
|
|
| 1545 |
kernel void kernel_group_norm(
|
| 1546 |
device const float * src0,
|
| 1547 |
device float * dst,
|
| 1548 |
-
constant
|
| 1549 |
-
constant int64_t & ne01,
|
| 1550 |
-
constant int64_t & ne02,
|
| 1551 |
-
constant uint64_t & nb00,
|
| 1552 |
-
constant uint64_t & nb01,
|
| 1553 |
-
constant uint64_t & nb02,
|
| 1554 |
-
constant int32_t & n_groups,
|
| 1555 |
-
constant float & eps,
|
| 1556 |
threadgroup float * buf [[threadgroup(0)]],
|
| 1557 |
uint tgpig[[threadgroup_position_in_grid]],
|
| 1558 |
uint tpitg[[thread_position_in_threadgroup]],
|
| 1559 |
uint sgitg[[simdgroup_index_in_threadgroup]],
|
| 1560 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1561 |
uint ntg[[threads_per_threadgroup]]) {
|
| 1562 |
-
const int64_t ne = ne00*ne01*ne02;
|
| 1563 |
-
const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
|
| 1564 |
|
| 1565 |
int start = tgpig * gs;
|
| 1566 |
int end = start + gs;
|
|
@@ -1624,7 +1538,7 @@ kernel void kernel_group_norm(
|
|
| 1624 |
}
|
| 1625 |
|
| 1626 |
const float variance = tmp / gs;
|
| 1627 |
-
const float scale = 1.0f/sqrt(variance + eps);
|
| 1628 |
for (int j = start; j < end; j += ntg) {
|
| 1629 |
dst[j] *= scale;
|
| 1630 |
}
|
|
@@ -2588,17 +2502,7 @@ template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_
|
|
| 2588 |
typedef void (im2col_t)(
|
| 2589 |
device const float * x,
|
| 2590 |
device char * dst,
|
| 2591 |
-
constant
|
| 2592 |
-
constant int32_t & ofs1,
|
| 2593 |
-
constant int32_t & IW,
|
| 2594 |
-
constant int32_t & IH,
|
| 2595 |
-
constant int32_t & CHW,
|
| 2596 |
-
constant int32_t & s0,
|
| 2597 |
-
constant int32_t & s1,
|
| 2598 |
-
constant int32_t & p0,
|
| 2599 |
-
constant int32_t & p1,
|
| 2600 |
-
constant int32_t & d0,
|
| 2601 |
-
constant int32_t & d1,
|
| 2602 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2603 |
uint3 tgpg[[threadgroups_per_grid]],
|
| 2604 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -2608,17 +2512,7 @@ template <typename T>
|
|
| 2608 |
kernel void kernel_im2col(
|
| 2609 |
device const float * x,
|
| 2610 |
device char * dst,
|
| 2611 |
-
constant
|
| 2612 |
-
constant int32_t & ofs1,
|
| 2613 |
-
constant int32_t & IW,
|
| 2614 |
-
constant int32_t & IH,
|
| 2615 |
-
constant int32_t & CHW,
|
| 2616 |
-
constant int32_t & s0,
|
| 2617 |
-
constant int32_t & s1,
|
| 2618 |
-
constant int32_t & p0,
|
| 2619 |
-
constant int32_t & p1,
|
| 2620 |
-
constant int32_t & d0,
|
| 2621 |
-
constant int32_t & d1,
|
| 2622 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2623 |
uint3 tgpg[[threadgroups_per_grid]],
|
| 2624 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -2639,17 +2533,17 @@ kernel void kernel_im2col(
|
|
| 2639 |
const int64_t ioh = tgpig[1];
|
| 2640 |
const int64_t iow = tgpig[2];
|
| 2641 |
|
| 2642 |
-
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
| 2643 |
-
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
| 2644 |
|
| 2645 |
-
const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);
|
| 2646 |
|
| 2647 |
device T * pdst = (device T *) (dst);
|
| 2648 |
|
| 2649 |
-
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
| 2650 |
pdst[offset_dst] = 0.0f;
|
| 2651 |
} else {
|
| 2652 |
-
const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw;
|
| 2653 |
pdst[offset_dst] = x[offset_src];
|
| 2654 |
}
|
| 2655 |
}
|
|
@@ -2660,20 +2554,7 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
|
| 2660 |
typedef void (im2col_ext_t)(
|
| 2661 |
device const float * x,
|
| 2662 |
device char * dst,
|
| 2663 |
-
constant
|
| 2664 |
-
constant int32_t & ofs1,
|
| 2665 |
-
constant int32_t & IW,
|
| 2666 |
-
constant int32_t & IH,
|
| 2667 |
-
constant int32_t & CHW,
|
| 2668 |
-
constant int32_t & s0,
|
| 2669 |
-
constant int32_t & s1,
|
| 2670 |
-
constant int32_t & p0,
|
| 2671 |
-
constant int32_t & p1,
|
| 2672 |
-
constant int32_t & d0,
|
| 2673 |
-
constant int32_t & d1,
|
| 2674 |
-
constant int32_t & N,
|
| 2675 |
-
constant int32_t & KH,
|
| 2676 |
-
constant int32_t & KW,
|
| 2677 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2678 |
uint3 tgpg[[threadgroups_per_grid]],
|
| 2679 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -2683,53 +2564,40 @@ template <typename T>
|
|
| 2683 |
kernel void kernel_im2col_ext(
|
| 2684 |
device const float * x,
|
| 2685 |
device char * dst,
|
| 2686 |
-
constant
|
| 2687 |
-
constant int32_t & ofs1,
|
| 2688 |
-
constant int32_t & IW,
|
| 2689 |
-
constant int32_t & IH,
|
| 2690 |
-
constant int32_t & CHW,
|
| 2691 |
-
constant int32_t & s0,
|
| 2692 |
-
constant int32_t & s1,
|
| 2693 |
-
constant int32_t & p0,
|
| 2694 |
-
constant int32_t & p1,
|
| 2695 |
-
constant int32_t & d0,
|
| 2696 |
-
constant int32_t & d1,
|
| 2697 |
-
constant int32_t & N,
|
| 2698 |
-
constant int32_t & KH,
|
| 2699 |
-
constant int32_t & KW,
|
| 2700 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2701 |
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
|
| 2702 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2703 |
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
|
| 2704 |
-
const int64_t KHW =
|
| 2705 |
|
| 2706 |
-
const int64_t d = tgpig[0] / CHW;
|
| 2707 |
-
const int64_t chw = tgpig[0] % CHW;
|
| 2708 |
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
|
| 2709 |
const int64_t HW = tgpig[0] % KHW;
|
| 2710 |
|
| 2711 |
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
|
| 2712 |
-
if (tpitg_0 >= N) {
|
| 2713 |
return;
|
| 2714 |
}
|
| 2715 |
|
| 2716 |
-
const int64_t tpitg_1 = HW / KW;
|
| 2717 |
-
const int64_t tpitg_2 = HW % KW;
|
| 2718 |
|
| 2719 |
-
const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
|
| 2720 |
-
const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
|
| 2721 |
|
| 2722 |
const int64_t offset_dst =
|
| 2723 |
-
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
| 2724 |
-
(tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
|
| 2725 |
|
| 2726 |
device T * pdst = (device T *) (dst);
|
| 2727 |
|
| 2728 |
-
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
| 2729 |
pdst[offset_dst] = 0.0f;
|
| 2730 |
} else {
|
| 2731 |
-
const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
|
| 2732 |
-
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
|
| 2733 |
}
|
| 2734 |
}
|
| 2735 |
|
|
@@ -2740,12 +2608,7 @@ typedef void (conv_transpose_1d_t)(
|
|
| 2740 |
device const float * src0,
|
| 2741 |
device const float * src1,
|
| 2742 |
device char * dst,
|
| 2743 |
-
constant
|
| 2744 |
-
constant int32_t & IL,
|
| 2745 |
-
constant int32_t & K,
|
| 2746 |
-
constant int32_t & s0,
|
| 2747 |
-
constant uint64_t & nb0,
|
| 2748 |
-
constant uint64_t & nb1,
|
| 2749 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2750 |
uint3 tgpg[[threadgroups_per_grid]]);
|
| 2751 |
|
|
@@ -2754,29 +2617,24 @@ kernel void kernel_conv_transpose_1d(
|
|
| 2754 |
device const T * src0,
|
| 2755 |
device const float * src1,
|
| 2756 |
device char * dst,
|
| 2757 |
-
constant
|
| 2758 |
-
constant int32_t & IL,
|
| 2759 |
-
constant int32_t & K,
|
| 2760 |
-
constant int32_t & s0,
|
| 2761 |
-
constant uint64_t & nb0,
|
| 2762 |
-
constant uint64_t & nb1,
|
| 2763 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2764 |
uint3 tgpg[[threadgroups_per_grid]]) {
|
| 2765 |
|
| 2766 |
float v = 0.0f;
|
| 2767 |
|
| 2768 |
-
for (int64_t c = 0; c < IC; c++) {
|
| 2769 |
-
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
|
| 2770 |
-
const int32_t input_offset = c * IL;
|
| 2771 |
|
| 2772 |
-
for (int64_t i = 0; i < IL; i++) {
|
| 2773 |
-
if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
|
| 2774 |
-
v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
|
| 2775 |
}
|
| 2776 |
}
|
| 2777 |
}
|
| 2778 |
|
| 2779 |
-
device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
|
| 2780 |
|
| 2781 |
dst_ptr[0] = v;
|
| 2782 |
}
|
|
@@ -2786,12 +2644,7 @@ kernel void kernel_conv_transpose_1d<float>(
|
|
| 2786 |
device const float * src0,
|
| 2787 |
device const float * src1,
|
| 2788 |
device char * dst,
|
| 2789 |
-
constant
|
| 2790 |
-
constant int32_t & IL,
|
| 2791 |
-
constant int32_t & K,
|
| 2792 |
-
constant int32_t & s0,
|
| 2793 |
-
constant uint64_t & nb0,
|
| 2794 |
-
constant uint64_t & nb1,
|
| 2795 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2796 |
uint3 tgpg[[threadgroups_per_grid]]);
|
| 2797 |
|
|
@@ -2800,38 +2653,14 @@ kernel void kernel_conv_transpose_1d<half>(
|
|
| 2800 |
device const half * src0,
|
| 2801 |
device const float * src1,
|
| 2802 |
device char * dst,
|
| 2803 |
-
constant
|
| 2804 |
-
constant int32_t & IL,
|
| 2805 |
-
constant int32_t & K,
|
| 2806 |
-
constant int32_t & s0,
|
| 2807 |
-
constant uint64_t & nb0,
|
| 2808 |
-
constant uint64_t & nb1,
|
| 2809 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2810 |
uint3 tgpg[[threadgroups_per_grid]]);
|
| 2811 |
|
| 2812 |
kernel void kernel_upscale_f32(
|
| 2813 |
device const char * src0,
|
| 2814 |
device char * dst,
|
| 2815 |
-
constant
|
| 2816 |
-
constant int64_t & ne01,
|
| 2817 |
-
constant int64_t & ne02,
|
| 2818 |
-
constant int64_t & ne03,
|
| 2819 |
-
constant uint64_t & nb00,
|
| 2820 |
-
constant uint64_t & nb01,
|
| 2821 |
-
constant uint64_t & nb02,
|
| 2822 |
-
constant uint64_t & nb03,
|
| 2823 |
-
constant int64_t & ne0,
|
| 2824 |
-
constant int64_t & ne1,
|
| 2825 |
-
constant int64_t & ne2,
|
| 2826 |
-
constant int64_t & ne3,
|
| 2827 |
-
constant uint64_t & nb0,
|
| 2828 |
-
constant uint64_t & nb1,
|
| 2829 |
-
constant uint64_t & nb2,
|
| 2830 |
-
constant uint64_t & nb3,
|
| 2831 |
-
constant float & sf0,
|
| 2832 |
-
constant float & sf1,
|
| 2833 |
-
constant float & sf2,
|
| 2834 |
-
constant float & sf3,
|
| 2835 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2836 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2837 |
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -2840,15 +2669,15 @@ kernel void kernel_upscale_f32(
|
|
| 2840 |
const int64_t i2 = tgpig.y;
|
| 2841 |
const int64_t i1 = tgpig.x;
|
| 2842 |
|
| 2843 |
-
const int64_t i03 = i3/sf3;
|
| 2844 |
-
const int64_t i02 = i2/sf2;
|
| 2845 |
-
const int64_t i01 = i1/sf1;
|
| 2846 |
|
| 2847 |
-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
| 2848 |
-
const int64_t i00 = i0/sf0;
|
| 2849 |
|
| 2850 |
-
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
| 2851 |
-
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 2852 |
|
| 2853 |
dst_ptr[0] = src0_ptr[0];
|
| 2854 |
}
|
|
@@ -2857,22 +2686,7 @@ kernel void kernel_upscale_f32(
|
|
| 2857 |
kernel void kernel_pad_f32(
|
| 2858 |
device const char * src0,
|
| 2859 |
device char * dst,
|
| 2860 |
-
constant
|
| 2861 |
-
constant int64_t & ne01,
|
| 2862 |
-
constant int64_t & ne02,
|
| 2863 |
-
constant int64_t & ne03,
|
| 2864 |
-
constant uint64_t & nb00,
|
| 2865 |
-
constant uint64_t & nb01,
|
| 2866 |
-
constant uint64_t & nb02,
|
| 2867 |
-
constant uint64_t & nb03,
|
| 2868 |
-
constant int64_t & ne0,
|
| 2869 |
-
constant int64_t & ne1,
|
| 2870 |
-
constant int64_t & ne2,
|
| 2871 |
-
constant int64_t & ne3,
|
| 2872 |
-
constant uint64_t & nb0,
|
| 2873 |
-
constant uint64_t & nb1,
|
| 2874 |
-
constant uint64_t & nb2,
|
| 2875 |
-
constant uint64_t & nb3,
|
| 2876 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2877 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2878 |
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -2885,12 +2699,12 @@ kernel void kernel_pad_f32(
|
|
| 2885 |
const int64_t i02 = i2;
|
| 2886 |
const int64_t i01 = i1;
|
| 2887 |
|
| 2888 |
-
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
| 2889 |
-
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
| 2890 |
|
| 2891 |
-
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
| 2892 |
-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
| 2893 |
-
if (i0 < ne00) {
|
| 2894 |
dst_ptr[i0] = src0_ptr[i0];
|
| 2895 |
} else {
|
| 2896 |
dst_ptr[i0] = 0.0f;
|
|
@@ -2900,7 +2714,7 @@ kernel void kernel_pad_f32(
|
|
| 2900 |
return;
|
| 2901 |
}
|
| 2902 |
|
| 2903 |
-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
| 2904 |
dst_ptr[i0] = 0.0f;
|
| 2905 |
}
|
| 2906 |
}
|
|
@@ -2908,21 +2722,7 @@ kernel void kernel_pad_f32(
|
|
| 2908 |
kernel void kernel_pad_reflect_1d_f32(
|
| 2909 |
device const char * src0,
|
| 2910 |
device char * dst,
|
| 2911 |
-
constant
|
| 2912 |
-
constant int64_t & ne01,
|
| 2913 |
-
constant int64_t & ne02,
|
| 2914 |
-
constant int64_t & ne03,
|
| 2915 |
-
constant int64_t & ne0,
|
| 2916 |
-
constant uint64_t & nb00,
|
| 2917 |
-
constant uint64_t & nb01,
|
| 2918 |
-
constant uint64_t & nb02,
|
| 2919 |
-
constant uint64_t & nb03,
|
| 2920 |
-
constant uint64_t & nb0,
|
| 2921 |
-
constant uint64_t & nb1,
|
| 2922 |
-
constant uint64_t & nb2,
|
| 2923 |
-
constant uint64_t & nb3,
|
| 2924 |
-
constant int32_t & p0,
|
| 2925 |
-
constant int32_t & p1,
|
| 2926 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2927 |
uint3 tgpg[[threadgroups_per_grid]],
|
| 2928 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -2936,17 +2736,17 @@ kernel void kernel_pad_reflect_1d_f32(
|
|
| 2936 |
const int64_t i02 = i2;
|
| 2937 |
const int64_t i01 = i1;
|
| 2938 |
|
| 2939 |
-
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
| 2940 |
-
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
| 2941 |
|
| 2942 |
-
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
| 2943 |
-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
| 2944 |
-
if (i0 < p0) {
|
| 2945 |
-
dst_ptr[i0] = src0_ptr[p0 - i0];
|
| 2946 |
-
} else if (i0 < ne0 - p1) {
|
| 2947 |
-
dst_ptr[i0] = src0_ptr[i0 - p0];
|
| 2948 |
} else {
|
| 2949 |
-
dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
|
| 2950 |
}
|
| 2951 |
}
|
| 2952 |
}
|
|
@@ -2954,44 +2754,40 @@ kernel void kernel_pad_reflect_1d_f32(
|
|
| 2954 |
|
| 2955 |
kernel void kernel_arange_f32(
|
| 2956 |
device char * dst,
|
| 2957 |
-
constant
|
| 2958 |
-
constant float & start,
|
| 2959 |
-
constant float & step,
|
| 2960 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2961 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2962 |
uint3 ntg[[threads_per_threadgroup]]) {
|
| 2963 |
|
| 2964 |
device float * dst_ptr = (device float *) dst;
|
| 2965 |
|
| 2966 |
-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
| 2967 |
-
dst_ptr[i0] = start + step * i0;
|
| 2968 |
}
|
| 2969 |
}
|
| 2970 |
|
| 2971 |
kernel void kernel_timestep_embedding_f32(
|
| 2972 |
device const char * src0,
|
| 2973 |
device char * dst,
|
| 2974 |
-
constant
|
| 2975 |
-
constant int & dim,
|
| 2976 |
-
constant int & max_period,
|
| 2977 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2978 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2979 |
uint3 ntg[[threads_per_threadgroup]]) {
|
| 2980 |
|
| 2981 |
int i = tgpig.x;
|
| 2982 |
-
device float * embed_data = (device float *)(dst +
|
| 2983 |
|
| 2984 |
-
int half_ = dim / 2;
|
| 2985 |
for (int j = tpitg.x; j < half_; j += ntg.x) {
|
| 2986 |
float timestep = ((device float *)src0)[i];
|
| 2987 |
-
float freq = (float)exp(-log((float)max_period) * j / half_);
|
| 2988 |
float arg = timestep * freq;
|
| 2989 |
embed_data[j ] = cos(arg);
|
| 2990 |
embed_data[j + half_] = sin(arg);
|
| 2991 |
}
|
| 2992 |
|
| 2993 |
-
if (dim % 2 != 0 && tpitg.x == 0) {
|
| 2994 |
-
embed_data[dim] = 0.f;
|
| 2995 |
}
|
| 2996 |
}
|
| 2997 |
|
|
@@ -2999,8 +2795,7 @@ kernel void kernel_timestep_embedding_f32(
|
|
| 2999 |
typedef void (argsort_t)(
|
| 3000 |
device const float * x,
|
| 3001 |
device int32_t * dst,
|
| 3002 |
-
constant
|
| 3003 |
-
constant int64_t & ncols_pad,
|
| 3004 |
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
| 3005 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3006 |
uint3 tpitg[[thread_position_in_threadgroup]]);
|
|
@@ -3009,8 +2804,7 @@ template<ggml_sort_order order>
|
|
| 3009 |
kernel void kernel_argsort_f32_i32(
|
| 3010 |
device const float * x,
|
| 3011 |
device int32_t * dst,
|
| 3012 |
-
constant
|
| 3013 |
-
constant int64_t & ncols_pad,
|
| 3014 |
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
| 3015 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3016 |
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
|
@@ -3018,9 +2812,9 @@ kernel void kernel_argsort_f32_i32(
|
|
| 3018 |
int col = tpitg[0];
|
| 3019 |
int row = tgpig[1];
|
| 3020 |
|
| 3021 |
-
if (col >= ncols_pad) return;
|
| 3022 |
|
| 3023 |
-
device const float * x_row = x + row * ncols;
|
| 3024 |
threadgroup int32_t * dst_row = shared_values;
|
| 3025 |
|
| 3026 |
// initialize indices
|
|
@@ -3028,21 +2822,21 @@ kernel void kernel_argsort_f32_i32(
|
|
| 3028 |
|
| 3029 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3030 |
|
| 3031 |
-
for (int k = 2; k <= ncols_pad; k *= 2) {
|
| 3032 |
for (int j = k / 2; j > 0; j /= 2) {
|
| 3033 |
int ixj = col ^ j;
|
| 3034 |
if (ixj > col) {
|
| 3035 |
if ((col & k) == 0) {
|
| 3036 |
-
if (dst_row[col] >= ncols ||
|
| 3037 |
-
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
| 3038 |
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
| 3039 |
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
| 3040 |
) {
|
| 3041 |
SWAP(dst_row[col], dst_row[ixj]);
|
| 3042 |
}
|
| 3043 |
} else {
|
| 3044 |
-
if (dst_row[ixj] >= ncols ||
|
| 3045 |
-
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
| 3046 |
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
| 3047 |
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
| 3048 |
) {
|
|
@@ -3055,8 +2849,8 @@ kernel void kernel_argsort_f32_i32(
|
|
| 3055 |
}
|
| 3056 |
|
| 3057 |
// copy the result to dst without the padding
|
| 3058 |
-
if (col < ncols) {
|
| 3059 |
-
dst[row * ncols + col] = dst_row[col];
|
| 3060 |
}
|
| 3061 |
}
|
| 3062 |
|
|
@@ -3066,9 +2860,9 @@ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_ar
|
|
| 3066 |
kernel void kernel_leaky_relu_f32(
|
| 3067 |
device const float * src0,
|
| 3068 |
device float * dst,
|
| 3069 |
-
constant
|
| 3070 |
uint tpig[[thread_position_in_grid]]) {
|
| 3071 |
-
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
| 3072 |
}
|
| 3073 |
|
| 3074 |
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
|
@@ -6009,28 +5803,21 @@ kernel void kernel_get_rows_q(
|
|
| 6009 |
device const void * src0,
|
| 6010 |
device const void * src1,
|
| 6011 |
device float * dst,
|
| 6012 |
-
constant
|
| 6013 |
-
constant uint64_t & nb01,
|
| 6014 |
-
constant uint64_t & nb02,
|
| 6015 |
-
constant int64_t & ne10,
|
| 6016 |
-
constant uint64_t & nb10,
|
| 6017 |
-
constant uint64_t & nb11,
|
| 6018 |
-
constant uint64_t & nb1,
|
| 6019 |
-
constant uint64_t & nb2,
|
| 6020 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6021 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6022 |
uint3 tptg [[threads_per_threadgroup]]) {
|
| 6023 |
const int64_t i10 = tgpig.x;
|
| 6024 |
const int64_t i11 = tgpig.y;
|
| 6025 |
|
| 6026 |
-
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
| 6027 |
|
| 6028 |
const int64_t i02 = i11;
|
| 6029 |
|
| 6030 |
-
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
| 6031 |
float4x4 temp;
|
| 6032 |
-
dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
| 6033 |
-
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
| 6034 |
}
|
| 6035 |
}
|
| 6036 |
|
|
@@ -6039,27 +5826,20 @@ kernel void kernel_get_rows_f(
|
|
| 6039 |
device const void * src0,
|
| 6040 |
device const void * src1,
|
| 6041 |
device float * dst,
|
| 6042 |
-
constant
|
| 6043 |
-
constant uint64_t & nb01,
|
| 6044 |
-
constant uint64_t & nb02,
|
| 6045 |
-
constant int64_t & ne10,
|
| 6046 |
-
constant uint64_t & nb10,
|
| 6047 |
-
constant uint64_t & nb11,
|
| 6048 |
-
constant uint64_t & nb1,
|
| 6049 |
-
constant uint64_t & nb2,
|
| 6050 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6051 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6052 |
uint3 tptg [[threads_per_threadgroup]]) {
|
| 6053 |
const int64_t i10 = tgpig.x;
|
| 6054 |
const int64_t i11 = tgpig.y;
|
| 6055 |
|
| 6056 |
-
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
| 6057 |
|
| 6058 |
const int64_t i02 = i11;
|
| 6059 |
|
| 6060 |
-
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
| 6061 |
-
(( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
| 6062 |
-
((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
|
| 6063 |
}
|
| 6064 |
}
|
| 6065 |
|
|
@@ -6067,27 +5847,20 @@ kernel void kernel_get_rows_i32(
|
|
| 6067 |
device const void * src0,
|
| 6068 |
device const void * src1,
|
| 6069 |
device int32_t * dst,
|
| 6070 |
-
constant
|
| 6071 |
-
constant uint64_t & nb01,
|
| 6072 |
-
constant uint64_t & nb02,
|
| 6073 |
-
constant int64_t & ne10,
|
| 6074 |
-
constant uint64_t & nb10,
|
| 6075 |
-
constant uint64_t & nb11,
|
| 6076 |
-
constant uint64_t & nb1,
|
| 6077 |
-
constant uint64_t & nb2,
|
| 6078 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6079 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6080 |
uint3 tptg [[threads_per_threadgroup]]) {
|
| 6081 |
const int64_t i10 = tgpig.x;
|
| 6082 |
const int64_t i11 = tgpig.y;
|
| 6083 |
|
| 6084 |
-
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
| 6085 |
|
| 6086 |
const int64_t i02 = i11;
|
| 6087 |
|
| 6088 |
-
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
| 6089 |
-
(( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
| 6090 |
-
((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
|
| 6091 |
}
|
| 6092 |
}
|
| 6093 |
|
|
@@ -6689,98 +6462,78 @@ template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t
|
|
| 6689 |
kernel void kernel_pool_2d_max_f32(
|
| 6690 |
device const float * src0,
|
| 6691 |
device float * dst,
|
| 6692 |
-
constant
|
| 6693 |
-
constant int32_t & k1,
|
| 6694 |
-
constant int32_t & s0,
|
| 6695 |
-
constant int32_t & s1,
|
| 6696 |
-
constant int32_t & p0,
|
| 6697 |
-
constant int32_t & p1,
|
| 6698 |
-
constant int64_t & IH,
|
| 6699 |
-
constant int64_t & IW,
|
| 6700 |
-
constant int64_t & OH,
|
| 6701 |
-
constant int64_t & OW,
|
| 6702 |
-
constant int64_t & parallel_elements,
|
| 6703 |
uint gid[[thread_position_in_grid]]) {
|
| 6704 |
|
| 6705 |
-
if (gid >= parallel_elements) {
|
| 6706 |
return;
|
| 6707 |
}
|
| 6708 |
|
| 6709 |
const int idx = gid;
|
| 6710 |
-
const int I_HW = IH * IW;
|
| 6711 |
-
const int O_HW = OH * OW;
|
| 6712 |
const int nc = idx / O_HW;
|
| 6713 |
-
const int cur_oh = idx % O_HW / OW;
|
| 6714 |
-
const int cur_ow = idx % O_HW % OW;
|
| 6715 |
|
| 6716 |
device const float * i_ptr = src0 + nc * I_HW;
|
| 6717 |
device float * o_ptr = dst + nc * O_HW;
|
| 6718 |
|
| 6719 |
-
const int start_h = cur_oh * s1 - p1;
|
| 6720 |
const int bh = MAX(0, start_h);
|
| 6721 |
-
const int eh = MIN(IH, start_h + k1);
|
| 6722 |
-
const int start_w = cur_ow * s0 - p0;
|
| 6723 |
const int bw = MAX(0, start_w);
|
| 6724 |
-
const int ew = MIN(IW, start_w + k0);
|
| 6725 |
|
| 6726 |
float res = -INFINITY;
|
| 6727 |
|
| 6728 |
for (int i = bh; i < eh; i += 1) {
|
| 6729 |
for (int j = bw; j < ew; j += 1) {
|
| 6730 |
-
res = MAX(res, i_ptr[i * IW + j]);
|
| 6731 |
}
|
| 6732 |
}
|
| 6733 |
|
| 6734 |
-
o_ptr[cur_oh * OW + cur_ow] = res;
|
| 6735 |
}
|
| 6736 |
|
| 6737 |
kernel void kernel_pool_2d_avg_f32(
|
| 6738 |
device const float * src0,
|
| 6739 |
device float * dst,
|
| 6740 |
-
constant
|
| 6741 |
-
constant int32_t & k1,
|
| 6742 |
-
constant int32_t & s0,
|
| 6743 |
-
constant int32_t & s1,
|
| 6744 |
-
constant int32_t & p0,
|
| 6745 |
-
constant int32_t & p1,
|
| 6746 |
-
constant int64_t & IH,
|
| 6747 |
-
constant int64_t & IW,
|
| 6748 |
-
constant int64_t & OH,
|
| 6749 |
-
constant int64_t & OW,
|
| 6750 |
-
constant int64_t & parallel_elements,
|
| 6751 |
uint gid[[thread_position_in_grid]]) {
|
| 6752 |
|
| 6753 |
-
if (gid >= parallel_elements) {
|
| 6754 |
return;
|
| 6755 |
}
|
| 6756 |
|
| 6757 |
const int idx = gid;
|
| 6758 |
-
const int I_HW = IH * IW;
|
| 6759 |
-
const int O_HW = OH * OW;
|
| 6760 |
const int nc = idx / O_HW;
|
| 6761 |
-
const int cur_oh = idx % O_HW / OW;
|
| 6762 |
-
const int cur_ow = idx % O_HW % OW;
|
| 6763 |
|
| 6764 |
device const float * i_ptr = src0 + nc * I_HW;
|
| 6765 |
device float * o_ptr = dst + nc * O_HW;
|
| 6766 |
|
| 6767 |
-
const int start_h = cur_oh * s1 - p1;
|
| 6768 |
const int bh = MAX(0, start_h);
|
| 6769 |
-
const int eh = MIN(IH, start_h + k1);
|
| 6770 |
-
const int start_w = cur_ow * s0 - p0;
|
| 6771 |
const int bw = MAX(0, start_w);
|
| 6772 |
-
const int ew = MIN(IW, start_w + k0);
|
| 6773 |
// const float scale = 1. / ((eh - bh) * (ew - bw));
|
| 6774 |
-
const float scale = 1. / (k0 * k1);
|
| 6775 |
|
| 6776 |
float res = 0;
|
| 6777 |
|
| 6778 |
for (int i = bh; i < eh; i += 1) {
|
| 6779 |
for (int j = bw; j < ew; j += 1) {
|
| 6780 |
-
float cur = i_ptr[i * IW + j];
|
| 6781 |
res += cur * scale;
|
| 6782 |
}
|
| 6783 |
}
|
| 6784 |
|
| 6785 |
-
o_ptr[cur_oh * OW + cur_ow] = res;
|
| 6786 |
}
|
|
|
|
| 947 |
kernel void kernel_sum_rows(
|
| 948 |
device const float * src0,
|
| 949 |
device float * dst,
|
| 950 |
+
constant ggml_metal_kargs_sum_rows & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 951 |
uint3 tpig[[thread_position_in_grid]]) {
|
| 952 |
int64_t i3 = tpig.z;
|
| 953 |
int64_t i2 = tpig.y;
|
| 954 |
int64_t i1 = tpig.x;
|
| 955 |
|
| 956 |
+
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
| 957 |
return;
|
| 958 |
}
|
| 959 |
|
| 960 |
+
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
| 961 |
+
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
| 962 |
|
| 963 |
float row_sum = 0;
|
| 964 |
|
| 965 |
+
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
| 966 |
row_sum += src_row[i0];
|
| 967 |
}
|
| 968 |
|
|
|
|
| 974 |
device const char * src0,
|
| 975 |
device const char * src1,
|
| 976 |
device char * dst,
|
| 977 |
+
constant ggml_metal_kargs_soft_max & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 978 |
threadgroup float * buf [[threadgroup(0)]],
|
| 979 |
uint tgpig[[threadgroup_position_in_grid]],
|
| 980 |
uint tpitg[[thread_position_in_threadgroup]],
|
| 981 |
uint sgitg[[simdgroup_index_in_threadgroup]],
|
| 982 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 983 |
uint ntg[[threads_per_threadgroup]]) {
|
| 984 |
+
const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
|
| 985 |
+
const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
|
| 986 |
+
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
|
| 987 |
|
| 988 |
+
device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
|
| 989 |
+
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
|
| 990 |
+
device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
|
| 991 |
|
| 992 |
float slope = 1.0f;
|
| 993 |
|
| 994 |
// ALiBi
|
| 995 |
+
if (args.max_bias > 0.0f) {
|
| 996 |
const int64_t h = i02;
|
| 997 |
|
| 998 |
+
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
| 999 |
+
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
| 1000 |
|
| 1001 |
slope = pow(base, exp);
|
| 1002 |
}
|
|
|
|
| 1004 |
// parallel max
|
| 1005 |
float lmax = -INFINITY;
|
| 1006 |
|
| 1007 |
+
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
| 1008 |
+
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
|
| 1009 |
}
|
| 1010 |
|
| 1011 |
// find the max value in the block
|
|
|
|
| 1029 |
|
| 1030 |
// parallel sum
|
| 1031 |
float lsum = 0.0f;
|
| 1032 |
+
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
| 1033 |
+
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
| 1034 |
lsum += exp_psrc0;
|
| 1035 |
pdst[i00] = exp_psrc0;
|
| 1036 |
}
|
|
|
|
| 1060 |
|
| 1061 |
const float inv_sum = 1.0f/sum;
|
| 1062 |
|
| 1063 |
+
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
| 1064 |
pdst[i00] *= inv_sum;
|
| 1065 |
}
|
| 1066 |
}
|
|
|
|
| 1070 |
device const char * src0,
|
| 1071 |
device const char * src1,
|
| 1072 |
device char * dst,
|
| 1073 |
+
constant ggml_metal_kargs_soft_max & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1074 |
threadgroup float * buf [[threadgroup(0)]],
|
| 1075 |
uint tgpig[[threadgroup_position_in_grid]],
|
| 1076 |
uint tpitg[[thread_position_in_threadgroup]],
|
| 1077 |
uint sgitg[[simdgroup_index_in_threadgroup]],
|
| 1078 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1079 |
uint ntg[[threads_per_threadgroup]]) {
|
| 1080 |
+
const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
|
| 1081 |
+
const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
|
| 1082 |
+
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
|
| 1083 |
|
| 1084 |
+
device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
|
| 1085 |
+
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
|
| 1086 |
+
device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
|
| 1087 |
|
| 1088 |
float slope = 1.0f;
|
| 1089 |
|
| 1090 |
+
if (args.max_bias > 0.0f) {
|
| 1091 |
const int64_t h = i02;
|
| 1092 |
|
| 1093 |
+
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
| 1094 |
+
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
| 1095 |
|
| 1096 |
slope = pow(base, exp);
|
| 1097 |
}
|
|
|
|
| 1099 |
// parallel max
|
| 1100 |
float4 lmax4 = -INFINITY;
|
| 1101 |
|
| 1102 |
+
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
|
| 1103 |
+
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
| 1104 |
}
|
| 1105 |
|
| 1106 |
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
|
|
| 1125 |
|
| 1126 |
// parallel sum
|
| 1127 |
float4 lsum4 = 0.0f;
|
| 1128 |
+
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
|
| 1129 |
+
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
| 1130 |
lsum4 += exp_psrc4;
|
| 1131 |
pdst4[i00] = exp_psrc4;
|
| 1132 |
}
|
|
|
|
| 1158 |
|
| 1159 |
const float inv_sum = 1.0f/sum;
|
| 1160 |
|
| 1161 |
+
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
|
| 1162 |
pdst4[i00] *= inv_sum;
|
| 1163 |
}
|
| 1164 |
}
|
|
|
|
| 1174 |
kernel void kernel_diag_mask_inf(
|
| 1175 |
device const float * src0,
|
| 1176 |
device float * dst,
|
| 1177 |
+
constant ggml_metal_kargs_diag_mask_inf & args,
|
|
|
|
|
|
|
| 1178 |
uint3 tpig[[thread_position_in_grid]]) {
|
| 1179 |
const int64_t i02 = tpig[2];
|
| 1180 |
const int64_t i01 = tpig[1];
|
| 1181 |
const int64_t i00 = tpig[0];
|
| 1182 |
|
| 1183 |
+
if (i00 > args.n_past + i01) {
|
| 1184 |
+
dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY;
|
| 1185 |
} else {
|
| 1186 |
+
dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00];
|
| 1187 |
}
|
| 1188 |
}
|
| 1189 |
|
| 1190 |
kernel void kernel_diag_mask_inf_8(
|
| 1191 |
device const float4 * src0,
|
| 1192 |
device float4 * dst,
|
| 1193 |
+
constant ggml_metal_kargs_diag_mask_inf & args,
|
|
|
|
|
|
|
| 1194 |
uint3 tpig[[thread_position_in_grid]]) {
|
| 1195 |
|
| 1196 |
const int64_t i = 2*tpig[0];
|
|
|
|
| 1198 |
dst[i+0] = src0[i+0];
|
| 1199 |
dst[i+1] = src0[i+1];
|
| 1200 |
int64_t i4 = 4*i;
|
| 1201 |
+
const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01;
|
| 1202 |
+
const int64_t i01 = i4/(args.ne00); i4 -= i01*args.ne00;
|
| 1203 |
const int64_t i00 = i4;
|
| 1204 |
for (int k = 3; k >= 0; --k) {
|
| 1205 |
+
if (i00 + 4 + k <= args.n_past + i01) {
|
| 1206 |
break;
|
| 1207 |
}
|
| 1208 |
dst[i+1][k] = -INFINITY;
|
| 1209 |
+
if (i00 + k > args.n_past + i01) {
|
| 1210 |
dst[i][k] = -INFINITY;
|
| 1211 |
}
|
| 1212 |
}
|
| 1213 |
}
|
| 1214 |
|
| 1215 |
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
|
|
|
|
| 1216 |
kernel void kernel_ssm_conv_f32(
|
| 1217 |
device const void * src0,
|
| 1218 |
device const void * src1,
|
| 1219 |
device float * dst,
|
| 1220 |
+
constant ggml_metal_kargs_ssm_conv & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1221 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1222 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1223 |
uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
| 1225 |
const int64_t i2 = tgpig.y;
|
| 1226 |
const int64_t i3 = tgpig.z;
|
| 1227 |
|
| 1228 |
+
const int64_t nc = args.ne10;
|
| 1229 |
+
//const int64_t ncs = args.ne00;
|
| 1230 |
+
//const int64_t nr = args.ne01;
|
| 1231 |
+
//const int64_t n_t = args.ne1;
|
| 1232 |
+
//const int64_t n_s = args.ne2;
|
| 1233 |
|
| 1234 |
+
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
| 1235 |
+
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
|
| 1236 |
+
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
| 1237 |
|
| 1238 |
float sumf = 0.0f;
|
| 1239 |
|
|
|
|
| 1245 |
}
|
| 1246 |
|
| 1247 |
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
|
|
|
|
| 1248 |
kernel void kernel_ssm_scan_f32(
|
| 1249 |
device const void * src0,
|
| 1250 |
device const void * src1,
|
|
|
|
| 1253 |
device const void * src4,
|
| 1254 |
device const void * src5,
|
| 1255 |
device float * dst,
|
| 1256 |
+
constant ggml_metal_kargs_ssm_scan & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1257 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1258 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1259 |
uint3 ntg[[threads_per_threadgroup]]) {
|
| 1260 |
const int64_t ir = tgpig.x;
|
| 1261 |
const int64_t i3 = tgpig.y;
|
| 1262 |
|
| 1263 |
+
const int64_t nc = args.d_state;
|
| 1264 |
+
// const int64_t nr = args.d_inner;
|
| 1265 |
+
const int64_t n_t = args.n_seq_tokens;
|
| 1266 |
+
// const int64_t n_s = args.n_seqs;
|
| 1267 |
|
| 1268 |
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
| 1269 |
+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
|
| 1270 |
+
device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
|
| 1271 |
+
device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
|
| 1272 |
+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
|
| 1273 |
+
device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
|
| 1274 |
+
device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
|
| 1275 |
+
device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
|
| 1276 |
+
device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
|
| 1277 |
|
| 1278 |
if (i2 > 0) {
|
| 1279 |
s0 = s;
|
|
|
|
| 1466 |
kernel void kernel_group_norm(
|
| 1467 |
device const float * src0,
|
| 1468 |
device float * dst,
|
| 1469 |
+
constant ggml_metal_kargs_group_norm & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1470 |
threadgroup float * buf [[threadgroup(0)]],
|
| 1471 |
uint tgpig[[threadgroup_position_in_grid]],
|
| 1472 |
uint tpitg[[thread_position_in_threadgroup]],
|
| 1473 |
uint sgitg[[simdgroup_index_in_threadgroup]],
|
| 1474 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1475 |
uint ntg[[threads_per_threadgroup]]) {
|
| 1476 |
+
const int64_t ne = args.ne00*args.ne01*args.ne02;
|
| 1477 |
+
const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups);
|
| 1478 |
|
| 1479 |
int start = tgpig * gs;
|
| 1480 |
int end = start + gs;
|
|
|
|
| 1538 |
}
|
| 1539 |
|
| 1540 |
const float variance = tmp / gs;
|
| 1541 |
+
const float scale = 1.0f/sqrt(variance + args.eps);
|
| 1542 |
for (int j = start; j < end; j += ntg) {
|
| 1543 |
dst[j] *= scale;
|
| 1544 |
}
|
|
|
|
| 2502 |
typedef void (im2col_t)(
|
| 2503 |
device const float * x,
|
| 2504 |
device char * dst,
|
| 2505 |
+
constant ggml_metal_kargs_im2col & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2506 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2507 |
uint3 tgpg[[threadgroups_per_grid]],
|
| 2508 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
| 2512 |
kernel void kernel_im2col(
|
| 2513 |
device const float * x,
|
| 2514 |
device char * dst,
|
| 2515 |
+
constant ggml_metal_kargs_im2col & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2516 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2517 |
uint3 tgpg[[threadgroups_per_grid]],
|
| 2518 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
| 2533 |
const int64_t ioh = tgpig[1];
|
| 2534 |
const int64_t iow = tgpig[2];
|
| 2535 |
|
| 2536 |
+
const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
|
| 2537 |
+
const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
|
| 2538 |
|
| 2539 |
+
const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
|
| 2540 |
|
| 2541 |
device T * pdst = (device T *) (dst);
|
| 2542 |
|
| 2543 |
+
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
|
| 2544 |
pdst[offset_dst] = 0.0f;
|
| 2545 |
} else {
|
| 2546 |
+
const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
|
| 2547 |
pdst[offset_dst] = x[offset_src];
|
| 2548 |
}
|
| 2549 |
}
|
|
|
|
| 2554 |
typedef void (im2col_ext_t)(
|
| 2555 |
device const float * x,
|
| 2556 |
device char * dst,
|
| 2557 |
+
constant ggml_metal_kargs_im2col & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2558 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2559 |
uint3 tgpg[[threadgroups_per_grid]],
|
| 2560 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
| 2564 |
kernel void kernel_im2col_ext(
|
| 2565 |
device const float * x,
|
| 2566 |
device char * dst,
|
| 2567 |
+
constant ggml_metal_kargs_im2col & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2568 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2569 |
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
|
| 2570 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2571 |
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
|
| 2572 |
+
const int64_t KHW = (int64_t)args.KHW;
|
| 2573 |
|
| 2574 |
+
const int64_t d = tgpig[0] / args.CHW;
|
| 2575 |
+
const int64_t chw = tgpig[0] % args.CHW;
|
| 2576 |
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
|
| 2577 |
const int64_t HW = tgpig[0] % KHW;
|
| 2578 |
|
| 2579 |
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
|
| 2580 |
+
if (tpitg_0 >= args.N) {
|
| 2581 |
return;
|
| 2582 |
}
|
| 2583 |
|
| 2584 |
+
const int64_t tpitg_1 = HW / args.KW;
|
| 2585 |
+
const int64_t tpitg_2 = HW % args.KW;
|
| 2586 |
|
| 2587 |
+
const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
|
| 2588 |
+
const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
|
| 2589 |
|
| 2590 |
const int64_t offset_dst =
|
| 2591 |
+
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
|
| 2592 |
+
(tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
|
| 2593 |
|
| 2594 |
device T * pdst = (device T *) (dst);
|
| 2595 |
|
| 2596 |
+
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
|
| 2597 |
pdst[offset_dst] = 0.0f;
|
| 2598 |
} else {
|
| 2599 |
+
const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
|
| 2600 |
+
pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
|
| 2601 |
}
|
| 2602 |
}
|
| 2603 |
|
|
|
|
| 2608 |
device const float * src0,
|
| 2609 |
device const float * src1,
|
| 2610 |
device char * dst,
|
| 2611 |
+
constant ggml_metal_kargs_conv_transpose_1d & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2612 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2613 |
uint3 tgpg[[threadgroups_per_grid]]);
|
| 2614 |
|
|
|
|
| 2617 |
device const T * src0,
|
| 2618 |
device const float * src1,
|
| 2619 |
device char * dst,
|
| 2620 |
+
constant ggml_metal_kargs_conv_transpose_1d & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2621 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2622 |
uint3 tgpg[[threadgroups_per_grid]]) {
|
| 2623 |
|
| 2624 |
float v = 0.0f;
|
| 2625 |
|
| 2626 |
+
for (int64_t c = 0; c < args.IC; c++) {
|
| 2627 |
+
const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
|
| 2628 |
+
const int32_t input_offset = c * args.IL;
|
| 2629 |
|
| 2630 |
+
for (int64_t i = 0; i < args.IL; i++) {
|
| 2631 |
+
if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
|
| 2632 |
+
v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
|
| 2633 |
}
|
| 2634 |
}
|
| 2635 |
}
|
| 2636 |
|
| 2637 |
+
device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
|
| 2638 |
|
| 2639 |
dst_ptr[0] = v;
|
| 2640 |
}
|
|
|
|
| 2644 |
device const float * src0,
|
| 2645 |
device const float * src1,
|
| 2646 |
device char * dst,
|
| 2647 |
+
constant ggml_metal_kargs_conv_transpose_1d & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2648 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2649 |
uint3 tgpg[[threadgroups_per_grid]]);
|
| 2650 |
|
|
|
|
| 2653 |
device const half * src0,
|
| 2654 |
device const float * src1,
|
| 2655 |
device char * dst,
|
| 2656 |
+
constant ggml_metal_kargs_conv_transpose_1d & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2657 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2658 |
uint3 tgpg[[threadgroups_per_grid]]);
|
| 2659 |
|
| 2660 |
kernel void kernel_upscale_f32(
|
| 2661 |
device const char * src0,
|
| 2662 |
device char * dst,
|
| 2663 |
+
constant ggml_metal_kargs_upscale & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2664 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2665 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2666 |
uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
| 2669 |
const int64_t i2 = tgpig.y;
|
| 2670 |
const int64_t i1 = tgpig.x;
|
| 2671 |
|
| 2672 |
+
const int64_t i03 = i3/args.sf3;
|
| 2673 |
+
const int64_t i02 = i2/args.sf2;
|
| 2674 |
+
const int64_t i01 = i1/args.sf1;
|
| 2675 |
|
| 2676 |
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
| 2677 |
+
const int64_t i00 = i0/args.sf0;
|
| 2678 |
|
| 2679 |
+
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
| 2680 |
+
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
| 2681 |
|
| 2682 |
dst_ptr[0] = src0_ptr[0];
|
| 2683 |
}
|
|
|
|
| 2686 |
kernel void kernel_pad_f32(
|
| 2687 |
device const char * src0,
|
| 2688 |
device char * dst,
|
| 2689 |
+
constant ggml_metal_kargs_pad & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2690 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2691 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2692 |
uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
| 2699 |
const int64_t i02 = i2;
|
| 2700 |
const int64_t i01 = i1;
|
| 2701 |
|
| 2702 |
+
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
| 2703 |
+
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
| 2704 |
|
| 2705 |
+
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
| 2706 |
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
| 2707 |
+
if (i0 < args.ne00) {
|
| 2708 |
dst_ptr[i0] = src0_ptr[i0];
|
| 2709 |
} else {
|
| 2710 |
dst_ptr[i0] = 0.0f;
|
|
|
|
| 2714 |
return;
|
| 2715 |
}
|
| 2716 |
|
| 2717 |
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
| 2718 |
dst_ptr[i0] = 0.0f;
|
| 2719 |
}
|
| 2720 |
}
|
|
|
|
| 2722 |
kernel void kernel_pad_reflect_1d_f32(
|
| 2723 |
device const char * src0,
|
| 2724 |
device char * dst,
|
| 2725 |
+
constant ggml_metal_kargs_pad_reflect_1d & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2726 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2727 |
uint3 tgpg[[threadgroups_per_grid]],
|
| 2728 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
| 2736 |
const int64_t i02 = i2;
|
| 2737 |
const int64_t i01 = i1;
|
| 2738 |
|
| 2739 |
+
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
| 2740 |
+
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
| 2741 |
|
| 2742 |
+
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
| 2743 |
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
| 2744 |
+
if (i0 < args.p0) {
|
| 2745 |
+
dst_ptr[i0] = src0_ptr[args.p0 - i0];
|
| 2746 |
+
} else if (i0 < args.ne0 - args.p1) {
|
| 2747 |
+
dst_ptr[i0] = src0_ptr[i0 - args.p0];
|
| 2748 |
} else {
|
| 2749 |
+
dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
|
| 2750 |
}
|
| 2751 |
}
|
| 2752 |
}
|
|
|
|
| 2754 |
|
| 2755 |
kernel void kernel_arange_f32(
|
| 2756 |
device char * dst,
|
| 2757 |
+
constant ggml_metal_kargs_arange & args,
|
|
|
|
|
|
|
| 2758 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2759 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2760 |
uint3 ntg[[threads_per_threadgroup]]) {
|
| 2761 |
|
| 2762 |
device float * dst_ptr = (device float *) dst;
|
| 2763 |
|
| 2764 |
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
| 2765 |
+
dst_ptr[i0] = args.start + args.step * i0;
|
| 2766 |
}
|
| 2767 |
}
|
| 2768 |
|
| 2769 |
kernel void kernel_timestep_embedding_f32(
|
| 2770 |
device const char * src0,
|
| 2771 |
device char * dst,
|
| 2772 |
+
constant ggml_metal_kargs_timestep_embedding & args,
|
|
|
|
|
|
|
| 2773 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2774 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2775 |
uint3 ntg[[threads_per_threadgroup]]) {
|
| 2776 |
|
| 2777 |
int i = tgpig.x;
|
| 2778 |
+
device float * embed_data = (device float *)(dst + i*args.nb1);
|
| 2779 |
|
| 2780 |
+
int half_ = args.dim / 2;
|
| 2781 |
for (int j = tpitg.x; j < half_; j += ntg.x) {
|
| 2782 |
float timestep = ((device float *)src0)[i];
|
| 2783 |
+
float freq = (float)exp(-log((float)args.max_period) * j / half_);
|
| 2784 |
float arg = timestep * freq;
|
| 2785 |
embed_data[j ] = cos(arg);
|
| 2786 |
embed_data[j + half_] = sin(arg);
|
| 2787 |
}
|
| 2788 |
|
| 2789 |
+
if (args.dim % 2 != 0 && tpitg.x == 0) {
|
| 2790 |
+
embed_data[args.dim] = 0.f;
|
| 2791 |
}
|
| 2792 |
}
|
| 2793 |
|
|
|
|
| 2795 |
typedef void (argsort_t)(
|
| 2796 |
device const float * x,
|
| 2797 |
device int32_t * dst,
|
| 2798 |
+
constant ggml_metal_kargs_argsort & args,
|
|
|
|
| 2799 |
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
| 2800 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2801 |
uint3 tpitg[[thread_position_in_threadgroup]]);
|
|
|
|
| 2804 |
kernel void kernel_argsort_f32_i32(
|
| 2805 |
device const float * x,
|
| 2806 |
device int32_t * dst,
|
| 2807 |
+
constant ggml_metal_kargs_argsort & args,
|
|
|
|
| 2808 |
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
| 2809 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2810 |
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
|
|
|
| 2812 |
int col = tpitg[0];
|
| 2813 |
int row = tgpig[1];
|
| 2814 |
|
| 2815 |
+
if (col >= args.ncols_pad) return;
|
| 2816 |
|
| 2817 |
+
device const float * x_row = x + row * args.ncols;
|
| 2818 |
threadgroup int32_t * dst_row = shared_values;
|
| 2819 |
|
| 2820 |
// initialize indices
|
|
|
|
| 2822 |
|
| 2823 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2824 |
|
| 2825 |
+
for (int k = 2; k <= args.ncols_pad; k *= 2) {
|
| 2826 |
for (int j = k / 2; j > 0; j /= 2) {
|
| 2827 |
int ixj = col ^ j;
|
| 2828 |
if (ixj > col) {
|
| 2829 |
if ((col & k) == 0) {
|
| 2830 |
+
if (dst_row[col] >= args.ncols ||
|
| 2831 |
+
(dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
|
| 2832 |
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
| 2833 |
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
| 2834 |
) {
|
| 2835 |
SWAP(dst_row[col], dst_row[ixj]);
|
| 2836 |
}
|
| 2837 |
} else {
|
| 2838 |
+
if (dst_row[ixj] >= args.ncols ||
|
| 2839 |
+
(dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
|
| 2840 |
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
| 2841 |
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
| 2842 |
) {
|
|
|
|
| 2849 |
}
|
| 2850 |
|
| 2851 |
// copy the result to dst without the padding
|
| 2852 |
+
if (col < args.ncols) {
|
| 2853 |
+
dst[row * args.ncols + col] = dst_row[col];
|
| 2854 |
}
|
| 2855 |
}
|
| 2856 |
|
|
|
|
| 2860 |
kernel void kernel_leaky_relu_f32(
|
| 2861 |
device const float * src0,
|
| 2862 |
device float * dst,
|
| 2863 |
+
constant ggml_metal_kargs_leaky_relu & args,
|
| 2864 |
uint tpig[[thread_position_in_grid]]) {
|
| 2865 |
+
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope;
|
| 2866 |
}
|
| 2867 |
|
| 2868 |
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
|
|
|
| 5803 |
device const void * src0,
|
| 5804 |
device const void * src1,
|
| 5805 |
device float * dst,
|
| 5806 |
+
constant ggml_metal_kargs_get_rows & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5807 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5808 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 5809 |
uint3 tptg [[threads_per_threadgroup]]) {
|
| 5810 |
const int64_t i10 = tgpig.x;
|
| 5811 |
const int64_t i11 = tgpig.y;
|
| 5812 |
|
| 5813 |
+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
|
| 5814 |
|
| 5815 |
const int64_t i02 = i11;
|
| 5816 |
|
| 5817 |
+
for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) {
|
| 5818 |
float4x4 temp;
|
| 5819 |
+
dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp);
|
| 5820 |
+
*(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp;
|
| 5821 |
}
|
| 5822 |
}
|
| 5823 |
|
|
|
|
| 5826 |
device const void * src0,
|
| 5827 |
device const void * src1,
|
| 5828 |
device float * dst,
|
| 5829 |
+
constant ggml_metal_kargs_get_rows & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5830 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5831 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 5832 |
uint3 tptg [[threads_per_threadgroup]]) {
|
| 5833 |
const int64_t i10 = tgpig.x;
|
| 5834 |
const int64_t i11 = tgpig.y;
|
| 5835 |
|
| 5836 |
+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
|
| 5837 |
|
| 5838 |
const int64_t i02 = i11;
|
| 5839 |
|
| 5840 |
+
for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
|
| 5841 |
+
(( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
|
| 5842 |
+
((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
|
| 5843 |
}
|
| 5844 |
}
|
| 5845 |
|
|
|
|
| 5847 |
device const void * src0,
|
| 5848 |
device const void * src1,
|
| 5849 |
device int32_t * dst,
|
| 5850 |
+
constant ggml_metal_kargs_get_rows & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5851 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5852 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 5853 |
uint3 tptg [[threads_per_threadgroup]]) {
|
| 5854 |
const int64_t i10 = tgpig.x;
|
| 5855 |
const int64_t i11 = tgpig.y;
|
| 5856 |
|
| 5857 |
+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
|
| 5858 |
|
| 5859 |
const int64_t i02 = i11;
|
| 5860 |
|
| 5861 |
+
for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
|
| 5862 |
+
(( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
|
| 5863 |
+
((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
|
| 5864 |
}
|
| 5865 |
}
|
| 5866 |
|
|
|
|
| 6462 |
kernel void kernel_pool_2d_max_f32(
|
| 6463 |
device const float * src0,
|
| 6464 |
device float * dst,
|
| 6465 |
+
constant ggml_metal_kargs_pool_2d & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6466 |
uint gid[[thread_position_in_grid]]) {
|
| 6467 |
|
| 6468 |
+
if (gid >= args.parallel_elements) {
|
| 6469 |
return;
|
| 6470 |
}
|
| 6471 |
|
| 6472 |
const int idx = gid;
|
| 6473 |
+
const int I_HW = args.IH * args.IW;
|
| 6474 |
+
const int O_HW = args.OH * args.OW;
|
| 6475 |
const int nc = idx / O_HW;
|
| 6476 |
+
const int cur_oh = idx % O_HW / args.OW;
|
| 6477 |
+
const int cur_ow = idx % O_HW % args.OW;
|
| 6478 |
|
| 6479 |
device const float * i_ptr = src0 + nc * I_HW;
|
| 6480 |
device float * o_ptr = dst + nc * O_HW;
|
| 6481 |
|
| 6482 |
+
const int start_h = cur_oh * args.s1 - args.p1;
|
| 6483 |
const int bh = MAX(0, start_h);
|
| 6484 |
+
const int eh = MIN(args.IH, start_h + args.k1);
|
| 6485 |
+
const int start_w = cur_ow * args.s0 - args.p0;
|
| 6486 |
const int bw = MAX(0, start_w);
|
| 6487 |
+
const int ew = MIN(args.IW, start_w + args.k0);
|
| 6488 |
|
| 6489 |
float res = -INFINITY;
|
| 6490 |
|
| 6491 |
for (int i = bh; i < eh; i += 1) {
|
| 6492 |
for (int j = bw; j < ew; j += 1) {
|
| 6493 |
+
res = MAX(res, i_ptr[i * args.IW + j]);
|
| 6494 |
}
|
| 6495 |
}
|
| 6496 |
|
| 6497 |
+
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
| 6498 |
}
|
| 6499 |
|
| 6500 |
kernel void kernel_pool_2d_avg_f32(
|
| 6501 |
device const float * src0,
|
| 6502 |
device float * dst,
|
| 6503 |
+
constant ggml_metal_kargs_pool_2d & args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6504 |
uint gid[[thread_position_in_grid]]) {
|
| 6505 |
|
| 6506 |
+
if (gid >= args.parallel_elements) {
|
| 6507 |
return;
|
| 6508 |
}
|
| 6509 |
|
| 6510 |
const int idx = gid;
|
| 6511 |
+
const int I_HW = args.IH * args.IW;
|
| 6512 |
+
const int O_HW = args.OH * args.OW;
|
| 6513 |
const int nc = idx / O_HW;
|
| 6514 |
+
const int cur_oh = idx % O_HW / args.OW;
|
| 6515 |
+
const int cur_ow = idx % O_HW % args.OW;
|
| 6516 |
|
| 6517 |
device const float * i_ptr = src0 + nc * I_HW;
|
| 6518 |
device float * o_ptr = dst + nc * O_HW;
|
| 6519 |
|
| 6520 |
+
const int start_h = cur_oh * args.s1 - args.p1;
|
| 6521 |
const int bh = MAX(0, start_h);
|
| 6522 |
+
const int eh = MIN(args.IH, start_h + args.k1);
|
| 6523 |
+
const int start_w = cur_ow * args.s0 - args.p0;
|
| 6524 |
const int bw = MAX(0, start_w);
|
| 6525 |
+
const int ew = MIN(args.IW, start_w + args.k0);
|
| 6526 |
// const float scale = 1. / ((eh - bh) * (ew - bw));
|
| 6527 |
+
const float scale = 1. / (args.k0 * args.k1);
|
| 6528 |
|
| 6529 |
float res = 0;
|
| 6530 |
|
| 6531 |
for (int i = bh; i < eh; i += 1) {
|
| 6532 |
for (int j = bw; j < ew; j += 1) {
|
| 6533 |
+
float cur = i_ptr[i * args.IW + j];
|
| 6534 |
res += cur * scale;
|
| 6535 |
}
|
| 6536 |
}
|
| 6537 |
|
| 6538 |
+
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
| 6539 |
}
|