jeffbolznv commited on
Commit
bac21a7
·
1 Parent(s): 1e145c7

vulkan: add RTE variants for glu/add/sub/mul/div (llama/14653)

Browse files
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -2835,10 +2835,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
2835
  return s;
2836
  };
2837
 
 
2838
  #define CREATE_BINARY(name, namemod, spec) \
2839
  for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
2840
  ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
2841
- #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
2842
  "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
2843
 
2844
  CREATE_BINARY(add, , {0})
@@ -2890,8 +2891,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
2890
  #undef CREATE_UNARY
2891
 
2892
  #define CREATE_GLU(name) \
2893
- ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2894
- ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
 
 
 
 
 
2895
 
2896
  CREATE_GLU(geglu)
2897
  CREATE_GLU(reglu)
 
2835
  return s;
2836
  };
2837
 
2838
+ bool rte = device->float_controls_rte_fp16;
2839
  #define CREATE_BINARY(name, namemod, spec) \
2840
  for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
2841
  ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
2842
+ #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
2843
  "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
2844
 
2845
  CREATE_BINARY(add, , {0})
 
2891
  #undef CREATE_UNARY
2892
 
2893
  #define CREATE_GLU(name) \
2894
+ if (device->float_controls_rte_fp16) { \
2895
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2896
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2897
+ } else { \
2898
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2899
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2900
+ }
2901
 
2902
  CREATE_GLU(geglu)
2903
  CREATE_GLU(reglu)
ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp CHANGED
@@ -1,10 +1,6 @@
1
  #version 450
2
 
3
- #if RTE16
4
- #extension GL_EXT_spirv_intrinsics : enable
5
- spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
6
- #endif // RTE16
7
-
8
  #include "types.comp"
9
 
10
  #if defined(SET_ROWS) && QUANT_K == 1
 
1
  #version 450
2
 
3
+ #include "rte.comp"
 
 
 
 
4
  #include "types.comp"
5
 
6
  #if defined(SET_ROWS) && QUANT_K == 1
ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp CHANGED
@@ -1,6 +1,8 @@
1
  #extension GL_EXT_shader_16bit_storage : require
2
  #extension GL_EXT_control_flow_attributes : require
3
 
 
 
4
  layout (push_constant) uniform parameter
5
  {
6
  uint ne;
 
1
  #extension GL_EXT_shader_16bit_storage : require
2
  #extension GL_EXT_control_flow_attributes : require
3
 
4
+ #include "rte.comp"
5
+
6
  layout (push_constant) uniform parameter
7
  {
8
  uint ne;
ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp CHANGED
@@ -1,5 +1,7 @@
1
  #extension GL_EXT_shader_16bit_storage : require
2
 
 
 
3
  layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
4
 
5
  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 
1
  #extension GL_EXT_shader_16bit_storage : require
2
 
3
+ #include "rte.comp"
4
+
5
  layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
6
 
7
  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp CHANGED
@@ -1,12 +1,9 @@
1
  #version 450
2
 
3
  #extension GL_EXT_shader_16bit_storage : require
4
- #extension GL_EXT_spirv_intrinsics: enable
5
  #extension GL_EXT_control_flow_attributes : require
6
 
7
- #if RTE16
8
- spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
9
- #endif
10
 
11
  layout (push_constant) uniform parameter
12
  {
 
1
  #version 450
2
 
3
  #extension GL_EXT_shader_16bit_storage : require
 
4
  #extension GL_EXT_control_flow_attributes : require
5
 
6
+ #include "rte.comp"
 
 
7
 
8
  layout (push_constant) uniform parameter
9
  {
ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp CHANGED
@@ -1,11 +1,8 @@
1
  #include "types.comp"
2
 
3
  #extension GL_EXT_shader_16bit_storage : require
4
- #extension GL_EXT_spirv_intrinsics: enable
5
 
6
- #if RTE16
7
- spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
8
- #endif
9
 
10
  layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
11
 
 
1
  #include "types.comp"
2
 
3
  #extension GL_EXT_shader_16bit_storage : require
 
4
 
5
+ #include "rte.comp"
 
 
6
 
7
  layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
8
 
ggml/src/ggml-vulkan/vulkan-shaders/rte.comp ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ #if RTE16
3
+ #extension GL_EXT_spirv_intrinsics : enable
4
+ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
5
+ #endif // RTE16
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -537,8 +537,10 @@ void process_shaders() {
537
  for (auto src0_f16 : {false, true}) {
538
  for (auto src1_f16 : {false, true}) {
539
  for (auto dst_f16 : {false, true}) {
540
- auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
541
- string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}});
 
 
542
  }
543
  }
544
  }
@@ -592,16 +594,19 @@ void process_shaders() {
592
  string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
593
  string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
594
 
595
- string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
596
- string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
597
- string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
598
- string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
599
- string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
600
- string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
601
- string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
602
- string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
603
- string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
604
- string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 
 
 
605
 
606
  string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
607
  string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -709,11 +714,59 @@ void write_output_files() {
709
  std::remove(path.c_str());
710
  }
711
  }
 
 
712
  for (const char *op : {"add", "sub", "mul", "div"}) {
713
- fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op);
714
- fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op);
715
- fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op);
716
- fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
717
  }
718
  fclose(hdr);
719
  fclose(src);
 
537
  for (auto src0_f16 : {false, true}) {
538
  for (auto src1_f16 : {false, true}) {
539
  for (auto dst_f16 : {false, true}) {
540
+ for (auto rte : {false, true}) {
541
+ auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
542
+ string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
543
+ }
544
  }
545
  }
546
  }
 
594
  string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
595
  string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
596
 
597
+ for (auto rte : {false, true}) {
598
+ std::string suffix = rte ? "_rte" : "";
599
+ string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
600
+ string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
601
+ string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
602
+ string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
603
+ string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
604
+ string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
605
+ string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
606
+ string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
607
+ string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
608
+ string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
609
+ }
610
 
611
  string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
612
  string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
 
714
  std::remove(path.c_str());
715
  }
716
  }
717
+
718
+ std::string suffixes[2] = {"_f32", "_f16"};
719
  for (const char *op : {"add", "sub", "mul", "div"}) {
720
+ fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
721
+ fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
722
+ std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
723
+ std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = ";
724
+ for (uint32_t t0 = 0; t0 < 2; ++t0) {
725
+ if (t0 == 0) {
726
+ data += "{";
727
+ len += "{";
728
+ }
729
+ for (uint32_t t1 = 0; t1 < 2; ++t1) {
730
+ if (t1 == 0) {
731
+ data += "{";
732
+ len += "{";
733
+ }
734
+ for (uint32_t t2 = 0; t2 < 2; ++t2) {
735
+ if (t2 == 0) {
736
+ data += "{";
737
+ len += "{";
738
+ }
739
+ for (uint32_t rte = 0; rte < 2; ++rte) {
740
+ if (rte == 0) {
741
+ data += "{";
742
+ len += "{";
743
+ }
744
+ data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
745
+ len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
746
+ data += "_data,";
747
+ len += "_len,";
748
+ if (rte == 1) {
749
+ data += "}, ";
750
+ len += "}, ";
751
+ }
752
+ }
753
+ if (t2 == 1) {
754
+ data += "}, ";
755
+ len += "}, ";
756
+ }
757
+ }
758
+ if (t1 == 1) {
759
+ data += "}, ";
760
+ len += "}, ";
761
+ }
762
+ }
763
+ if (t0 == 1) {
764
+ data += "};\n";
765
+ len += "};\n";
766
+ }
767
+ }
768
+ fprintf(src, data.c_str());
769
+ fprintf(src, len.c_str());
770
  }
771
  fclose(hdr);
772
  fclose(src);