jeffbolznv commited on
Commit
9821f43
·
1 Parent(s): 04b631e

vulkan: support SET_ROWS (llama/14587)

Browse files

* vulkan: support SET_ROWS

Add variants of the copy_to_quant shader that do the SET_ROWS operation.
Change these shaders to spread the work across the workgroup.
The memory access pattern is probably not great (one thread per quant block),
but should be fine for now.

* vulkan: optimize set_rows

Larger workgroups for non-quant types.
Set "norepeat" (there is manual repeat logic).
Use fastmod.

ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -437,6 +437,7 @@ struct vk_device_struct {
437
  vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
438
  vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
439
  vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
 
440
  vk_pipeline pipeline_norm_f32;
441
  vk_pipeline pipeline_group_norm_f32;
442
  vk_pipeline pipeline_rms_norm_f32;
@@ -2749,19 +2750,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
2749
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2750
 
2751
  if (device->float_controls_rte_fp16) {
2752
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2753
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2754
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2755
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2756
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2757
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2758
  } else {
2759
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2760
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2761
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2762
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2763
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2764
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2765
  }
2766
 
2767
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
@@ -6527,6 +6550,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6527
  case GGML_OP_CONT:
6528
  case GGML_OP_DUP:
6529
  return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
 
 
6530
  case GGML_OP_SILU_BACK:
6531
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6532
  return ctx->device->pipeline_silu_back_f32;
@@ -6765,6 +6790,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
6765
  case GGML_OP_RMS_NORM:
6766
  case GGML_OP_CONV_2D_DW:
6767
  case GGML_OP_IM2COL:
 
6768
  return true;
6769
  default:
6770
  return false;
@@ -7078,6 +7104,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7078
  ne *= ggml_type_size(src0->type) / 2;
7079
  }
7080
  }
 
 
 
 
 
 
7081
  if (ne > 262144) {
7082
  elements = { 512, 512, CEIL_DIV(ne, 262144) };
7083
  } else if (ne > 512) {
@@ -7086,6 +7118,25 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7086
  elements = { ne, 1, 1 };
7087
  }
7088
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7089
  default:
7090
  elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
7091
  break;
@@ -7648,6 +7699,21 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
7648
  }, dryrun);
7649
  }
7650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7651
  static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7652
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
7653
  }
@@ -8968,6 +9034,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
8968
  case GGML_OP_CLAMP:
8969
  case GGML_OP_PAD:
8970
  case GGML_OP_CPY:
 
8971
  case GGML_OP_CONT:
8972
  case GGML_OP_DUP:
8973
  case GGML_OP_SILU_BACK:
@@ -9034,6 +9101,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9034
  case GGML_OP_CLAMP:
9035
  case GGML_OP_PAD:
9036
  case GGML_OP_CPY:
 
9037
  case GGML_OP_CONT:
9038
  case GGML_OP_DUP:
9039
  case GGML_OP_SILU_BACK:
@@ -9142,6 +9210,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9142
  case GGML_OP_DUP:
9143
  ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
9144
 
 
 
 
 
9145
  break;
9146
  case GGML_OP_SILU_BACK:
9147
  ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9357,6 +9429,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
9357
  case GGML_OP_CLAMP:
9358
  case GGML_OP_PAD:
9359
  case GGML_OP_CPY:
 
9360
  case GGML_OP_CONT:
9361
  case GGML_OP_DUP:
9362
  case GGML_OP_SILU_BACK:
@@ -10422,9 +10495,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10422
  } break;
10423
  case GGML_OP_SET_ROWS:
10424
  {
10425
- // TODO: add support
10426
- // ref: https://github.com/ggml-org/llama.cpp/pull/14274
10427
- return false;
 
 
 
 
 
 
 
 
 
 
 
10428
  } break;
10429
  case GGML_OP_CONT:
10430
  case GGML_OP_CPY:
@@ -11039,6 +11123,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11039
  } else {
11040
  tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
11041
  }
 
 
11042
  } else if (tensor->op == GGML_OP_CONT) {
11043
  tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
11044
  } else if (tensor->op == GGML_OP_RESHAPE) {
 
437
  vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
438
  vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
439
  vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
440
+ vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
441
  vk_pipeline pipeline_norm_f32;
442
  vk_pipeline pipeline_group_norm_f32;
443
  vk_pipeline pipeline_rms_norm_f32;
 
2750
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2751
 
2752
  if (device->float_controls_rte_fp16) {
2753
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2754
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2755
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2756
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2757
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2758
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2759
  } else {
2760
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2761
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2762
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2763
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2764
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2765
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2766
+ }
2767
+
2768
+ if (device->float_controls_rte_fp16) {
2769
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2770
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2771
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2772
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2773
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2774
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2775
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2776
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2777
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2778
+ } else {
2779
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2780
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2781
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2782
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2783
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2784
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2785
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2786
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2787
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2788
  }
2789
 
2790
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
 
6550
  case GGML_OP_CONT:
6551
  case GGML_OP_DUP:
6552
  return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
6553
+ case GGML_OP_SET_ROWS:
6554
+ return ctx->device->pipeline_set_rows[dst->type];
6555
  case GGML_OP_SILU_BACK:
6556
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6557
  return ctx->device->pipeline_silu_back_f32;
 
6790
  case GGML_OP_RMS_NORM:
6791
  case GGML_OP_CONV_2D_DW:
6792
  case GGML_OP_IM2COL:
6793
+ case GGML_OP_SET_ROWS:
6794
  return true;
6795
  default:
6796
  return false;
 
7104
  ne *= ggml_type_size(src0->type) / 2;
7105
  }
7106
  }
7107
+ // copy_to_quant has block size of 32, and each thread does QUANT_K elements.
7108
+ // Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
7109
+ // So divide by block size here before splitting into 512x512 groups.
7110
+ if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
7111
+ ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
7112
+ }
7113
  if (ne > 262144) {
7114
  elements = { 512, 512, CEIL_DIV(ne, 262144) };
7115
  } else if (ne > 512) {
 
7118
  elements = { ne, 1, 1 };
7119
  }
7120
  } break;
7121
+ case GGML_OP_SET_ROWS:
7122
+ {
7123
+ uint32_t ne = ggml_nelements(src0);
7124
+ if (ggml_is_quantized(dst->type)) {
7125
+ // quants run 32 threads each doing QUANT_K elements
7126
+ ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
7127
+ } else {
7128
+ // scalar types do one element per thread, running 512 threads
7129
+ ne = CEIL_DIV(ne, 512);
7130
+ }
7131
+ if (ne > 262144) {
7132
+ elements = { 512, 512, CEIL_DIV(ne, 262144) };
7133
+ } else if (ne > 512) {
7134
+ elements = { 512, CEIL_DIV(ne, 512), 1 };
7135
+ } else {
7136
+ elements = { ne, 1, 1 };
7137
+ }
7138
+ }
7139
+ break;
7140
  default:
7141
  elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
7142
  break;
 
7699
  }, dryrun);
7700
  }
7701
 
7702
+ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7703
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7704
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
7705
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
7706
+
7707
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
7708
+ (uint32_t)ggml_nelements(src0),
7709
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7710
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7711
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7712
+ 0,
7713
+ 0.0f, 0.0f, 0,
7714
+ }, dryrun);
7715
+ }
7716
+
7717
  static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7718
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
7719
  }
 
9034
  case GGML_OP_CLAMP:
9035
  case GGML_OP_PAD:
9036
  case GGML_OP_CPY:
9037
+ case GGML_OP_SET_ROWS:
9038
  case GGML_OP_CONT:
9039
  case GGML_OP_DUP:
9040
  case GGML_OP_SILU_BACK:
 
9101
  case GGML_OP_CLAMP:
9102
  case GGML_OP_PAD:
9103
  case GGML_OP_CPY:
9104
+ case GGML_OP_SET_ROWS:
9105
  case GGML_OP_CONT:
9106
  case GGML_OP_DUP:
9107
  case GGML_OP_SILU_BACK:
 
9210
  case GGML_OP_DUP:
9211
  ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
9212
 
9213
+ break;
9214
+ case GGML_OP_SET_ROWS:
9215
+ ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun);
9216
+
9217
  break;
9218
  case GGML_OP_SILU_BACK:
9219
  ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
 
9429
  case GGML_OP_CLAMP:
9430
  case GGML_OP_PAD:
9431
  case GGML_OP_CPY:
9432
+ case GGML_OP_SET_ROWS:
9433
  case GGML_OP_CONT:
9434
  case GGML_OP_DUP:
9435
  case GGML_OP_SILU_BACK:
 
10495
  } break;
10496
  case GGML_OP_SET_ROWS:
10497
  {
10498
+ switch (op->type) {
10499
+ case GGML_TYPE_F32:
10500
+ case GGML_TYPE_F16:
10501
+ case GGML_TYPE_BF16:
10502
+ case GGML_TYPE_Q4_0:
10503
+ case GGML_TYPE_Q4_1:
10504
+ case GGML_TYPE_Q5_0:
10505
+ case GGML_TYPE_Q5_1:
10506
+ case GGML_TYPE_Q8_0:
10507
+ case GGML_TYPE_IQ4_NL:
10508
+ return true;
10509
+ default:
10510
+ return false;
10511
+ }
10512
  } break;
10513
  case GGML_OP_CONT:
10514
  case GGML_OP_CPY:
 
11123
  } else {
11124
  tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
11125
  }
11126
+ } else if (tensor->op == GGML_OP_SET_ROWS) {
11127
+ tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
11128
  } else if (tensor->op == GGML_OP_CONT) {
11129
  tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
11130
  } else if (tensor->op == GGML_OP_RESHAPE) {
ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp CHANGED
@@ -6,17 +6,25 @@ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bi
6
  #endif // RTE16
7
 
8
  #include "types.comp"
9
- #include "generic_unary_head.comp"
10
 
11
- #if defined(DATA_A_IQ4_NL)
12
- // 16 invocations needed for init_iq4nl_shmem
13
- layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
14
  #else
15
- layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
 
16
  #endif
17
 
18
  layout (binding = 0) readonly buffer S {float data_s[];};
 
 
 
 
 
 
 
19
  layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
 
20
 
21
  #if defined(DATA_A_Q4_0)
22
  void quantize(uint dst_idx, uint src_idx)
@@ -221,15 +229,56 @@ void quantize(uint dst_idx, uint src_idx)
221
  }
222
  #endif
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  void main() {
225
  #ifdef NEEDS_INIT_IQ_SHMEM
226
  init_iq_shmem(gl_WorkGroupSize);
227
- if (gl_LocalInvocationIndex.x != 0) {
 
 
 
 
228
  return;
229
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  #endif
231
 
232
- const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
233
 
234
  if (idx >= p.ne) {
235
  return;
@@ -240,3 +289,5 @@ void main() {
240
 
241
  quantize(dst_idx, src_idx);
242
  }
 
 
 
6
  #endif // RTE16
7
 
8
  #include "types.comp"
 
9
 
10
+ #if defined(SET_ROWS) && QUANT_K == 1
11
+ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
12
+ const uint BLOCK_SIZE = 512;
13
  #else
14
+ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
15
+ const uint BLOCK_SIZE = 32;
16
  #endif
17
 
18
  layout (binding = 0) readonly buffer S {float data_s[];};
19
+
20
+ #if defined(SET_ROWS)
21
+ #include "generic_binary_head.comp"
22
+ layout (binding = 1) readonly buffer C {uvec2 data_i[];};
23
+ layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
24
+ #else
25
+ #include "generic_unary_head.comp"
26
  layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
27
+ #endif
28
 
29
  #if defined(DATA_A_Q4_0)
30
  void quantize(uint dst_idx, uint src_idx)
 
229
  }
230
  #endif
231
 
232
+ #if defined(DATA_A_F32) || defined(DATA_A_F16)
233
+ void quantize(uint dst_idx, uint src_idx)
234
+ {
235
+ data_q[dst_idx] = A_TYPE(data_s[src_idx]);
236
+ }
237
+ #endif
238
+
239
+ #if defined(DATA_A_BF16)
240
+ void quantize(uint dst_idx, uint src_idx)
241
+ {
242
+ data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
243
+ }
244
+ #endif
245
+
246
+ #if defined(SET_ROWS)
247
+
248
  void main() {
249
  #ifdef NEEDS_INIT_IQ_SHMEM
250
  init_iq_shmem(gl_WorkGroupSize);
251
+ #endif
252
+
253
+ const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
254
+
255
+ if (idx >= p.ne) {
256
  return;
257
  }
258
+
259
+ uint i00, i01, i02, i03;
260
+ get_indices(idx, i00, i01, i02, i03);
261
+
262
+ uint i12 = fastmod(i03, p.ne12);
263
+ uint i11 = fastmod(i02, p.ne11);
264
+ uint i10 = i01;
265
+
266
+ uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
267
+
268
+ uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
269
+ uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
270
+
271
+ quantize(dst_idx, src0_idx);
272
+ }
273
+
274
+ #else
275
+
276
+ void main() {
277
+ #ifdef NEEDS_INIT_IQ_SHMEM
278
+ init_iq_shmem(gl_WorkGroupSize);
279
  #endif
280
 
281
+ const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
282
 
283
  if (idx >= p.ne) {
284
  return;
 
289
 
290
  quantize(dst_idx, src_idx);
291
  }
292
+
293
+ #endif
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -518,6 +518,11 @@ void process_shaders() {
518
  string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
519
  }
520
 
 
 
 
 
 
521
  auto get_type_str = [](bool f16) {
522
  return f16 ? "float16_t" : "float";
523
  };
 
518
  string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
519
  }
520
 
521
+ for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
522
+ string_to_spv("set_rows_" + t, "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
523
+ string_to_spv("set_rows_" + t + "_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
524
+ }
525
+
526
  auto get_type_str = [](bool f16) {
527
  return f16 ? "float16_t" : "float";
528
  };