Akarshan Biswas commited on
Commit
b305121
·
1 Parent(s): 2722bea

SYCL: use 1D kernel for set_rows (llama/14618)

Browse files

* SYCL: Use 1D kernel for set_rows

* Remove dangling comment

* Refactor and use ceil_div

Files changed (1) hide show
  1. ggml/src/ggml-sycl/set_rows.cpp +43 -43
ggml/src/ggml-sycl/set_rows.cpp CHANGED
@@ -6,46 +6,49 @@ static constexpr bool is_arithmetic_v() {
6
  return std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half> || std::is_same_v<T, sycl::ext::oneapi::bfloat16>;
7
  }
8
  }
 
9
  template<typename TIn, typename TOut>
10
  static inline std::enable_if_t<utils::is_arithmetic_v<TIn>() && utils::is_arithmetic_v<TOut>(), void>
11
  convert (const char* src, char* dst) {
12
  auto src_val = *reinterpret_cast<const TIn*>(src);
13
  auto dst_val = sycl::vec<TIn, 1>(src_val).template convert<TOut, sycl::rounding_mode::automatic>()[0];
14
- *reinterpret_cast<TOut*>(dst) = dst_val;;
15
  }
16
 
17
  template<typename TIn, typename TOut>
18
  static void k_set_rows(
19
  const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
20
- const int64_t ne00, const int64_t ne01, const int64_t ne11, const int64_t ne12,
 
21
  const size_t nb01, const size_t nb02, const size_t nb03,
22
  const size_t nb10, const size_t nb11, const size_t nb12,
23
  const size_t nb1, const size_t nb2, const size_t nb3,
24
  const size_t src_type_size, const size_t dst_type_size,
25
- const sycl::nd_item<3> & item_ct1) {
26
-
27
- const int i03 = item_ct1.get_group(0);
28
- const int i02 = item_ct1.get_group(1);
29
- const int i01 = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); // Row index
30
 
31
- if (i01 >= ne01) {
 
32
  return;
33
  }
34
 
35
- const int i12 = i03 % ne12;
36
- const int i11 = i02 % ne11;
37
- const int i10 = i01;
 
 
 
 
 
38
 
39
  const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3>({nb10, nb11, nb12}, {i10, i11, i12}));
40
 
41
  const char * src0_row = src0 + calculate_offset<3>({nb01, nb02, nb03}, {i01, i02, i03});
42
- char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
 
 
43
 
44
- for (int col = item_ct1.get_local_id(0); col < ne00; col += item_ct1.get_local_range(0)) {
45
- const char * src_elem = src0_row + col * src_type_size;
46
- char * dst_elem = dst_row_ptr + col * dst_type_size;
47
- convert<TIn, TOut>(src_elem, dst_elem);
48
- }
49
  }
50
 
51
  template<typename TIn, typename TOut>
@@ -58,32 +61,29 @@ static void set_rows_sycl(
58
  const size_t src_type_size, const size_t dst_type_size,
59
  queue_ptr stream) {
60
 
61
- constexpr int max_threads_per_row = 64; // KEEPING 64 for now
62
- const int threads_per_row = std::min((int)ne00, max_threads_per_row);
63
-
64
- constexpr int max_threads_per_block = 64;
65
- const int rows_per_block = std::max(1, max_threads_per_block / threads_per_row);
66
-
67
- const sycl::range<3> block_size(1, rows_per_block, threads_per_row);
68
- const sycl::range<3> grid_size(ne03, ne02, (ne01 + rows_per_block - 1) / rows_per_block);
69
-
70
- sycl_parallel_for(
71
- stream,
72
- sycl::nd_range<3>(grid_size * block_size, block_size),
73
- [=](sycl::nd_item<3> item_ct1) {
74
- k_set_rows<TIn, TOut>(
75
- src0_d, src1_d, dst_d,
76
- ne00, ne01, ne11, ne12,
77
- nb01, nb02, nb03,
78
- nb10, nb11, nb12,
79
- nb1, nb2, nb3,
80
- src_type_size, dst_type_size,
81
- item_ct1
82
- );
83
- }
84
- );
85
- }
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
89
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
@@ -122,7 +122,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
122
  nb1, nb2, nb3,
123
  sizeof(float), sizeof(sycl::half),
124
  stream
125
- );
126
  break;
127
  default:
128
  GGML_ABORT("Unsupported tensor type!");
 
6
  return std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half> || std::is_same_v<T, sycl::ext::oneapi::bfloat16>;
7
  }
8
  }
9
+
10
  template<typename TIn, typename TOut>
11
  static inline std::enable_if_t<utils::is_arithmetic_v<TIn>() && utils::is_arithmetic_v<TOut>(), void>
12
  convert (const char* src, char* dst) {
13
  auto src_val = *reinterpret_cast<const TIn*>(src);
14
  auto dst_val = sycl::vec<TIn, 1>(src_val).template convert<TOut, sycl::rounding_mode::automatic>()[0];
15
+ *reinterpret_cast<TOut*>(dst) = dst_val;
16
  }
17
 
18
  template<typename TIn, typename TOut>
19
  static void k_set_rows(
20
  const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
21
+ const int64_t ne00, const int64_t ne01, const int64_t ne02,
22
+ const int64_t ne11, const int64_t ne12,
23
  const size_t nb01, const size_t nb02, const size_t nb03,
24
  const size_t nb10, const size_t nb11, const size_t nb12,
25
  const size_t nb1, const size_t nb2, const size_t nb3,
26
  const size_t src_type_size, const size_t dst_type_size,
27
+ const int64_t total_elements,
28
+ const sycl::nd_item<1> & item_ct1) {
 
 
 
29
 
30
+ const int64_t i = item_ct1.get_global_linear_id();
31
+ if (i >= total_elements) {
32
  return;
33
  }
34
 
35
+ const int64_t i03 = i / (ne00 * ne01 * ne02);
36
+ const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
37
+ const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
38
+ const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
39
+
40
+ const int64_t i12 = i03 % ne12;
41
+ const int64_t i11 = i02 % ne11;
42
+ const int64_t i10 = i01;
43
 
44
  const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3>({nb10, nb11, nb12}, {i10, i11, i12}));
45
 
46
  const char * src0_row = src0 + calculate_offset<3>({nb01, nb02, nb03}, {i01, i02, i03});
47
+ const char * src_elem = src0_row + i00 * src_type_size;
48
+ char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
49
+ char * dst_elem = dst_row_ptr + i00 * dst_type_size;
50
 
51
+ convert<TIn, TOut>(src_elem, dst_elem);
 
 
 
 
52
  }
53
 
54
  template<typename TIn, typename TOut>
 
61
  const size_t src_type_size, const size_t dst_type_size,
62
  queue_ptr stream) {
63
 
64
+ const int64_t total_elements = ne00 * ne01 * ne02 * ne03;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ constexpr int block_size = 64;
67
+ const int64_t grid_size = ceil_div(total_elements, block_size);
68
+
69
+ sycl_parallel_for(
70
+ stream,
71
+ sycl::nd_range<1>(grid_size * block_size, block_size),
72
+ [=](sycl::nd_item<1> item_ct1) {
73
+ k_set_rows<TIn, TOut>(
74
+ src0_d, src1_d, dst_d,
75
+ ne00, ne01, ne02,
76
+ ne11, ne12,
77
+ nb01, nb02, nb03,
78
+ nb10, nb11, nb12,
79
+ nb1, nb2, nb3,
80
+ src_type_size, dst_type_size,
81
+ total_elements,
82
+ item_ct1
83
+ );
84
+ }
85
+ );
86
+ }
87
 
88
  void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
89
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
 
122
  nb1, nb2, nb3,
123
  sizeof(float), sizeof(sycl::half),
124
  stream
125
+ );
126
  break;
127
  default:
128
  GGML_ABORT("Unsupported tensor type!");