Spaces:
Running
Running
Commit
·
131a21e
1
Parent(s):
fdb1fe5
RoPE: fix back, CUDA support for back + noncont. (llama/11240)
Browse files* RoPE: fix back, CUDA support for back + noncont.
* fix comments reg. non-cont. RoPE support [no-ci]
- ggml/include/ggml.h +18 -1
- ggml/src/ggml-cpu/ggml-cpu.c +1 -0
- ggml/src/ggml-cpu/ggml-cpu.cpp +0 -2
- ggml/src/ggml-cuda/ggml-cuda.cu +9 -1
- ggml/src/ggml-cuda/rope.cu +162 -204
- ggml/src/ggml-cuda/rope.cuh +2 -0
- ggml/src/ggml.c +33 -25
ggml/include/ggml.h
CHANGED
|
@@ -1500,7 +1500,7 @@ extern "C" {
|
|
| 1500 |
|
| 1501 |
// rotary position embedding backward, i.e compute dx from dy
|
| 1502 |
// a - dy
|
| 1503 |
-
GGML_API struct ggml_tensor *
|
| 1504 |
struct ggml_context * ctx,
|
| 1505 |
struct ggml_tensor * a, // gradients of ggml_rope result
|
| 1506 |
struct ggml_tensor * b, // positions
|
|
@@ -1515,6 +1515,23 @@ extern "C" {
|
|
| 1515 |
float beta_fast,
|
| 1516 |
float beta_slow);
|
| 1517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1518 |
// clamp
|
| 1519 |
// in-place, returns view(a)
|
| 1520 |
GGML_API struct ggml_tensor * ggml_clamp(
|
|
|
|
| 1500 |
|
| 1501 |
// rotary position embedding backward, i.e compute dx from dy
|
| 1502 |
// a - dy
|
| 1503 |
+
GGML_API struct ggml_tensor * ggml_rope_ext_back(
|
| 1504 |
struct ggml_context * ctx,
|
| 1505 |
struct ggml_tensor * a, // gradients of ggml_rope result
|
| 1506 |
struct ggml_tensor * b, // positions
|
|
|
|
| 1515 |
float beta_fast,
|
| 1516 |
float beta_slow);
|
| 1517 |
|
| 1518 |
+
GGML_API struct ggml_tensor * ggml_rope_multi_back(
|
| 1519 |
+
struct ggml_context * ctx,
|
| 1520 |
+
struct ggml_tensor * a,
|
| 1521 |
+
struct ggml_tensor * b,
|
| 1522 |
+
struct ggml_tensor * c,
|
| 1523 |
+
int n_dims,
|
| 1524 |
+
int sections[4],
|
| 1525 |
+
int mode,
|
| 1526 |
+
int n_ctx_orig,
|
| 1527 |
+
float freq_base,
|
| 1528 |
+
float freq_scale,
|
| 1529 |
+
float ext_factor,
|
| 1530 |
+
float attn_factor,
|
| 1531 |
+
float beta_fast,
|
| 1532 |
+
float beta_slow);
|
| 1533 |
+
|
| 1534 |
+
|
| 1535 |
// clamp
|
| 1536 |
// in-place, returns view(a)
|
| 1537 |
GGML_API struct ggml_tensor * ggml_clamp(
|
ggml/src/ggml-cpu/ggml-cpu.c
CHANGED
|
@@ -13668,6 +13668,7 @@ struct ggml_cplan ggml_graph_plan(
|
|
| 13668 |
} break;
|
| 13669 |
case GGML_OP_SOFT_MAX:
|
| 13670 |
case GGML_OP_ROPE:
|
|
|
|
| 13671 |
{
|
| 13672 |
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
| 13673 |
} break;
|
|
|
|
| 13668 |
} break;
|
| 13669 |
case GGML_OP_SOFT_MAX:
|
| 13670 |
case GGML_OP_ROPE:
|
| 13671 |
+
case GGML_OP_ROPE_BACK:
|
| 13672 |
{
|
| 13673 |
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
| 13674 |
} break;
|
ggml/src/ggml-cpu/ggml-cpu.cpp
CHANGED
|
@@ -403,8 +403,6 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
|
|
| 403 |
op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
|
| 404 |
case GGML_OP_MUL_MAT:
|
| 405 |
return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
|
| 406 |
-
case GGML_OP_ROPE_BACK:
|
| 407 |
-
return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
|
| 408 |
case GGML_OP_IM2COL_BACK:
|
| 409 |
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
|
| 410 |
case GGML_OP_OUT_PROD:
|
|
|
|
| 403 |
op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
|
| 404 |
case GGML_OP_MUL_MAT:
|
| 405 |
return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
|
|
|
|
|
|
|
| 406 |
case GGML_OP_IM2COL_BACK:
|
| 407 |
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
|
| 408 |
case GGML_OP_OUT_PROD:
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -2141,6 +2141,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2141 |
case GGML_OP_ROPE:
|
| 2142 |
ggml_cuda_op_rope(ctx, dst);
|
| 2143 |
break;
|
|
|
|
|
|
|
|
|
|
| 2144 |
case GGML_OP_IM2COL:
|
| 2145 |
ggml_cuda_op_im2col(ctx, dst);
|
| 2146 |
break;
|
|
@@ -3025,7 +3028,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3025 |
case GGML_OP_SOFT_MAX:
|
| 3026 |
return true;
|
| 3027 |
case GGML_OP_ROPE:
|
| 3028 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3029 |
case GGML_OP_IM2COL:
|
| 3030 |
case GGML_OP_POOL_2D:
|
| 3031 |
case GGML_OP_SUM:
|
|
@@ -3081,6 +3088,7 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
|
|
| 3081 |
return op->ne[1];
|
| 3082 |
case GGML_OP_MUL_MAT_ID:
|
| 3083 |
case GGML_OP_ROPE:
|
|
|
|
| 3084 |
return op->ne[2];
|
| 3085 |
default:
|
| 3086 |
return ggml_nrows(op);
|
|
|
|
| 2141 |
case GGML_OP_ROPE:
|
| 2142 |
ggml_cuda_op_rope(ctx, dst);
|
| 2143 |
break;
|
| 2144 |
+
case GGML_OP_ROPE_BACK:
|
| 2145 |
+
ggml_cuda_op_rope_back(ctx, dst);
|
| 2146 |
+
break;
|
| 2147 |
case GGML_OP_IM2COL:
|
| 2148 |
ggml_cuda_op_im2col(ctx, dst);
|
| 2149 |
break;
|
|
|
|
| 3028 |
case GGML_OP_SOFT_MAX:
|
| 3029 |
return true;
|
| 3030 |
case GGML_OP_ROPE:
|
| 3031 |
+
case GGML_OP_ROPE_BACK: {
|
| 3032 |
+
const size_t ts = ggml_type_size(op->src[0]->type);
|
| 3033 |
+
const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2];
|
| 3034 |
+
return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts;
|
| 3035 |
+
}
|
| 3036 |
case GGML_OP_IM2COL:
|
| 3037 |
case GGML_OP_POOL_2D:
|
| 3038 |
case GGML_OP_SUM:
|
|
|
|
| 3088 |
return op->ne[1];
|
| 3089 |
case GGML_OP_MUL_MAT_ID:
|
| 3090 |
case GGML_OP_ROPE:
|
| 3091 |
+
case GGML_OP_ROPE_BACK:
|
| 3092 |
return op->ne[2];
|
| 3093 |
default:
|
| 3094 |
return ggml_nrows(op);
|
ggml/src/ggml-cuda/rope.cu
CHANGED
|
@@ -16,9 +16,10 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
|
|
| 16 |
|
| 17 |
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
| 18 |
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
|
|
|
| 19 |
static __device__ void rope_yarn(
|
| 20 |
-
|
| 21 |
-
|
| 22 |
// Get n-d rotational scaling corrected for extrapolation
|
| 23 |
float theta_interp = freq_scale * theta_extrap;
|
| 24 |
float theta = theta_interp;
|
|
@@ -29,24 +30,28 @@ static __device__ void rope_yarn(
|
|
| 29 |
// Get n-d magnitude scaling corrected for interpolation
|
| 30 |
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
| 31 |
}
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
| 34 |
}
|
| 35 |
|
| 36 |
-
template<
|
| 37 |
static __global__ void rope_norm(
|
| 38 |
-
|
| 39 |
-
|
|
|
|
| 40 |
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
| 41 |
|
| 42 |
if (i0 >= ne0) {
|
| 43 |
return;
|
| 44 |
}
|
| 45 |
|
| 46 |
-
const int
|
| 47 |
|
| 48 |
if (i0 >= n_dims) {
|
| 49 |
-
const int i =
|
| 50 |
|
| 51 |
dst[i + 0] = x[i + 0];
|
| 52 |
dst[i + 1] = x[i + 1];
|
|
@@ -54,39 +59,43 @@ static __global__ void rope_norm(
|
|
| 54 |
return;
|
| 55 |
}
|
| 56 |
|
| 57 |
-
const int
|
| 58 |
-
const int
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
const float theta_base = pos[
|
| 61 |
|
| 62 |
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
| 63 |
|
| 64 |
float cos_theta;
|
| 65 |
float sin_theta;
|
| 66 |
|
| 67 |
-
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor,
|
| 68 |
|
| 69 |
-
const float x0 = x[
|
| 70 |
-
const float x1 = x[
|
| 71 |
|
| 72 |
-
dst[
|
| 73 |
-
dst[
|
| 74 |
}
|
| 75 |
|
| 76 |
-
template<
|
| 77 |
static __global__ void rope_neox(
|
| 78 |
-
|
| 79 |
-
|
|
|
|
| 80 |
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
| 81 |
|
| 82 |
if (i0 >= ne0) {
|
| 83 |
return;
|
| 84 |
}
|
| 85 |
|
| 86 |
-
const int
|
| 87 |
|
| 88 |
if (i0 >= n_dims) {
|
| 89 |
-
const int i =
|
| 90 |
|
| 91 |
dst[i + 0] = x[i + 0];
|
| 92 |
dst[i + 1] = x[i + 1];
|
|
@@ -94,39 +103,43 @@ static __global__ void rope_neox(
|
|
| 94 |
return;
|
| 95 |
}
|
| 96 |
|
| 97 |
-
const int
|
| 98 |
-
const int
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
const float theta_base = pos[
|
| 101 |
|
| 102 |
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
| 103 |
|
| 104 |
float cos_theta;
|
| 105 |
float sin_theta;
|
| 106 |
|
| 107 |
-
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor,
|
| 108 |
|
| 109 |
-
const float x0 = x[
|
| 110 |
-
const float x1 = x[
|
| 111 |
|
| 112 |
-
dst[
|
| 113 |
-
dst[
|
| 114 |
}
|
| 115 |
|
| 116 |
-
template<
|
| 117 |
static __global__ void rope_multi(
|
| 118 |
-
|
| 119 |
-
|
|
|
|
| 120 |
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
| 121 |
|
| 122 |
if (i0 >= ne0) {
|
| 123 |
return;
|
| 124 |
}
|
| 125 |
|
| 126 |
-
const int
|
| 127 |
|
| 128 |
if (i0 >= n_dims) {
|
| 129 |
-
const int i =
|
| 130 |
|
| 131 |
dst[i + 0] = x[i + 0];
|
| 132 |
dst[i + 1] = x[i + 1];
|
|
@@ -134,25 +147,28 @@ static __global__ void rope_multi(
|
|
| 134 |
return;
|
| 135 |
}
|
| 136 |
|
| 137 |
-
const int
|
| 138 |
-
const int
|
| 139 |
|
| 140 |
-
int
|
| 141 |
-
int
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
float theta_base = 0.0;
|
| 145 |
if (sector < sections.v[0]) {
|
| 146 |
-
theta_base = pos[
|
| 147 |
}
|
| 148 |
else if (sector >= sections.v[0] && sector < sec_w) {
|
| 149 |
-
theta_base = pos[
|
| 150 |
}
|
| 151 |
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
| 152 |
-
theta_base = pos[
|
| 153 |
}
|
| 154 |
else if (sector >= sec_w + sections.v[2]) {
|
| 155 |
-
theta_base = pos[
|
| 156 |
}
|
| 157 |
|
| 158 |
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
@@ -160,42 +176,46 @@ static __global__ void rope_multi(
|
|
| 160 |
float cos_theta;
|
| 161 |
float sin_theta;
|
| 162 |
|
| 163 |
-
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor,
|
| 164 |
|
| 165 |
-
const float x0 = x[
|
| 166 |
-
const float x1 = x[
|
| 167 |
|
| 168 |
-
dst[
|
| 169 |
-
dst[
|
| 170 |
}
|
| 171 |
|
| 172 |
-
template<
|
| 173 |
static __global__ void rope_vision(
|
| 174 |
-
|
| 175 |
-
|
|
|
|
| 176 |
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
| 177 |
|
| 178 |
if (i0 >= ne0) {
|
| 179 |
return;
|
| 180 |
}
|
| 181 |
|
| 182 |
-
const int
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
-
const int
|
| 185 |
-
const int
|
| 186 |
|
| 187 |
-
int sect_dims = sections.v[0] + sections.v[1];
|
| 188 |
-
int sec_w = sections.v[1] + sections.v[0];
|
| 189 |
-
int sector = (i0 / 2) % sect_dims;
|
| 190 |
|
| 191 |
float theta_base = 0.0;
|
| 192 |
if (sector < sections.v[0]) {
|
| 193 |
const int p = sector;
|
| 194 |
-
theta_base = pos[
|
| 195 |
}
|
| 196 |
else if (sector >= sections.v[0] && sector < sec_w) {
|
| 197 |
const int p = sector - sections.v[0];
|
| 198 |
-
theta_base = pos[
|
| 199 |
}
|
| 200 |
|
| 201 |
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
@@ -203,19 +223,20 @@ static __global__ void rope_vision(
|
|
| 203 |
float cos_theta;
|
| 204 |
float sin_theta;
|
| 205 |
|
| 206 |
-
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor,
|
| 207 |
|
| 208 |
-
const float x0 = x[
|
| 209 |
-
const float x1 = x[
|
| 210 |
|
| 211 |
-
dst[
|
| 212 |
-
dst[
|
| 213 |
}
|
| 214 |
|
| 215 |
-
template<typename T>
|
| 216 |
static void rope_norm_cuda(
|
| 217 |
-
|
| 218 |
-
|
|
|
|
| 219 |
GGML_ASSERT(ne0 % 2 == 0);
|
| 220 |
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
| 221 |
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
@@ -224,22 +245,21 @@ static void rope_norm_cuda(
|
|
| 224 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
| 225 |
|
| 226 |
if (freq_factors == nullptr) {
|
| 227 |
-
rope_norm<
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
);
|
| 231 |
} else {
|
| 232 |
-
rope_norm<
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
);
|
| 236 |
}
|
| 237 |
}
|
| 238 |
|
| 239 |
-
template<typename T>
|
| 240 |
static void rope_neox_cuda(
|
| 241 |
-
|
| 242 |
-
|
|
|
|
| 243 |
GGML_ASSERT(ne0 % 2 == 0);
|
| 244 |
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
| 245 |
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
@@ -248,22 +268,21 @@ static void rope_neox_cuda(
|
|
| 248 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
| 249 |
|
| 250 |
if (freq_factors == nullptr) {
|
| 251 |
-
rope_neox<
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
);
|
| 255 |
} else {
|
| 256 |
-
rope_neox<
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
);
|
| 260 |
}
|
| 261 |
}
|
| 262 |
|
| 263 |
-
template<typename T>
|
| 264 |
static void rope_multi_cuda(
|
| 265 |
-
|
| 266 |
-
|
|
|
|
| 267 |
GGML_ASSERT(ne0 % 2 == 0);
|
| 268 |
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
| 269 |
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
@@ -272,22 +291,21 @@ static void rope_multi_cuda(
|
|
| 272 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
| 273 |
|
| 274 |
if (freq_factors == nullptr) {
|
| 275 |
-
rope_multi<
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
);
|
| 279 |
} else {
|
| 280 |
-
rope_multi<
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
);
|
| 284 |
}
|
| 285 |
}
|
| 286 |
|
| 287 |
-
template<typename T>
|
| 288 |
static void rope_vision_cuda(
|
| 289 |
-
|
| 290 |
-
|
|
|
|
| 291 |
GGML_ASSERT(ne0 % 2 == 0);
|
| 292 |
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
| 293 |
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
@@ -298,80 +316,18 @@ static void rope_vision_cuda(
|
|
| 298 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
| 299 |
|
| 300 |
if (freq_factors == nullptr) {
|
| 301 |
-
rope_vision<
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
);
|
| 305 |
} else {
|
| 306 |
-
rope_vision<
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
);
|
| 310 |
}
|
| 311 |
}
|
| 312 |
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
| 316 |
-
|
| 317 |
-
rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
| 318 |
-
}
|
| 319 |
-
|
| 320 |
-
static void rope_norm_cuda_f32(
|
| 321 |
-
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 322 |
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
| 323 |
-
|
| 324 |
-
rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
| 325 |
-
}
|
| 326 |
-
|
| 327 |
-
static void rope_neox_cuda_f16(
|
| 328 |
-
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 329 |
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
| 330 |
-
|
| 331 |
-
rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
| 332 |
-
}
|
| 333 |
-
|
| 334 |
-
static void rope_neox_cuda_f32(
|
| 335 |
-
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 336 |
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
|
| 337 |
-
) {
|
| 338 |
-
|
| 339 |
-
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
| 340 |
-
}
|
| 341 |
-
|
| 342 |
-
static void rope_multi_cuda_f16(
|
| 343 |
-
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 344 |
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
| 345 |
-
) {
|
| 346 |
-
|
| 347 |
-
rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
| 348 |
-
}
|
| 349 |
-
|
| 350 |
-
static void rope_multi_cuda_f32(
|
| 351 |
-
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 352 |
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
| 353 |
-
) {
|
| 354 |
-
|
| 355 |
-
rope_multi_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
| 356 |
-
}
|
| 357 |
-
|
| 358 |
-
static void rope_vision_cuda_f16(
|
| 359 |
-
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 360 |
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
| 361 |
-
) {
|
| 362 |
-
|
| 363 |
-
rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
| 364 |
-
}
|
| 365 |
-
|
| 366 |
-
static void rope_vision_cuda_f32(
|
| 367 |
-
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 368 |
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
| 369 |
-
) {
|
| 370 |
-
|
| 371 |
-
rope_vision_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
| 372 |
-
}
|
| 373 |
-
|
| 374 |
-
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 375 |
const ggml_tensor * src0 = dst->src[0];
|
| 376 |
const ggml_tensor * src1 = dst->src[1];
|
| 377 |
const ggml_tensor * src2 = dst->src[2];
|
|
@@ -382,7 +338,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 382 |
float * dst_d = (float *)dst->data;
|
| 383 |
cudaStream_t stream = ctx.stream();
|
| 384 |
|
| 385 |
-
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 386 |
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
| 387 |
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
| 388 |
GGML_ASSERT(src0->type == dst->type);
|
|
@@ -392,6 +347,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 392 |
const int64_t ne02 = src0->ne[2]; // num heads
|
| 393 |
const int64_t nr = ggml_nrows(src0);
|
| 394 |
|
|
|
|
|
|
|
|
|
|
| 395 |
//const int n_past = ((int32_t *) dst->op_params)[0];
|
| 396 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 397 |
const int mode = ((int32_t *) dst->op_params)[2];
|
|
@@ -440,59 +398,59 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 440 |
// compute
|
| 441 |
if (is_neox) {
|
| 442 |
if (src0->type == GGML_TYPE_F32) {
|
| 443 |
-
|
| 444 |
-
(const float *)src0_d, (float *)dst_d, ne00,
|
| 445 |
-
attn_factor, corr_dims, freq_factors, stream
|
| 446 |
-
);
|
| 447 |
} else if (src0->type == GGML_TYPE_F16) {
|
| 448 |
-
|
| 449 |
-
(const half *)src0_d, (half *)dst_d, ne00,
|
| 450 |
-
attn_factor, corr_dims, freq_factors, stream
|
| 451 |
-
);
|
| 452 |
} else {
|
| 453 |
GGML_ABORT("fatal error");
|
| 454 |
}
|
| 455 |
} else if (is_mrope && !is_vision) {
|
| 456 |
if (src0->type == GGML_TYPE_F32) {
|
| 457 |
-
|
| 458 |
-
(const float *)src0_d, (float *)dst_d, ne00,
|
| 459 |
-
attn_factor, corr_dims, freq_factors, sections, stream
|
| 460 |
-
);
|
| 461 |
} else if (src0->type == GGML_TYPE_F16) {
|
| 462 |
-
|
| 463 |
-
(const half *)src0_d, (half *)dst_d, ne00,
|
| 464 |
-
attn_factor, corr_dims, freq_factors, sections, stream
|
| 465 |
-
);
|
| 466 |
} else {
|
| 467 |
GGML_ABORT("fatal error");
|
| 468 |
}
|
| 469 |
} else if (is_vision) {
|
| 470 |
if (src0->type == GGML_TYPE_F32) {
|
| 471 |
-
|
| 472 |
-
(const float *)src0_d, (float *)dst_d, ne00,
|
| 473 |
-
attn_factor, corr_dims, freq_factors, sections, stream
|
| 474 |
-
);
|
| 475 |
} else if (src0->type == GGML_TYPE_F16) {
|
| 476 |
-
|
| 477 |
-
(const half *)src0_d, (half *)dst_d, ne00,
|
| 478 |
-
attn_factor, corr_dims, freq_factors, sections, stream
|
| 479 |
-
);
|
| 480 |
} else {
|
| 481 |
GGML_ABORT("fatal error");
|
| 482 |
}
|
| 483 |
} else {
|
| 484 |
if (src0->type == GGML_TYPE_F32) {
|
| 485 |
-
|
| 486 |
-
(const float *)src0_d, (float *)dst_d, ne00,
|
| 487 |
-
attn_factor, corr_dims, freq_factors, stream
|
| 488 |
-
);
|
| 489 |
} else if (src0->type == GGML_TYPE_F16) {
|
| 490 |
-
|
| 491 |
-
(const half *)src0_d, (half *)dst_d, ne00,
|
| 492 |
-
attn_factor, corr_dims, freq_factors, stream
|
| 493 |
-
);
|
| 494 |
} else {
|
| 495 |
GGML_ABORT("fatal error");
|
| 496 |
}
|
| 497 |
}
|
| 498 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
| 18 |
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
| 19 |
+
template<bool forward>
|
| 20 |
static __device__ void rope_yarn(
|
| 21 |
+
const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor,
|
| 22 |
+
float mscale, float & cos_theta, float & sin_theta) {
|
| 23 |
// Get n-d rotational scaling corrected for extrapolation
|
| 24 |
float theta_interp = freq_scale * theta_extrap;
|
| 25 |
float theta = theta_interp;
|
|
|
|
| 30 |
// Get n-d magnitude scaling corrected for interpolation
|
| 31 |
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
| 32 |
}
|
| 33 |
+
cos_theta = cosf(theta) * mscale;
|
| 34 |
+
sin_theta = sinf(theta) * mscale;
|
| 35 |
+
if (!forward) {
|
| 36 |
+
sin_theta *= -1.0f;
|
| 37 |
+
}
|
| 38 |
}
|
| 39 |
|
| 40 |
+
template<bool forward, bool has_ff, typename T>
|
| 41 |
static __global__ void rope_norm(
|
| 42 |
+
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
| 43 |
+
const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
| 44 |
+
const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
|
| 45 |
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
| 46 |
|
| 47 |
if (i0 >= ne0) {
|
| 48 |
return;
|
| 49 |
}
|
| 50 |
|
| 51 |
+
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
| 52 |
|
| 53 |
if (i0 >= n_dims) {
|
| 54 |
+
const int i = row_dst*ne0 + i0;
|
| 55 |
|
| 56 |
dst[i + 0] = x[i + 0];
|
| 57 |
dst[i + 1] = x[i + 1];
|
|
|
|
| 59 |
return;
|
| 60 |
}
|
| 61 |
|
| 62 |
+
const int row_x = row_dst % ne1;
|
| 63 |
+
const int channel_x = row_dst / ne1;
|
| 64 |
+
|
| 65 |
+
const int idst = row_dst*ne0 + i0;
|
| 66 |
+
const int ix = channel_x*s2 + row_x*s1 + i0;
|
| 67 |
|
| 68 |
+
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
| 69 |
|
| 70 |
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
| 71 |
|
| 72 |
float cos_theta;
|
| 73 |
float sin_theta;
|
| 74 |
|
| 75 |
+
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
|
| 76 |
|
| 77 |
+
const float x0 = x[ix + 0];
|
| 78 |
+
const float x1 = x[ix + 1];
|
| 79 |
|
| 80 |
+
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
|
| 81 |
+
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
|
| 82 |
}
|
| 83 |
|
| 84 |
+
template<bool forward, bool has_ff, typename T>
|
| 85 |
static __global__ void rope_neox(
|
| 86 |
+
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
| 87 |
+
const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
| 88 |
+
const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
|
| 89 |
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
| 90 |
|
| 91 |
if (i0 >= ne0) {
|
| 92 |
return;
|
| 93 |
}
|
| 94 |
|
| 95 |
+
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
| 96 |
|
| 97 |
if (i0 >= n_dims) {
|
| 98 |
+
const int i = row_dst*ne0 + i0;
|
| 99 |
|
| 100 |
dst[i + 0] = x[i + 0];
|
| 101 |
dst[i + 1] = x[i + 1];
|
|
|
|
| 103 |
return;
|
| 104 |
}
|
| 105 |
|
| 106 |
+
const int row_x = row_dst % ne1;
|
| 107 |
+
const int channel_x = row_dst / ne1;
|
| 108 |
+
|
| 109 |
+
const int idst = row_dst*ne0 + i0/2;
|
| 110 |
+
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
| 111 |
|
| 112 |
+
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
| 113 |
|
| 114 |
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
| 115 |
|
| 116 |
float cos_theta;
|
| 117 |
float sin_theta;
|
| 118 |
|
| 119 |
+
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
|
| 120 |
|
| 121 |
+
const float x0 = x[ix + 0];
|
| 122 |
+
const float x1 = x[ix + n_dims/2];
|
| 123 |
|
| 124 |
+
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
|
| 125 |
+
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
| 126 |
}
|
| 127 |
|
| 128 |
+
template<bool forward, bool has_ff, typename T>
|
| 129 |
static __global__ void rope_multi(
|
| 130 |
+
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
|
| 131 |
+
const int n_dims, const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
| 132 |
+
const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
|
| 133 |
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
| 134 |
|
| 135 |
if (i0 >= ne0) {
|
| 136 |
return;
|
| 137 |
}
|
| 138 |
|
| 139 |
+
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
| 140 |
|
| 141 |
if (i0 >= n_dims) {
|
| 142 |
+
const int i = row_dst*ne0 + i0;
|
| 143 |
|
| 144 |
dst[i + 0] = x[i + 0];
|
| 145 |
dst[i + 1] = x[i + 1];
|
|
|
|
| 147 |
return;
|
| 148 |
}
|
| 149 |
|
| 150 |
+
const int row_x = row_dst % ne1;
|
| 151 |
+
const int channel_x = row_dst / ne1;
|
| 152 |
|
| 153 |
+
const int idst = row_dst*ne0 + i0/2;
|
| 154 |
+
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
| 155 |
+
|
| 156 |
+
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
| 157 |
+
const int sec_w = sections.v[1] + sections.v[0];
|
| 158 |
+
const int sector = (i0 / 2) % sect_dims;
|
| 159 |
|
| 160 |
float theta_base = 0.0;
|
| 161 |
if (sector < sections.v[0]) {
|
| 162 |
+
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
| 163 |
}
|
| 164 |
else if (sector >= sections.v[0] && sector < sec_w) {
|
| 165 |
+
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
| 166 |
}
|
| 167 |
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
| 168 |
+
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
| 169 |
}
|
| 170 |
else if (sector >= sec_w + sections.v[2]) {
|
| 171 |
+
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
| 172 |
}
|
| 173 |
|
| 174 |
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
|
|
| 176 |
float cos_theta;
|
| 177 |
float sin_theta;
|
| 178 |
|
| 179 |
+
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
|
| 180 |
|
| 181 |
+
const float x0 = x[ix + 0];
|
| 182 |
+
const float x1 = x[ix + n_dims/2];
|
| 183 |
|
| 184 |
+
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
|
| 185 |
+
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
| 186 |
}
|
| 187 |
|
| 188 |
+
template<bool forward, bool has_ff, typename T>
|
| 189 |
static __global__ void rope_vision(
|
| 190 |
+
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
|
| 191 |
+
const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
| 192 |
+
const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
|
| 193 |
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
| 194 |
|
| 195 |
if (i0 >= ne0) {
|
| 196 |
return;
|
| 197 |
}
|
| 198 |
|
| 199 |
+
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
| 200 |
+
|
| 201 |
+
const int row_x = row_dst % ne1;
|
| 202 |
+
const int channel_x = row_dst / ne1;
|
| 203 |
|
| 204 |
+
const int idst = row_dst*ne0 + i0/2;
|
| 205 |
+
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
| 206 |
|
| 207 |
+
const int sect_dims = sections.v[0] + sections.v[1];
|
| 208 |
+
const int sec_w = sections.v[1] + sections.v[0];
|
| 209 |
+
const int sector = (i0 / 2) % sect_dims;
|
| 210 |
|
| 211 |
float theta_base = 0.0;
|
| 212 |
if (sector < sections.v[0]) {
|
| 213 |
const int p = sector;
|
| 214 |
+
theta_base = pos[channel_x]*powf(theta_scale, p);
|
| 215 |
}
|
| 216 |
else if (sector >= sections.v[0] && sector < sec_w) {
|
| 217 |
const int p = sector - sections.v[0];
|
| 218 |
+
theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
|
| 219 |
}
|
| 220 |
|
| 221 |
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
|
|
| 223 |
float cos_theta;
|
| 224 |
float sin_theta;
|
| 225 |
|
| 226 |
+
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
|
| 227 |
|
| 228 |
+
const float x0 = x[ix + 0];
|
| 229 |
+
const float x1 = x[ix + n_dims];
|
| 230 |
|
| 231 |
+
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
|
| 232 |
+
dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
|
| 233 |
}
|
| 234 |
|
| 235 |
+
template<bool forward, typename T>
|
| 236 |
static void rope_norm_cuda(
|
| 237 |
+
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
| 238 |
+
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
| 239 |
+
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
|
| 240 |
GGML_ASSERT(ne0 % 2 == 0);
|
| 241 |
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
| 242 |
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
|
|
| 245 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
| 246 |
|
| 247 |
if (freq_factors == nullptr) {
|
| 248 |
+
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
| 249 |
+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
| 250 |
+
attn_factor, corr_dims, theta_scale, freq_factors);
|
|
|
|
| 251 |
} else {
|
| 252 |
+
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
| 253 |
+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
| 254 |
+
attn_factor, corr_dims, theta_scale, freq_factors);
|
|
|
|
| 255 |
}
|
| 256 |
}
|
| 257 |
|
| 258 |
+
template<bool forward, typename T>
|
| 259 |
static void rope_neox_cuda(
|
| 260 |
+
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
| 261 |
+
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
| 262 |
+
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
|
| 263 |
GGML_ASSERT(ne0 % 2 == 0);
|
| 264 |
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
| 265 |
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
|
|
| 268 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
| 269 |
|
| 270 |
if (freq_factors == nullptr) {
|
| 271 |
+
rope_neox<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
| 272 |
+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
| 273 |
+
attn_factor, corr_dims, theta_scale, freq_factors);
|
|
|
|
| 274 |
} else {
|
| 275 |
+
rope_neox<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
| 276 |
+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
| 277 |
+
attn_factor, corr_dims, theta_scale, freq_factors);
|
|
|
|
| 278 |
}
|
| 279 |
}
|
| 280 |
|
| 281 |
+
template<bool forward, typename T>
|
| 282 |
static void rope_multi_cuda(
|
| 283 |
+
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
| 284 |
+
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
| 285 |
+
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
| 286 |
GGML_ASSERT(ne0 % 2 == 0);
|
| 287 |
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
| 288 |
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
|
|
| 291 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
| 292 |
|
| 293 |
if (freq_factors == nullptr) {
|
| 294 |
+
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
| 295 |
+
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
| 296 |
+
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
|
|
|
| 297 |
} else {
|
| 298 |
+
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
| 299 |
+
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
| 300 |
+
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
|
|
|
| 301 |
}
|
| 302 |
}
|
| 303 |
|
| 304 |
+
template<bool forward, typename T>
|
| 305 |
static void rope_vision_cuda(
|
| 306 |
+
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
| 307 |
+
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
| 308 |
+
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
| 309 |
GGML_ASSERT(ne0 % 2 == 0);
|
| 310 |
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
| 311 |
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
|
|
| 316 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
| 317 |
|
| 318 |
if (freq_factors == nullptr) {
|
| 319 |
+
rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
| 320 |
+
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
| 321 |
+
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
|
|
|
| 322 |
} else {
|
| 323 |
+
rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
| 324 |
+
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
| 325 |
+
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
|
|
|
| 326 |
}
|
| 327 |
}
|
| 328 |
|
| 329 |
+
template <bool forward>
|
| 330 |
+
void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
const ggml_tensor * src0 = dst->src[0];
|
| 332 |
const ggml_tensor * src1 = dst->src[1];
|
| 333 |
const ggml_tensor * src2 = dst->src[2];
|
|
|
|
| 338 |
float * dst_d = (float *)dst->data;
|
| 339 |
cudaStream_t stream = ctx.stream();
|
| 340 |
|
|
|
|
| 341 |
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
| 342 |
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
| 343 |
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
| 347 |
const int64_t ne02 = src0->ne[2]; // num heads
|
| 348 |
const int64_t nr = ggml_nrows(src0);
|
| 349 |
|
| 350 |
+
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
|
| 351 |
+
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
|
| 352 |
+
|
| 353 |
//const int n_past = ((int32_t *) dst->op_params)[0];
|
| 354 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 355 |
const int mode = ((int32_t *) dst->op_params)[2];
|
|
|
|
| 398 |
// compute
|
| 399 |
if (is_neox) {
|
| 400 |
if (src0->type == GGML_TYPE_F32) {
|
| 401 |
+
rope_neox_cuda<forward>(
|
| 402 |
+
(const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
| 403 |
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
|
| 404 |
} else if (src0->type == GGML_TYPE_F16) {
|
| 405 |
+
rope_neox_cuda<forward>(
|
| 406 |
+
(const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
| 407 |
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
|
| 408 |
} else {
|
| 409 |
GGML_ABORT("fatal error");
|
| 410 |
}
|
| 411 |
} else if (is_mrope && !is_vision) {
|
| 412 |
if (src0->type == GGML_TYPE_F32) {
|
| 413 |
+
rope_multi_cuda<forward>(
|
| 414 |
+
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
| 415 |
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
|
|
|
| 416 |
} else if (src0->type == GGML_TYPE_F16) {
|
| 417 |
+
rope_multi_cuda<forward>(
|
| 418 |
+
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
| 419 |
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
|
|
|
| 420 |
} else {
|
| 421 |
GGML_ABORT("fatal error");
|
| 422 |
}
|
| 423 |
} else if (is_vision) {
|
| 424 |
if (src0->type == GGML_TYPE_F32) {
|
| 425 |
+
rope_vision_cuda<forward>(
|
| 426 |
+
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
| 427 |
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
|
|
|
| 428 |
} else if (src0->type == GGML_TYPE_F16) {
|
| 429 |
+
rope_vision_cuda<forward>(
|
| 430 |
+
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
| 431 |
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
|
|
|
| 432 |
} else {
|
| 433 |
GGML_ABORT("fatal error");
|
| 434 |
}
|
| 435 |
} else {
|
| 436 |
if (src0->type == GGML_TYPE_F32) {
|
| 437 |
+
rope_norm_cuda<forward>(
|
| 438 |
+
(const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
| 439 |
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
|
| 440 |
} else if (src0->type == GGML_TYPE_F16) {
|
| 441 |
+
rope_norm_cuda<forward>(
|
| 442 |
+
(const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
| 443 |
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
|
| 444 |
} else {
|
| 445 |
GGML_ABORT("fatal error");
|
| 446 |
}
|
| 447 |
}
|
| 448 |
}
|
| 449 |
+
|
| 450 |
+
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 451 |
+
ggml_cuda_op_rope_impl<true>(ctx, dst);
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 455 |
+
ggml_cuda_op_rope_impl<false>(ctx, dst);
|
| 456 |
+
}
|
ggml/src/ggml-cuda/rope.cuh
CHANGED
|
@@ -3,3 +3,5 @@
|
|
| 3 |
#define CUDA_ROPE_BLOCK_SIZE 256
|
| 4 |
|
| 5 |
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
|
|
|
|
|
|
|
| 3 |
#define CUDA_ROPE_BLOCK_SIZE 256
|
| 4 |
|
| 5 |
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 6 |
+
|
| 7 |
+
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
ggml/src/ggml.c
CHANGED
|
@@ -3699,7 +3699,7 @@ void ggml_rope_yarn_corr_dims(
|
|
| 3699 |
|
| 3700 |
// ggml_rope_back
|
| 3701 |
|
| 3702 |
-
struct ggml_tensor *
|
| 3703 |
struct ggml_context * ctx,
|
| 3704 |
struct ggml_tensor * a,
|
| 3705 |
struct ggml_tensor * b,
|
|
@@ -3713,29 +3713,32 @@ struct ggml_tensor * ggml_rope_back(
|
|
| 3713 |
float attn_factor,
|
| 3714 |
float beta_fast,
|
| 3715 |
float beta_slow) {
|
| 3716 |
-
|
| 3717 |
-
|
| 3718 |
-
|
| 3719 |
-
|
| 3720 |
-
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
| 3721 |
-
|
| 3722 |
-
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
| 3723 |
-
memcpy(params + 5, &freq_base, sizeof(float));
|
| 3724 |
-
memcpy(params + 6, &freq_scale, sizeof(float));
|
| 3725 |
-
memcpy(params + 7, &ext_factor, sizeof(float));
|
| 3726 |
-
memcpy(params + 8, &attn_factor, sizeof(float));
|
| 3727 |
-
memcpy(params + 9, &beta_fast, sizeof(float));
|
| 3728 |
-
memcpy(params + 10, &beta_slow, sizeof(float));
|
| 3729 |
-
ggml_set_op_params(result, params, sizeof(params));
|
| 3730 |
-
|
| 3731 |
-
result->op = GGML_OP_ROPE_BACK;
|
| 3732 |
-
result->src[0] = a;
|
| 3733 |
-
result->src[1] = b;
|
| 3734 |
-
result->src[2] = c;
|
| 3735 |
-
|
| 3736 |
return result;
|
| 3737 |
}
|
| 3738 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3739 |
// ggml_clamp
|
| 3740 |
|
| 3741 |
struct ggml_tensor * ggml_clamp(
|
|
@@ -5598,6 +5601,7 @@ static void ggml_compute_backward(
|
|
| 5598 |
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
| 5599 |
const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
|
| 5600 |
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
|
|
| 5601 |
|
| 5602 |
memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float));
|
| 5603 |
memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float));
|
|
@@ -5605,10 +5609,14 @@ static void ggml_compute_backward(
|
|
| 5605 |
memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float));
|
| 5606 |
memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float));
|
| 5607 |
memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float));
|
| 5608 |
-
|
| 5609 |
-
|
| 5610 |
-
|
| 5611 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5612 |
}
|
| 5613 |
GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
|
| 5614 |
} break;
|
|
|
|
| 3699 |
|
| 3700 |
// ggml_rope_back
|
| 3701 |
|
| 3702 |
+
struct ggml_tensor * ggml_rope_ext_back(
|
| 3703 |
struct ggml_context * ctx,
|
| 3704 |
struct ggml_tensor * a,
|
| 3705 |
struct ggml_tensor * b,
|
|
|
|
| 3713 |
float attn_factor,
|
| 3714 |
float beta_fast,
|
| 3715 |
float beta_slow) {
|
| 3716 |
+
struct ggml_tensor * result = ggml_rope_ext(
|
| 3717 |
+
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
| 3718 |
+
result->op = GGML_OP_ROPE_BACK;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3719 |
return result;
|
| 3720 |
}
|
| 3721 |
|
| 3722 |
+
struct ggml_tensor * ggml_rope_multi_back(
|
| 3723 |
+
struct ggml_context * ctx,
|
| 3724 |
+
struct ggml_tensor * a,
|
| 3725 |
+
struct ggml_tensor * b,
|
| 3726 |
+
struct ggml_tensor * c,
|
| 3727 |
+
int n_dims,
|
| 3728 |
+
int sections[4],
|
| 3729 |
+
int mode,
|
| 3730 |
+
int n_ctx_orig,
|
| 3731 |
+
float freq_base,
|
| 3732 |
+
float freq_scale,
|
| 3733 |
+
float ext_factor,
|
| 3734 |
+
float attn_factor,
|
| 3735 |
+
float beta_fast,
|
| 3736 |
+
float beta_slow) {
|
| 3737 |
+
struct ggml_tensor * result = ggml_rope_multi(
|
| 3738 |
+
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
| 3739 |
+
result->op = GGML_OP_ROPE_BACK;
|
| 3740 |
+
return result;
|
| 3741 |
+
}
|
| 3742 |
// ggml_clamp
|
| 3743 |
|
| 3744 |
struct ggml_tensor * ggml_clamp(
|
|
|
|
| 5601 |
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
| 5602 |
const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
|
| 5603 |
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
| 5604 |
+
int sections[4] = {0, 0, 0, 0};
|
| 5605 |
|
| 5606 |
memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float));
|
| 5607 |
memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float));
|
|
|
|
| 5609 |
memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float));
|
| 5610 |
memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float));
|
| 5611 |
memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float));
|
| 5612 |
+
memcpy(§ions, tensor->op_params + 11, sizeof(sections));
|
| 5613 |
+
|
| 5614 |
+
struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ?
|
| 5615 |
+
ggml_rope_ext_back(ctx, grad, src1, src2, n_dims,
|
| 5616 |
+
mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) :
|
| 5617 |
+
ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections,
|
| 5618 |
+
mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
| 5619 |
+
ggml_add_or_set(ctx, cgraph, isrc0, rope_back);
|
| 5620 |
}
|
| 5621 |
GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
|
| 5622 |
} break;
|