JohannesGaessler commited on
Commit
131a21e
·
1 Parent(s): fdb1fe5

RoPE: fix back, CUDA support for back + noncont. (llama/11240)

Browse files

* RoPE: fix back, CUDA support for back + noncont.

* fix comments reg. non-cont. RoPE support [no-ci]

ggml/include/ggml.h CHANGED
@@ -1500,7 +1500,7 @@ extern "C" {
1500
 
1501
  // rotary position embedding backward, i.e compute dx from dy
1502
  // a - dy
1503
- GGML_API struct ggml_tensor * ggml_rope_back(
1504
  struct ggml_context * ctx,
1505
  struct ggml_tensor * a, // gradients of ggml_rope result
1506
  struct ggml_tensor * b, // positions
@@ -1515,6 +1515,23 @@ extern "C" {
1515
  float beta_fast,
1516
  float beta_slow);
1517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1518
  // clamp
1519
  // in-place, returns view(a)
1520
  GGML_API struct ggml_tensor * ggml_clamp(
 
1500
 
1501
  // rotary position embedding backward, i.e compute dx from dy
1502
  // a - dy
1503
+ GGML_API struct ggml_tensor * ggml_rope_ext_back(
1504
  struct ggml_context * ctx,
1505
  struct ggml_tensor * a, // gradients of ggml_rope result
1506
  struct ggml_tensor * b, // positions
 
1515
  float beta_fast,
1516
  float beta_slow);
1517
 
1518
+ GGML_API struct ggml_tensor * ggml_rope_multi_back(
1519
+ struct ggml_context * ctx,
1520
+ struct ggml_tensor * a,
1521
+ struct ggml_tensor * b,
1522
+ struct ggml_tensor * c,
1523
+ int n_dims,
1524
+ int sections[4],
1525
+ int mode,
1526
+ int n_ctx_orig,
1527
+ float freq_base,
1528
+ float freq_scale,
1529
+ float ext_factor,
1530
+ float attn_factor,
1531
+ float beta_fast,
1532
+ float beta_slow);
1533
+
1534
+
1535
  // clamp
1536
  // in-place, returns view(a)
1537
  GGML_API struct ggml_tensor * ggml_clamp(
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -13668,6 +13668,7 @@ struct ggml_cplan ggml_graph_plan(
13668
  } break;
13669
  case GGML_OP_SOFT_MAX:
13670
  case GGML_OP_ROPE:
 
13671
  {
13672
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
13673
  } break;
 
13668
  } break;
13669
  case GGML_OP_SOFT_MAX:
13670
  case GGML_OP_ROPE:
13671
+ case GGML_OP_ROPE_BACK:
13672
  {
13673
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
13674
  } break;
ggml/src/ggml-cpu/ggml-cpu.cpp CHANGED
@@ -403,8 +403,6 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
403
  op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
404
  case GGML_OP_MUL_MAT:
405
  return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
406
- case GGML_OP_ROPE_BACK:
407
- return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
408
  case GGML_OP_IM2COL_BACK:
409
  return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
410
  case GGML_OP_OUT_PROD:
 
403
  op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
404
  case GGML_OP_MUL_MAT:
405
  return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
 
 
406
  case GGML_OP_IM2COL_BACK:
407
  return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
408
  case GGML_OP_OUT_PROD:
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -2141,6 +2141,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2141
  case GGML_OP_ROPE:
2142
  ggml_cuda_op_rope(ctx, dst);
2143
  break;
 
 
 
2144
  case GGML_OP_IM2COL:
2145
  ggml_cuda_op_im2col(ctx, dst);
2146
  break;
@@ -3025,7 +3028,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3025
  case GGML_OP_SOFT_MAX:
3026
  return true;
3027
  case GGML_OP_ROPE:
3028
- return ggml_is_contiguous(op->src[0]);
 
 
 
 
3029
  case GGML_OP_IM2COL:
3030
  case GGML_OP_POOL_2D:
3031
  case GGML_OP_SUM:
@@ -3081,6 +3088,7 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
3081
  return op->ne[1];
3082
  case GGML_OP_MUL_MAT_ID:
3083
  case GGML_OP_ROPE:
 
3084
  return op->ne[2];
3085
  default:
3086
  return ggml_nrows(op);
 
2141
  case GGML_OP_ROPE:
2142
  ggml_cuda_op_rope(ctx, dst);
2143
  break;
2144
+ case GGML_OP_ROPE_BACK:
2145
+ ggml_cuda_op_rope_back(ctx, dst);
2146
+ break;
2147
  case GGML_OP_IM2COL:
2148
  ggml_cuda_op_im2col(ctx, dst);
2149
  break;
 
3028
  case GGML_OP_SOFT_MAX:
3029
  return true;
3030
  case GGML_OP_ROPE:
3031
+ case GGML_OP_ROPE_BACK: {
3032
+ const size_t ts = ggml_type_size(op->src[0]->type);
3033
+ const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2];
3034
+ return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts;
3035
+ }
3036
  case GGML_OP_IM2COL:
3037
  case GGML_OP_POOL_2D:
3038
  case GGML_OP_SUM:
 
3088
  return op->ne[1];
3089
  case GGML_OP_MUL_MAT_ID:
3090
  case GGML_OP_ROPE:
3091
+ case GGML_OP_ROPE_BACK:
3092
  return op->ne[2];
3093
  default:
3094
  return ggml_nrows(op);
ggml/src/ggml-cuda/rope.cu CHANGED
@@ -16,9 +16,10 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
16
 
17
  // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
18
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 
19
  static __device__ void rope_yarn(
20
- float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
21
- float * cos_theta, float * sin_theta) {
22
  // Get n-d rotational scaling corrected for extrapolation
23
  float theta_interp = freq_scale * theta_extrap;
24
  float theta = theta_interp;
@@ -29,24 +30,28 @@ static __device__ void rope_yarn(
29
  // Get n-d magnitude scaling corrected for interpolation
30
  mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
31
  }
32
- *cos_theta = cosf(theta) * mscale;
33
- *sin_theta = sinf(theta) * mscale;
 
 
 
34
  }
35
 
36
- template<typename T, bool has_ff>
37
  static __global__ void rope_norm(
38
- const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
39
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
 
40
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
41
 
42
  if (i0 >= ne0) {
43
  return;
44
  }
45
 
46
- const int row = blockDim.x*blockIdx.x + threadIdx.x;
47
 
48
  if (i0 >= n_dims) {
49
- const int i = row*ne0 + i0;
50
 
51
  dst[i + 0] = x[i + 0];
52
  dst[i + 1] = x[i + 1];
@@ -54,39 +59,43 @@ static __global__ void rope_norm(
54
  return;
55
  }
56
 
57
- const int i = row*ne0 + i0;
58
- const int i2 = row/p_delta_rows;
 
 
 
59
 
60
- const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
61
 
62
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
63
 
64
  float cos_theta;
65
  float sin_theta;
66
 
67
- rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
68
 
69
- const float x0 = x[i + 0];
70
- const float x1 = x[i + 1];
71
 
72
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
73
- dst[i + 1] = x0*sin_theta + x1*cos_theta;
74
  }
75
 
76
- template<typename T, bool has_ff>
77
  static __global__ void rope_neox(
78
- const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
79
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
 
80
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
81
 
82
  if (i0 >= ne0) {
83
  return;
84
  }
85
 
86
- const int row = blockDim.x*blockIdx.x + threadIdx.x;
87
 
88
  if (i0 >= n_dims) {
89
- const int i = row*ne0 + i0;
90
 
91
  dst[i + 0] = x[i + 0];
92
  dst[i + 1] = x[i + 1];
@@ -94,39 +103,43 @@ static __global__ void rope_neox(
94
  return;
95
  }
96
 
97
- const int i = row*ne0 + i0/2;
98
- const int i2 = row/p_delta_rows;
 
 
 
99
 
100
- const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
101
 
102
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
103
 
104
  float cos_theta;
105
  float sin_theta;
106
 
107
- rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
108
 
109
- const float x0 = x[i + 0];
110
- const float x1 = x[i + n_dims/2];
111
 
112
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
113
- dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
114
  }
115
 
116
- template<typename T, bool has_ff>
117
  static __global__ void rope_multi(
118
- const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
119
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
 
120
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
121
 
122
  if (i0 >= ne0) {
123
  return;
124
  }
125
 
126
- const int row = blockDim.x*blockIdx.x + threadIdx.x;
127
 
128
  if (i0 >= n_dims) {
129
- const int i = row*ne0 + i0;
130
 
131
  dst[i + 0] = x[i + 0];
132
  dst[i + 1] = x[i + 1];
@@ -134,25 +147,28 @@ static __global__ void rope_multi(
134
  return;
135
  }
136
 
137
- const int i = row*ne0 + i0/2;
138
- const int i2 = row/p_delta_rows;
139
 
140
- int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
141
- int sec_w = sections.v[1] + sections.v[0];
142
- int sector = (i0 / 2) % sect_dims;
 
 
 
143
 
144
  float theta_base = 0.0;
145
  if (sector < sections.v[0]) {
146
- theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
147
  }
148
  else if (sector >= sections.v[0] && sector < sec_w) {
149
- theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f);
150
  }
151
  else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
152
- theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f);
153
  }
154
  else if (sector >= sec_w + sections.v[2]) {
155
- theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f);
156
  }
157
 
158
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -160,42 +176,46 @@ static __global__ void rope_multi(
160
  float cos_theta;
161
  float sin_theta;
162
 
163
- rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
164
 
165
- const float x0 = x[i + 0];
166
- const float x1 = x[i + n_dims/2];
167
 
168
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
169
- dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
170
  }
171
 
172
- template<typename T, bool has_ff>
173
  static __global__ void rope_vision(
174
- const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
175
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
 
176
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
177
 
178
  if (i0 >= ne0) {
179
  return;
180
  }
181
 
182
- const int row = blockDim.x*blockIdx.x + threadIdx.x;
 
 
 
183
 
184
- const int i = row*ne0 + i0/2;
185
- const int i2 = row/p_delta_rows; // i2-th tokens
186
 
187
- int sect_dims = sections.v[0] + sections.v[1];
188
- int sec_w = sections.v[1] + sections.v[0];
189
- int sector = (i0 / 2) % sect_dims;
190
 
191
  float theta_base = 0.0;
192
  if (sector < sections.v[0]) {
193
  const int p = sector;
194
- theta_base = pos[i2]*powf(theta_scale, p);
195
  }
196
  else if (sector >= sections.v[0] && sector < sec_w) {
197
  const int p = sector - sections.v[0];
198
- theta_base = pos[i2 + ne2]*powf(theta_scale, p);
199
  }
200
 
201
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -203,19 +223,20 @@ static __global__ void rope_vision(
203
  float cos_theta;
204
  float sin_theta;
205
 
206
- rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
207
 
208
- const float x0 = x[i + 0];
209
- const float x1 = x[i + n_dims];
210
 
211
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
212
- dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
213
  }
214
 
215
- template<typename T>
216
  static void rope_norm_cuda(
217
- const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
218
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
 
219
  GGML_ASSERT(ne0 % 2 == 0);
220
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
221
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -224,22 +245,21 @@ static void rope_norm_cuda(
224
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
225
 
226
  if (freq_factors == nullptr) {
227
- rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
228
- x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
229
- theta_scale, freq_factors
230
- );
231
  } else {
232
- rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
233
- x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
234
- theta_scale, freq_factors
235
- );
236
  }
237
  }
238
 
239
- template<typename T>
240
  static void rope_neox_cuda(
241
- const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
242
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
 
243
  GGML_ASSERT(ne0 % 2 == 0);
244
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
245
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -248,22 +268,21 @@ static void rope_neox_cuda(
248
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
249
 
250
  if (freq_factors == nullptr) {
251
- rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
252
- x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
253
- theta_scale, freq_factors
254
- );
255
  } else {
256
- rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
257
- x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
258
- theta_scale, freq_factors
259
- );
260
  }
261
  }
262
 
263
- template<typename T>
264
  static void rope_multi_cuda(
265
- const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
266
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
 
267
  GGML_ASSERT(ne0 % 2 == 0);
268
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
269
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -272,22 +291,21 @@ static void rope_multi_cuda(
272
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
273
 
274
  if (freq_factors == nullptr) {
275
- rope_multi<T, false><<<block_nums, block_dims, 0, stream>>>(
276
- x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
277
- theta_scale, freq_factors, sections
278
- );
279
  } else {
280
- rope_multi<T, true><<<block_nums, block_dims, 0, stream>>>(
281
- x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
282
- theta_scale, freq_factors, sections
283
- );
284
  }
285
  }
286
 
287
- template<typename T>
288
  static void rope_vision_cuda(
289
- const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
290
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
 
291
  GGML_ASSERT(ne0 % 2 == 0);
292
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
293
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -298,80 +316,18 @@ static void rope_vision_cuda(
298
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
299
 
300
  if (freq_factors == nullptr) {
301
- rope_vision<T, false><<<block_nums, block_dims, 0, stream>>>(
302
- x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
303
- theta_scale, freq_factors, sections
304
- );
305
  } else {
306
- rope_vision<T, true><<<block_nums, block_dims, 0, stream>>>(
307
- x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
308
- theta_scale, freq_factors, sections
309
- );
310
  }
311
  }
312
 
313
- static void rope_norm_cuda_f16(
314
- const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
315
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
316
-
317
- rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
318
- }
319
-
320
- static void rope_norm_cuda_f32(
321
- const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
322
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
323
-
324
- rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
325
- }
326
-
327
- static void rope_neox_cuda_f16(
328
- const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
329
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
330
-
331
- rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
332
- }
333
-
334
- static void rope_neox_cuda_f32(
335
- const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
336
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
337
- ) {
338
-
339
- rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
340
- }
341
-
342
- static void rope_multi_cuda_f16(
343
- const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
344
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
345
- ) {
346
-
347
- rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
348
- }
349
-
350
- static void rope_multi_cuda_f32(
351
- const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
352
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
353
- ) {
354
-
355
- rope_multi_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
356
- }
357
-
358
- static void rope_vision_cuda_f16(
359
- const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
360
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
361
- ) {
362
-
363
- rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
364
- }
365
-
366
- static void rope_vision_cuda_f32(
367
- const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
368
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
369
- ) {
370
-
371
- rope_vision_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
372
- }
373
-
374
- void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
375
  const ggml_tensor * src0 = dst->src[0];
376
  const ggml_tensor * src1 = dst->src[1];
377
  const ggml_tensor * src2 = dst->src[2];
@@ -382,7 +338,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
382
  float * dst_d = (float *)dst->data;
383
  cudaStream_t stream = ctx.stream();
384
 
385
- GGML_ASSERT(ggml_is_contiguous(src0));
386
  GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
387
  GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
388
  GGML_ASSERT(src0->type == dst->type);
@@ -392,6 +347,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
392
  const int64_t ne02 = src0->ne[2]; // num heads
393
  const int64_t nr = ggml_nrows(src0);
394
 
 
 
 
395
  //const int n_past = ((int32_t *) dst->op_params)[0];
396
  const int n_dims = ((int32_t *) dst->op_params)[1];
397
  const int mode = ((int32_t *) dst->op_params)[2];
@@ -440,59 +398,59 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
440
  // compute
441
  if (is_neox) {
442
  if (src0->type == GGML_TYPE_F32) {
443
- rope_neox_cuda_f32(
444
- (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
445
- attn_factor, corr_dims, freq_factors, stream
446
- );
447
  } else if (src0->type == GGML_TYPE_F16) {
448
- rope_neox_cuda_f16(
449
- (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
450
- attn_factor, corr_dims, freq_factors, stream
451
- );
452
  } else {
453
  GGML_ABORT("fatal error");
454
  }
455
  } else if (is_mrope && !is_vision) {
456
  if (src0->type == GGML_TYPE_F32) {
457
- rope_multi_cuda_f32(
458
- (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
459
- attn_factor, corr_dims, freq_factors, sections, stream
460
- );
461
  } else if (src0->type == GGML_TYPE_F16) {
462
- rope_multi_cuda_f16(
463
- (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
464
- attn_factor, corr_dims, freq_factors, sections, stream
465
- );
466
  } else {
467
  GGML_ABORT("fatal error");
468
  }
469
  } else if (is_vision) {
470
  if (src0->type == GGML_TYPE_F32) {
471
- rope_vision_cuda_f32(
472
- (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
473
- attn_factor, corr_dims, freq_factors, sections, stream
474
- );
475
  } else if (src0->type == GGML_TYPE_F16) {
476
- rope_vision_cuda_f16(
477
- (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
478
- attn_factor, corr_dims, freq_factors, sections, stream
479
- );
480
  } else {
481
  GGML_ABORT("fatal error");
482
  }
483
  } else {
484
  if (src0->type == GGML_TYPE_F32) {
485
- rope_norm_cuda_f32(
486
- (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
487
- attn_factor, corr_dims, freq_factors, stream
488
- );
489
  } else if (src0->type == GGML_TYPE_F16) {
490
- rope_norm_cuda_f16(
491
- (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
492
- attn_factor, corr_dims, freq_factors, stream
493
- );
494
  } else {
495
  GGML_ABORT("fatal error");
496
  }
497
  }
498
  }
 
 
 
 
 
 
 
 
 
16
 
17
  // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
18
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
19
+ template<bool forward>
20
  static __device__ void rope_yarn(
21
+ const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor,
22
+ float mscale, float & cos_theta, float & sin_theta) {
23
  // Get n-d rotational scaling corrected for extrapolation
24
  float theta_interp = freq_scale * theta_extrap;
25
  float theta = theta_interp;
 
30
  // Get n-d magnitude scaling corrected for interpolation
31
  mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
32
  }
33
+ cos_theta = cosf(theta) * mscale;
34
+ sin_theta = sinf(theta) * mscale;
35
+ if (!forward) {
36
+ sin_theta *= -1.0f;
37
+ }
38
  }
39
 
40
+ template<bool forward, bool has_ff, typename T>
41
  static __global__ void rope_norm(
42
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
43
+ const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
44
+ const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
45
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
46
 
47
  if (i0 >= ne0) {
48
  return;
49
  }
50
 
51
+ const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
52
 
53
  if (i0 >= n_dims) {
54
+ const int i = row_dst*ne0 + i0;
55
 
56
  dst[i + 0] = x[i + 0];
57
  dst[i + 1] = x[i + 1];
 
59
  return;
60
  }
61
 
62
+ const int row_x = row_dst % ne1;
63
+ const int channel_x = row_dst / ne1;
64
+
65
+ const int idst = row_dst*ne0 + i0;
66
+ const int ix = channel_x*s2 + row_x*s1 + i0;
67
 
68
+ const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
69
 
70
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
71
 
72
  float cos_theta;
73
  float sin_theta;
74
 
75
+ rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
76
 
77
+ const float x0 = x[ix + 0];
78
+ const float x1 = x[ix + 1];
79
 
80
+ dst[idst + 0] = x0*cos_theta - x1*sin_theta;
81
+ dst[idst + 1] = x0*sin_theta + x1*cos_theta;
82
  }
83
 
84
+ template<bool forward, bool has_ff, typename T>
85
  static __global__ void rope_neox(
86
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
87
+ const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
88
+ const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
89
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
90
 
91
  if (i0 >= ne0) {
92
  return;
93
  }
94
 
95
+ const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
96
 
97
  if (i0 >= n_dims) {
98
+ const int i = row_dst*ne0 + i0;
99
 
100
  dst[i + 0] = x[i + 0];
101
  dst[i + 1] = x[i + 1];
 
103
  return;
104
  }
105
 
106
+ const int row_x = row_dst % ne1;
107
+ const int channel_x = row_dst / ne1;
108
+
109
+ const int idst = row_dst*ne0 + i0/2;
110
+ const int ix = channel_x*s2 + row_x*s1 + i0/2;
111
 
112
+ const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
113
 
114
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
115
 
116
  float cos_theta;
117
  float sin_theta;
118
 
119
+ rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
120
 
121
+ const float x0 = x[ix + 0];
122
+ const float x1 = x[ix + n_dims/2];
123
 
124
+ dst[idst + 0] = x0*cos_theta - x1*sin_theta;
125
+ dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
126
  }
127
 
128
+ template<bool forward, bool has_ff, typename T>
129
  static __global__ void rope_multi(
130
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
131
+ const int n_dims, const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
132
+ const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
133
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
134
 
135
  if (i0 >= ne0) {
136
  return;
137
  }
138
 
139
+ const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
140
 
141
  if (i0 >= n_dims) {
142
+ const int i = row_dst*ne0 + i0;
143
 
144
  dst[i + 0] = x[i + 0];
145
  dst[i + 1] = x[i + 1];
 
147
  return;
148
  }
149
 
150
+ const int row_x = row_dst % ne1;
151
+ const int channel_x = row_dst / ne1;
152
 
153
+ const int idst = row_dst*ne0 + i0/2;
154
+ const int ix = channel_x*s2 + row_x*s1 + i0/2;
155
+
156
+ const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
157
+ const int sec_w = sections.v[1] + sections.v[0];
158
+ const int sector = (i0 / 2) % sect_dims;
159
 
160
  float theta_base = 0.0;
161
  if (sector < sections.v[0]) {
162
+ theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
163
  }
164
  else if (sector >= sections.v[0] && sector < sec_w) {
165
+ theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
166
  }
167
  else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
168
+ theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
169
  }
170
  else if (sector >= sec_w + sections.v[2]) {
171
+ theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
172
  }
173
 
174
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
176
  float cos_theta;
177
  float sin_theta;
178
 
179
+ rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
180
 
181
+ const float x0 = x[ix + 0];
182
+ const float x1 = x[ix + n_dims/2];
183
 
184
+ dst[idst + 0] = x0*cos_theta - x1*sin_theta;
185
+ dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
186
  }
187
 
188
+ template<bool forward, bool has_ff, typename T>
189
  static __global__ void rope_vision(
190
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
191
+ const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
192
+ const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
193
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
194
 
195
  if (i0 >= ne0) {
196
  return;
197
  }
198
 
199
+ const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
200
+
201
+ const int row_x = row_dst % ne1;
202
+ const int channel_x = row_dst / ne1;
203
 
204
+ const int idst = row_dst*ne0 + i0/2;
205
+ const int ix = channel_x*s2 + row_x*s1 + i0/2;
206
 
207
+ const int sect_dims = sections.v[0] + sections.v[1];
208
+ const int sec_w = sections.v[1] + sections.v[0];
209
+ const int sector = (i0 / 2) % sect_dims;
210
 
211
  float theta_base = 0.0;
212
  if (sector < sections.v[0]) {
213
  const int p = sector;
214
+ theta_base = pos[channel_x]*powf(theta_scale, p);
215
  }
216
  else if (sector >= sections.v[0] && sector < sec_w) {
217
  const int p = sector - sections.v[0];
218
+ theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
219
  }
220
 
221
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
223
  float cos_theta;
224
  float sin_theta;
225
 
226
+ rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
227
 
228
+ const float x0 = x[ix + 0];
229
+ const float x1 = x[ix + n_dims];
230
 
231
+ dst[idst + 0] = x0*cos_theta - x1*sin_theta;
232
+ dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
233
  }
234
 
235
+ template<bool forward, typename T>
236
  static void rope_norm_cuda(
237
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
238
+ const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
239
+ const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
240
  GGML_ASSERT(ne0 % 2 == 0);
241
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
242
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
 
245
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
246
 
247
  if (freq_factors == nullptr) {
248
+ rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
249
+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
250
+ attn_factor, corr_dims, theta_scale, freq_factors);
 
251
  } else {
252
+ rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
253
+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
254
+ attn_factor, corr_dims, theta_scale, freq_factors);
 
255
  }
256
  }
257
 
258
+ template<bool forward, typename T>
259
  static void rope_neox_cuda(
260
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
261
+ const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
262
+ const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
263
  GGML_ASSERT(ne0 % 2 == 0);
264
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
265
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
 
268
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
269
 
270
  if (freq_factors == nullptr) {
271
+ rope_neox<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
272
+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
273
+ attn_factor, corr_dims, theta_scale, freq_factors);
 
274
  } else {
275
+ rope_neox<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
276
+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
277
+ attn_factor, corr_dims, theta_scale, freq_factors);
 
278
  }
279
  }
280
 
281
+ template<bool forward, typename T>
282
  static void rope_multi_cuda(
283
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
284
+ const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
285
+ const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
286
  GGML_ASSERT(ne0 % 2 == 0);
287
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
288
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
 
291
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
292
 
293
  if (freq_factors == nullptr) {
294
+ rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
295
+ x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
296
+ attn_factor, corr_dims, theta_scale, freq_factors, sections);
 
297
  } else {
298
+ rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
299
+ x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
300
+ attn_factor, corr_dims, theta_scale, freq_factors, sections);
 
301
  }
302
  }
303
 
304
+ template<bool forward, typename T>
305
  static void rope_vision_cuda(
306
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
307
+ const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
308
+ const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
309
  GGML_ASSERT(ne0 % 2 == 0);
310
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
311
  const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
 
316
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
317
 
318
  if (freq_factors == nullptr) {
319
+ rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
320
+ x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
321
+ attn_factor, corr_dims, theta_scale, freq_factors, sections);
 
322
  } else {
323
+ rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
324
+ x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
325
+ attn_factor, corr_dims, theta_scale, freq_factors, sections);
 
326
  }
327
  }
328
 
329
+ template <bool forward>
330
+ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  const ggml_tensor * src0 = dst->src[0];
332
  const ggml_tensor * src1 = dst->src[1];
333
  const ggml_tensor * src2 = dst->src[2];
 
338
  float * dst_d = (float *)dst->data;
339
  cudaStream_t stream = ctx.stream();
340
 
 
341
  GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
342
  GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
343
  GGML_ASSERT(src0->type == dst->type);
 
347
  const int64_t ne02 = src0->ne[2]; // num heads
348
  const int64_t nr = ggml_nrows(src0);
349
 
350
+ const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
351
+ const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
352
+
353
  //const int n_past = ((int32_t *) dst->op_params)[0];
354
  const int n_dims = ((int32_t *) dst->op_params)[1];
355
  const int mode = ((int32_t *) dst->op_params)[2];
 
398
  // compute
399
  if (is_neox) {
400
  if (src0->type == GGML_TYPE_F32) {
401
+ rope_neox_cuda<forward>(
402
+ (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
403
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 
404
  } else if (src0->type == GGML_TYPE_F16) {
405
+ rope_neox_cuda<forward>(
406
+ (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
407
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 
408
  } else {
409
  GGML_ABORT("fatal error");
410
  }
411
  } else if (is_mrope && !is_vision) {
412
  if (src0->type == GGML_TYPE_F32) {
413
+ rope_multi_cuda<forward>(
414
+ (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
415
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
 
416
  } else if (src0->type == GGML_TYPE_F16) {
417
+ rope_multi_cuda<forward>(
418
+ (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
419
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
 
420
  } else {
421
  GGML_ABORT("fatal error");
422
  }
423
  } else if (is_vision) {
424
  if (src0->type == GGML_TYPE_F32) {
425
+ rope_vision_cuda<forward>(
426
+ (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
427
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
 
428
  } else if (src0->type == GGML_TYPE_F16) {
429
+ rope_vision_cuda<forward>(
430
+ (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
431
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
 
432
  } else {
433
  GGML_ABORT("fatal error");
434
  }
435
  } else {
436
  if (src0->type == GGML_TYPE_F32) {
437
+ rope_norm_cuda<forward>(
438
+ (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
439
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 
440
  } else if (src0->type == GGML_TYPE_F16) {
441
+ rope_norm_cuda<forward>(
442
+ (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
443
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 
444
  } else {
445
  GGML_ABORT("fatal error");
446
  }
447
  }
448
  }
449
+
450
+ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
451
+ ggml_cuda_op_rope_impl<true>(ctx, dst);
452
+ }
453
+
454
+ void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
455
+ ggml_cuda_op_rope_impl<false>(ctx, dst);
456
+ }
ggml/src/ggml-cuda/rope.cuh CHANGED
@@ -3,3 +3,5 @@
3
  #define CUDA_ROPE_BLOCK_SIZE 256
4
 
5
  void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 
 
3
  #define CUDA_ROPE_BLOCK_SIZE 256
4
 
5
  void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6
+
7
+ void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml.c CHANGED
@@ -3699,7 +3699,7 @@ void ggml_rope_yarn_corr_dims(
3699
 
3700
  // ggml_rope_back
3701
 
3702
- struct ggml_tensor * ggml_rope_back(
3703
  struct ggml_context * ctx,
3704
  struct ggml_tensor * a,
3705
  struct ggml_tensor * b,
@@ -3713,29 +3713,32 @@ struct ggml_tensor * ggml_rope_back(
3713
  float attn_factor,
3714
  float beta_fast,
3715
  float beta_slow) {
3716
- GGML_ASSERT(ggml_is_vector(b));
3717
- GGML_ASSERT(b->type == GGML_TYPE_I32);
3718
- GGML_ASSERT(a->ne[2] == b->ne[0]);
3719
-
3720
- struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
3721
-
3722
- int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
3723
- memcpy(params + 5, &freq_base, sizeof(float));
3724
- memcpy(params + 6, &freq_scale, sizeof(float));
3725
- memcpy(params + 7, &ext_factor, sizeof(float));
3726
- memcpy(params + 8, &attn_factor, sizeof(float));
3727
- memcpy(params + 9, &beta_fast, sizeof(float));
3728
- memcpy(params + 10, &beta_slow, sizeof(float));
3729
- ggml_set_op_params(result, params, sizeof(params));
3730
-
3731
- result->op = GGML_OP_ROPE_BACK;
3732
- result->src[0] = a;
3733
- result->src[1] = b;
3734
- result->src[2] = c;
3735
-
3736
  return result;
3737
  }
3738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3739
  // ggml_clamp
3740
 
3741
  struct ggml_tensor * ggml_clamp(
@@ -5598,6 +5601,7 @@ static void ggml_compute_backward(
5598
  //const int n_ctx = ((int32_t *) tensor->op_params)[3];
5599
  const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
5600
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
5601
 
5602
  memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float));
5603
  memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float));
@@ -5605,10 +5609,14 @@ static void ggml_compute_backward(
5605
  memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float));
5606
  memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float));
5607
  memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float));
5608
-
5609
- ggml_add_or_set(ctx, cgraph, isrc0,
5610
- ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base,
5611
- freq_scale, ext_factor, attn_factor, beta_fast, beta_slow));
 
 
 
 
5612
  }
5613
  GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
5614
  } break;
 
3699
 
3700
  // ggml_rope_back
3701
 
3702
+ struct ggml_tensor * ggml_rope_ext_back(
3703
  struct ggml_context * ctx,
3704
  struct ggml_tensor * a,
3705
  struct ggml_tensor * b,
 
3713
  float attn_factor,
3714
  float beta_fast,
3715
  float beta_slow) {
3716
+ struct ggml_tensor * result = ggml_rope_ext(
3717
+ ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
3718
+ result->op = GGML_OP_ROPE_BACK;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3719
  return result;
3720
  }
3721
 
3722
+ struct ggml_tensor * ggml_rope_multi_back(
3723
+ struct ggml_context * ctx,
3724
+ struct ggml_tensor * a,
3725
+ struct ggml_tensor * b,
3726
+ struct ggml_tensor * c,
3727
+ int n_dims,
3728
+ int sections[4],
3729
+ int mode,
3730
+ int n_ctx_orig,
3731
+ float freq_base,
3732
+ float freq_scale,
3733
+ float ext_factor,
3734
+ float attn_factor,
3735
+ float beta_fast,
3736
+ float beta_slow) {
3737
+ struct ggml_tensor * result = ggml_rope_multi(
3738
+ ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
3739
+ result->op = GGML_OP_ROPE_BACK;
3740
+ return result;
3741
+ }
3742
  // ggml_clamp
3743
 
3744
  struct ggml_tensor * ggml_clamp(
 
5601
  //const int n_ctx = ((int32_t *) tensor->op_params)[3];
5602
  const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
5603
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5604
+ int sections[4] = {0, 0, 0, 0};
5605
 
5606
  memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float));
5607
  memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float));
 
5609
  memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float));
5610
  memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float));
5611
  memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float));
5612
+ memcpy(&sections, tensor->op_params + 11, sizeof(sections));
5613
+
5614
+ struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ?
5615
+ ggml_rope_ext_back(ctx, grad, src1, src2, n_dims,
5616
+ mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) :
5617
+ ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections,
5618
+ mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
5619
+ ggml_add_or_set(ctx, cgraph, isrc0, rope_back);
5620
  }
5621
  GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
5622
  } break;