Jonathan Graehl OccamRazor JohannesGaessler commited on
Commit
f585fe7
Β·
1 Parent(s): 58a3802

finetune: SGD optimizer, more CLI args (llama/13873)

Browse files

* examples/finetune -opt SGD (stochastic gradient descent) memory opt

add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating
m, v tensors.

support finetune.cpp arg -opt SGD (or sgd). (default adamw as before)

llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch)
when using SGD instead of 19gb (55 sec/epoch) using adamw.
(wikipedia 100 lines finetune)

(
using the same GPU memory, adamw can only do before OOM 512
batch/context, reaching:
train: [β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰] data=0000140/0000140 loss=0.02575Β±0.00099 acc=99.52Β±0.03% t=00:00:47 ETA=00:00:00
val: [β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰] data=0000008/0000008 loss=4.76565Β±0.28810 acc=41.46Β±0.77% t=00:00:00 ETA=00:00:00

SGD is superior, though it converges slower, with max before OOM 1728
batch/context (esp see the better validation perf):
train: [β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰] data=0000039/0000039 loss=0.00371Β±0.00010 acc=99.96Β±0.01% t=00:00:41 ETA=00:00:00
val: [β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰] data=0000003/0000003 loss=5.11406Β±0.76034 acc=48.01Β±0.69% t=00:00:01 ETA=00:00:00
)

note: when finetuning long enough (or w/ enough -lr),
validation accuracy *eventually* drops ('catastrophic forgetting')

-lr-half (halflife) option useful for SGD to avoid oscillation or
super slow underdamped learning (makes setting -lr more forgiving).
terminal -lr for now is set by lr-halvings i.e. if you want at most
1/8 the inital -lr you set -lr-halvings 3.

note: objective loss not directly comparable between adamw, sgd? -
check perplexity or accuracy or consider relative improvements
for convergence

new finetune args -wd 1e-9 to enable weight decay in sgd or adamw,
and max -epochs N (default 2 as before)

cache (1 - wd*alpha) in 'adamw' opt struct -
no noticeable perf benefit, disabled (still done
for new SGD though)

since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params
would probably be able to change between SGD and AdamW with each epoch
but would need to use adamw for the first (unconfirmed - no cmdline arg
to set such a policy yet)

test-opt checks adamw as before and now sgd (except for a few disabled
tests for sgd only; probably just needs logging values and adding
alternate reference values); tolerance on the 'regression'
test is broader for sgd (so we don't need many more epochs)

* Vulkan: Implement GGML_OP_OPT_STEP_SGD

* tests: Fix OPT_STEP_SGD test-backend-ops

* SGD op param store weight-decay and not 1-alpha*wd

* minor + cosmetic changes

* fix vulkan sgd

* try CI fix

---------

Co-authored-by: 0cc4m <[email protected]>
Co-authored-by: Johannes GÀßler <[email protected]>

ggml/include/ggml-opt.h CHANGED
@@ -74,16 +74,26 @@ extern "C" {
74
  GGML_OPT_BUILD_TYPE_OPT = 30,
75
  };
76
 
 
 
 
 
 
 
 
77
  // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
78
  struct ggml_opt_optimizer_params {
79
- // AdamW optimizer parameters
80
  struct {
81
  float alpha; // learning rate
82
- float beta1;
83
- float beta2;
84
  float eps; // epsilon for numerical stability
85
- float wd; // weight decay for AdamW, use 0.0f to disable
86
  } adamw;
 
 
 
 
87
  };
88
 
89
  // callback to calculate optimizer parameters prior to a backward pass
@@ -112,8 +122,11 @@ extern "C" {
112
 
113
  int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
114
 
115
- ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
116
- void * get_opt_pars_ud; // userdata for calculating optimizer parameters
 
 
 
117
  };
118
 
119
  // get parameters for an optimization context with defaults set where possible
@@ -142,6 +155,10 @@ extern "C" {
142
  // get the gradient accumulator for a node from the forward graph
143
  GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
144
 
 
 
 
 
145
  // ====== Optimization Result ======
146
 
147
  GGML_API ggml_opt_result_t ggml_opt_result_init(void);
@@ -226,12 +243,14 @@ extern "C" {
226
  struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
227
  ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
228
  enum ggml_opt_loss_type loss_type, // loss to minimize
 
229
  ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
230
  int64_t nepoch, // how many times the dataset should be iterated over
231
  int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
232
  float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
233
  bool silent); // whether or not info prints to stderr should be suppressed
234
 
 
235
  #ifdef __cplusplus
236
  }
237
  #endif
 
74
  GGML_OPT_BUILD_TYPE_OPT = 30,
75
  };
76
 
77
+ enum ggml_opt_optimizer_type {
78
+ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
79
+ GGML_OPT_OPTIMIZER_TYPE_SGD,
80
+
81
+ GGML_OPT_OPTIMIZER_TYPE_COUNT
82
+ };
83
+
84
  // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
85
  struct ggml_opt_optimizer_params {
 
86
  struct {
87
  float alpha; // learning rate
88
+ float beta1; // first AdamW momentum
89
+ float beta2; // second AdamW momentum
90
  float eps; // epsilon for numerical stability
91
+ float wd; // weight decay - 0.0f to disable
92
  } adamw;
93
+ struct {
94
+ float alpha; // learning rate
95
+ float wd; // weight decay
96
+ } sgd;
97
  };
98
 
99
  // callback to calculate optimizer parameters prior to a backward pass
 
122
 
123
  int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
124
 
125
+ ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
126
+ void * get_opt_pars_ud; // userdata for calculating optimizer parameters
127
+
128
+ // only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
129
+ enum ggml_opt_optimizer_type optimizer;
130
  };
131
 
132
  // get parameters for an optimization context with defaults set where possible
 
155
  // get the gradient accumulator for a node from the forward graph
156
  GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
157
 
158
+ GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme
159
+
160
+ GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
161
+
162
  // ====== Optimization Result ======
163
 
164
  GGML_API ggml_opt_result_t ggml_opt_result_init(void);
 
243
  struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
244
  ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
245
  enum ggml_opt_loss_type loss_type, // loss to minimize
246
+ enum ggml_opt_optimizer_type optimizer, // sgd or adamw
247
  ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
248
  int64_t nepoch, // how many times the dataset should be iterated over
249
  int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
250
  float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
251
  bool silent); // whether or not info prints to stderr should be suppressed
252
 
253
+
254
  #ifdef __cplusplus
255
  }
256
  #endif
ggml/include/ggml.h CHANGED
@@ -542,6 +542,7 @@ extern "C" {
542
  GGML_OP_CROSS_ENTROPY_LOSS,
543
  GGML_OP_CROSS_ENTROPY_LOSS_BACK,
544
  GGML_OP_OPT_STEP_ADAMW,
 
545
 
546
  GGML_OP_GLU,
547
 
@@ -2311,7 +2312,14 @@ extern "C" {
2311
  struct ggml_tensor * grad,
2312
  struct ggml_tensor * m,
2313
  struct ggml_tensor * v,
2314
- struct ggml_tensor * adamw_params); // parameters such a the learning rate
 
 
 
 
 
 
 
2315
 
2316
  //
2317
  // automatic differentiation
 
542
  GGML_OP_CROSS_ENTROPY_LOSS,
543
  GGML_OP_CROSS_ENTROPY_LOSS_BACK,
544
  GGML_OP_OPT_STEP_ADAMW,
545
+ GGML_OP_OPT_STEP_SGD,
546
 
547
  GGML_OP_GLU,
548
 
 
2312
  struct ggml_tensor * grad,
2313
  struct ggml_tensor * m,
2314
  struct ggml_tensor * v,
2315
+ struct ggml_tensor * adamw_params); // parameters such as the learning rate
2316
+
2317
+ // stochastic gradient descent step (with weight decay)
2318
+ GGML_API struct ggml_tensor * ggml_opt_step_sgd(
2319
+ struct ggml_context * ctx,
2320
+ struct ggml_tensor * a,
2321
+ struct ggml_tensor * grad,
2322
+ struct ggml_tensor * sgd_params); // alpha, weight decay
2323
 
2324
  //
2325
  // automatic differentiation
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -2022,6 +2022,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
2022
  ggml_compute_forward_opt_step_adamw(params, tensor);
2023
  }
2024
  break;
 
 
 
 
 
2025
  case GGML_OP_NONE:
2026
  {
2027
  // nop
@@ -2325,6 +2330,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2325
  case GGML_OP_CROSS_ENTROPY_LOSS:
2326
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2327
  case GGML_OP_OPT_STEP_ADAMW:
 
2328
  {
2329
  n_tasks = n_threads;
2330
  } break;
 
2022
  ggml_compute_forward_opt_step_adamw(params, tensor);
2023
  }
2024
  break;
2025
+ case GGML_OP_OPT_STEP_SGD:
2026
+ {
2027
+ ggml_compute_forward_opt_step_sgd(params, tensor);
2028
+ }
2029
+ break;
2030
  case GGML_OP_NONE:
2031
  {
2032
  // nop
 
2330
  case GGML_OP_CROSS_ENTROPY_LOSS:
2331
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2332
  case GGML_OP_OPT_STEP_ADAMW:
2333
+ case GGML_OP_OPT_STEP_SGD:
2334
  {
2335
  n_tasks = n_threads;
2336
  } break;
ggml/src/ggml-cpu/ops.cpp CHANGED
@@ -10330,6 +10330,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
10330
  const int ir1 = MIN(ir0 + dr, nr);
10331
 
10332
  const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
 
10333
  const float alpha = adamw_params_ptr[0];
10334
  const float beta1 = adamw_params_ptr[1];
10335
  const float beta2 = adamw_params_ptr[2];
@@ -10337,7 +10338,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
10337
  const float wd = adamw_params_ptr[4];
10338
  const float beta1h = adamw_params_ptr[5];
10339
  const float beta2h = adamw_params_ptr[6];
10340
-
10341
  for (int ir = ir0; ir < ir1; ++ir) {
10342
  const int64_t i03 = ir/(ne02*ne01);
10343
  const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -10360,7 +10361,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
10360
  // The weight decay is applied independently of the Adam momenta m and v.
10361
  // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
10362
  // See: https://arxiv.org/pdf/1711.05101v3.pdf
10363
- w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
10364
  }
10365
  }
10366
  }
@@ -10382,3 +10383,63 @@ void ggml_compute_forward_opt_step_adamw(
10382
  }
10383
  }
10384
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10330
  const int ir1 = MIN(ir0 + dr, nr);
10331
 
10332
  const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
10333
+
10334
  const float alpha = adamw_params_ptr[0];
10335
  const float beta1 = adamw_params_ptr[1];
10336
  const float beta2 = adamw_params_ptr[2];
 
10338
  const float wd = adamw_params_ptr[4];
10339
  const float beta1h = adamw_params_ptr[5];
10340
  const float beta2h = adamw_params_ptr[6];
10341
+ const float keep = 1.f - alpha * wd;
10342
  for (int ir = ir0; ir < ir1; ++ir) {
10343
  const int64_t i03 = ir/(ne02*ne01);
10344
  const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
 
10361
  // The weight decay is applied independently of the Adam momenta m and v.
10362
  // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
10363
  // See: https://arxiv.org/pdf/1711.05101v3.pdf
10364
+ w[i00] = w[i00] * keep - alpha * mh / vh;
10365
  }
10366
  }
10367
  }
 
10383
  }
10384
  }
10385
  }
10386
+
10387
+ static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
10388
+ const ggml_tensor * src0 = dst->src[0];
10389
+ const ggml_tensor * src0_grad = dst->src[1];
10390
+ const ggml_tensor * sgd_params = dst->src[2];
10391
+
10392
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10393
+ GGML_ASSERT(ggml_nelements(sgd_params) == 2);
10394
+
10395
+ const int ith = params->ith;
10396
+ const int nth = params->nth;
10397
+
10398
+ const int nr = ggml_nrows(src0);
10399
+
10400
+ GGML_TENSOR_UNARY_OP_LOCALS
10401
+ GGML_ASSERT(nb00 == sizeof(float));
10402
+
10403
+ // rows per thread
10404
+ const int dr = (nr + nth - 1) / nth;
10405
+
10406
+ // row range for this thread
10407
+ const int ir0 = dr * ith;
10408
+ const int ir1 = MIN(ir0 + dr, nr);
10409
+
10410
+ // using adamw param subset we care about - alpha, wd - could have a separate struct
10411
+ const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
10412
+ const float alpha = sgd_params_ptr[0];
10413
+ const float keep = 1.f - alpha * sgd_params_ptr[1];
10414
+
10415
+ for (int ir = ir0; ir < ir1; ++ir) {
10416
+ const int64_t i03 = ir / (ne02 * ne01);
10417
+ const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
10418
+ const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
10419
+
10420
+ const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
10421
+
10422
+ float * w = (float *) ((char *) src0->data + offset); // weight
10423
+ const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
10424
+
10425
+ for (int i00 = 0; i00 < ne00; ++i00) {
10426
+ w[i00] = w[i00] * keep - alpha * g[i00];
10427
+ }
10428
+ }
10429
+ }
10430
+
10431
+ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
10432
+ const ggml_tensor * src0 = dst->src[0];
10433
+
10434
+ switch (src0->type) {
10435
+ case GGML_TYPE_F32:
10436
+ {
10437
+ ggml_compute_forward_opt_step_sgd_f32(params, dst);
10438
+ }
10439
+ break;
10440
+ default:
10441
+ {
10442
+ GGML_ABORT("fatal error - sgd is F32 only");
10443
+ }
10444
+ }
10445
+ }
ggml/src/ggml-cpu/ops.h CHANGED
@@ -107,7 +107,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params *
107
  void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108
  void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
109
  void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
110
-
111
  #ifdef __cplusplus
112
  }
113
  #endif
 
107
  void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108
  void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
109
  void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
110
+ void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
111
  #ifdef __cplusplus
112
  }
113
  #endif
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -28,6 +28,7 @@
28
  #include "ggml-cuda/mmvq.cuh"
29
  #include "ggml-cuda/norm.cuh"
30
  #include "ggml-cuda/opt-step-adamw.cuh"
 
31
  #include "ggml-cuda/out-prod.cuh"
32
  #include "ggml-cuda/pad.cuh"
33
  #include "ggml-cuda/pool2d.cuh"
@@ -2479,6 +2480,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2479
  case GGML_OP_OPT_STEP_ADAMW:
2480
  ggml_cuda_opt_step_adamw(ctx, dst);
2481
  break;
 
 
 
2482
  default:
2483
  return false;
2484
  }
@@ -3536,6 +3540,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3536
  case GGML_OP_CROSS_ENTROPY_LOSS:
3537
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
3538
  case GGML_OP_OPT_STEP_ADAMW:
 
3539
  return true;
3540
  default:
3541
  return false;
 
28
  #include "ggml-cuda/mmvq.cuh"
29
  #include "ggml-cuda/norm.cuh"
30
  #include "ggml-cuda/opt-step-adamw.cuh"
31
+ #include "ggml-cuda/opt-step-sgd.cuh"
32
  #include "ggml-cuda/out-prod.cuh"
33
  #include "ggml-cuda/pad.cuh"
34
  #include "ggml-cuda/pool2d.cuh"
 
2480
  case GGML_OP_OPT_STEP_ADAMW:
2481
  ggml_cuda_opt_step_adamw(ctx, dst);
2482
  break;
2483
+ case GGML_OP_OPT_STEP_SGD:
2484
+ ggml_cuda_opt_step_sgd(ctx, dst);
2485
+ break;
2486
  default:
2487
  return false;
2488
  }
 
3540
  case GGML_OP_CROSS_ENTROPY_LOSS:
3541
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
3542
  case GGML_OP_OPT_STEP_ADAMW:
3543
+ case GGML_OP_OPT_STEP_SGD:
3544
  return true;
3545
  default:
3546
  return false;
ggml/src/ggml-cuda/opt-step-sgd.cu ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml-impl.h"
2
+ #include "opt-step-sgd.cuh"
3
+
4
+ #include <cstdint>
5
+
6
+ static __global__ void opt_step_sgd_f32(
7
+ float * __restrict__ x, const float * __restrict__ g,
8
+ const float * __restrict__ pars, const int64_t k) {
9
+
10
+ const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
11
+
12
+ if (i >= k) {
13
+ return;
14
+ }
15
+ x[i] = x[i] * (1.0f - pars[0] * pars[1]) - pars[0] * g[i];
16
+ }
17
+
18
+ static void opt_step_sgd_f32_cuda(
19
+ float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) {
20
+
21
+ const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
22
+ const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
23
+ opt_step_sgd_f32<<<block_nums, block_dims, 0, stream>>>(x, g, pars, k);
24
+ }
25
+
26
+ void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
27
+ const ggml_tensor * src0 = dst->src[0];
28
+ const ggml_tensor * src0_grad = dst->src[1];
29
+ const ggml_tensor * params = dst->src[2];
30
+
31
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
32
+ GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
33
+ GGML_ASSERT(params->type == GGML_TYPE_F32);
34
+ GGML_ASSERT(ggml_is_contiguous(src0));
35
+ GGML_ASSERT(ggml_is_contiguous(src0_grad));
36
+ GGML_ASSERT(ggml_is_contiguous(params));
37
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
38
+ GGML_ASSERT(ggml_nelements(params) == 2);
39
+
40
+ float * src0_d = (float *) src0->data;
41
+ const float * src0_grad_d = (const float *) src0_grad->data;
42
+ const float * params_d = (const float *) params->data;
43
+
44
+ cudaStream_t stream = ctx.stream();
45
+
46
+ const int64_t ne = ggml_nelements(src0);
47
+
48
+ opt_step_sgd_f32_cuda(src0_d, src0_grad_d, params_d, ne, stream);
49
+ }
ggml/src/ggml-cuda/opt-step-sgd.cuh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ #define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256
4
+
5
+ void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-opt.cpp CHANGED
@@ -64,9 +64,11 @@ struct ggml_opt_context {
64
  int32_t opt_i = 0;
65
  bool loss_per_datapoint = false;
66
 
67
- ggml_opt_get_optimizer_params get_opt_pars = nullptr;
68
- void * get_opt_pars_ud = nullptr;
69
- struct ggml_tensor * adamw_params = nullptr;
 
 
70
  };
71
 
72
  struct ggml_opt_result {
@@ -229,9 +231,13 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
229
  result.adamw.eps = 1e-8f;
230
  result.adamw.wd = 0.0f;
231
 
 
 
 
232
  return result;
233
  }
234
 
 
235
  struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
236
  return *((struct ggml_opt_optimizer_params *) userdata);
237
  }
@@ -249,6 +255,7 @@ struct ggml_opt_params ggml_opt_default_params(
249
  /*opt_period =*/ 1,
250
  /*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
251
  /*get_opt_pars_ud =*/ nullptr,
 
252
  };
253
  }
254
 
@@ -316,9 +323,14 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
316
  GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
317
  GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
318
 
 
 
319
  const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
320
  !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
321
 
 
 
 
322
  ggml_set_input(opt_ctx->inputs);
323
  ggml_set_output(opt_ctx->outputs);
324
 
@@ -340,8 +352,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
340
  // - pred (if using static graphs)
341
  // - ncorrect (if using static graphs, 2 tensors).
342
  constexpr size_t n_loss = 1;
343
- const size_t tensors_per_param = (accumulate ? 1 : 0) +
344
- (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
345
  const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
346
  const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
347
  struct ggml_init_params params = {
@@ -458,7 +469,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
458
  }
459
  }
460
 
461
- if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
462
  opt_ctx->grad_m.resize(n_nodes);
463
  opt_ctx->grad_v.resize(n_nodes);
464
  for (int i = 0; i < n_nodes; ++i) {
@@ -492,23 +503,36 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
492
  // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
493
  opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
494
 
495
- opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7);
496
- ggml_set_input(opt_ctx->adamw_params);
497
- ggml_set_name(opt_ctx->adamw_params, "adamw_params");
498
-
 
499
  for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
500
  struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
501
  struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
502
 
503
  if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
504
- struct ggml_tensor * m = opt_ctx->grad_m[i];
505
- struct ggml_tensor * v = opt_ctx->grad_v[i];
506
- struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
507
-
508
- ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
509
- ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
510
- ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
511
-
 
 
 
 
 
 
 
 
 
 
 
 
512
  ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
513
  }
514
  }
@@ -534,6 +558,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
534
  result->opt_period = params.opt_period;
535
  result->get_opt_pars = params.get_opt_pars;
536
  result->get_opt_pars_ud = params.get_opt_pars_ud;
 
537
 
538
  GGML_ASSERT(result->opt_period >= 1);
539
 
@@ -756,29 +781,43 @@ void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
756
  void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
757
  GGML_ASSERT(opt_ctx->eval_ready);
758
  if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
759
- struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
760
-
761
- GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
762
- GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
763
- GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
764
- GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
765
- GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
766
- GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
767
- GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
768
- GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
769
-
770
- // beta1, beta2 after applying warmup
771
- const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
772
- const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
773
-
774
- float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params);
775
- adamw_par_data[0] = opt_pars.adamw.alpha;
776
- adamw_par_data[1] = opt_pars.adamw.beta1;
777
- adamw_par_data[2] = opt_pars.adamw.beta2;
778
- adamw_par_data[3] = opt_pars.adamw.eps;
779
- adamw_par_data[4] = opt_pars.adamw.wd;
780
- adamw_par_data[5] = beta1h;
781
- adamw_par_data[6] = beta2h;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
782
  }
783
 
784
  ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
@@ -963,6 +1002,7 @@ void ggml_opt_fit(
963
  ggml_tensor * outputs,
964
  ggml_opt_dataset_t dataset,
965
  enum ggml_opt_loss_type loss_type,
 
966
  ggml_opt_get_optimizer_params get_opt_pars,
967
  int64_t nepoch,
968
  int64_t nbatch_logical,
@@ -993,6 +1033,7 @@ void ggml_opt_fit(
993
  params.opt_period = opt_period;
994
  params.get_opt_pars = get_opt_pars;
995
  params.get_opt_pars_ud = &epoch;
 
996
  ggml_opt_context_t opt_ctx = ggml_opt_init(params);
997
 
998
  // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
@@ -1035,3 +1076,18 @@ void ggml_opt_fit(
1035
  ggml_opt_result_free(result_train);
1036
  ggml_opt_result_free(result_val);
1037
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  int32_t opt_i = 0;
65
  bool loss_per_datapoint = false;
66
 
67
+ ggml_opt_get_optimizer_params get_opt_pars = nullptr;
68
+ void * get_opt_pars_ud = nullptr;
69
+ struct ggml_tensor * opt_step_params = nullptr; // Stores output of get_opt_pars.
70
+
71
+ enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
72
  };
73
 
74
  struct ggml_opt_result {
 
231
  result.adamw.eps = 1e-8f;
232
  result.adamw.wd = 0.0f;
233
 
234
+ result.sgd.alpha = 1e-3f;
235
+ result.sgd.wd = 0.0f;
236
+
237
  return result;
238
  }
239
 
240
+
241
  struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
242
  return *((struct ggml_opt_optimizer_params *) userdata);
243
  }
 
255
  /*opt_period =*/ 1,
256
  /*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
257
  /*get_opt_pars_ud =*/ nullptr,
258
+ /*optimizer =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
259
  };
260
  }
261
 
 
323
  GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
324
  GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
325
 
326
+ const enum ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
327
+
328
  const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
329
  !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
330
 
331
+ const bool need_momenta = opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT &&
332
+ opt_ctx->optimizer == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
333
+
334
  ggml_set_input(opt_ctx->inputs);
335
  ggml_set_output(opt_ctx->outputs);
336
 
 
352
  // - pred (if using static graphs)
353
  // - ncorrect (if using static graphs, 2 tensors).
354
  constexpr size_t n_loss = 1;
355
+ const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
 
356
  const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
357
  const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
358
  struct ggml_init_params params = {
 
469
  }
470
  }
471
 
472
+ if (need_momenta && opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
473
  opt_ctx->grad_m.resize(n_nodes);
474
  opt_ctx->grad_v.resize(n_nodes);
475
  for (int i = 0; i < n_nodes; ++i) {
 
503
  // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
504
  opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
505
 
506
+ opt_ctx->opt_step_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, need_momenta ? 7 : 2);
507
+ ggml_tensor * adamw_params = opt_ctx->opt_step_params;
508
+ ggml_set_input(adamw_params);
509
+ const char * optimizer_name = ggml_opt_optimizer_name(opt_ctx->optimizer);
510
+ ggml_format_name(adamw_params, "%s_params", optimizer_name);
511
  for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
512
  struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
513
  struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
514
 
515
  if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
516
+ struct ggml_tensor * m = nullptr;
517
+ struct ggml_tensor * v = nullptr;
518
+ if (need_momenta) {
519
+ m = opt_ctx->grad_m[i];
520
+ v = opt_ctx->grad_v[i];
521
+ ggml_format_name(m, "AdamW m for %s", node->name);
522
+ ggml_format_name(v, "AdamW v for %s", node->name);
523
+ }
524
+ struct ggml_tensor * opt_step;
525
+ switch (optimizer) {
526
+ case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
527
+ opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
528
+ break;
529
+ case GGML_OPT_OPTIMIZER_TYPE_SGD:
530
+ opt_step = ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
531
+ break;
532
+ default:
533
+ GGML_ABORT("fatal error");
534
+ }
535
+ ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
536
  ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
537
  }
538
  }
 
558
  result->opt_period = params.opt_period;
559
  result->get_opt_pars = params.get_opt_pars;
560
  result->get_opt_pars_ud = params.get_opt_pars_ud;
561
+ result->optimizer = params.optimizer;
562
 
563
  GGML_ASSERT(result->opt_period >= 1);
564
 
 
781
  void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
782
  GGML_ASSERT(opt_ctx->eval_ready);
783
  if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
784
+ const ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
785
+
786
+ switch (opt_ctx->optimizer) {
787
+ case GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
788
+ GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
789
+ GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
790
+ GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
791
+ GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
792
+ GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
793
+ GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
794
+ GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
795
+ GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
796
+
797
+ // beta1, beta2 after applying warmup
798
+ const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
799
+ const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
800
+
801
+ float * adamw_par_data = ggml_get_data_f32(opt_ctx->opt_step_params);
802
+ adamw_par_data[0] = opt_pars.adamw.alpha;
803
+ adamw_par_data[1] = opt_pars.adamw.beta1;
804
+ adamw_par_data[2] = opt_pars.adamw.beta2;
805
+ adamw_par_data[3] = opt_pars.adamw.eps;
806
+ adamw_par_data[4] = opt_pars.adamw.wd;
807
+ adamw_par_data[5] = beta1h;
808
+ adamw_par_data[6] = beta2h;
809
+ } break;
810
+ case GGML_OPT_OPTIMIZER_TYPE_SGD: {
811
+ GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
812
+ GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
813
+ GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
814
+ float * sgd = ggml_get_data_f32(opt_ctx->opt_step_params);
815
+ sgd[0] = opt_pars.sgd.alpha;
816
+ sgd[1] = opt_pars.sgd.wd;
817
+ } break;
818
+ default:
819
+ GGML_ABORT("fatal error");
820
+ }
821
  }
822
 
823
  ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
 
1002
  ggml_tensor * outputs,
1003
  ggml_opt_dataset_t dataset,
1004
  enum ggml_opt_loss_type loss_type,
1005
+ enum ggml_opt_optimizer_type optimizer,
1006
  ggml_opt_get_optimizer_params get_opt_pars,
1007
  int64_t nepoch,
1008
  int64_t nbatch_logical,
 
1033
  params.opt_period = opt_period;
1034
  params.get_opt_pars = get_opt_pars;
1035
  params.get_opt_pars_ud = &epoch;
1036
+ params.optimizer = optimizer;
1037
  ggml_opt_context_t opt_ctx = ggml_opt_init(params);
1038
 
1039
  // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
 
1076
  ggml_opt_result_free(result_train);
1077
  ggml_opt_result_free(result_val);
1078
  }
1079
+
1080
+ enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t c) {
1081
+ return c->optimizer;
1082
+ }
1083
+
1084
+ GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type o) {
1085
+ switch (o) {
1086
+ case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
1087
+ return "adamw";
1088
+ case GGML_OPT_OPTIMIZER_TYPE_SGD:
1089
+ return "sgd";
1090
+ default:
1091
+ return "undefined";
1092
+ };
1093
+ }
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -510,6 +510,7 @@ struct vk_device_struct {
510
  vk_pipeline pipeline_rwkv_wkv6_f32;
511
  vk_pipeline pipeline_rwkv_wkv7_f32;
512
  vk_pipeline pipeline_opt_step_adamw_f32;
 
513
  vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
514
  vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
515
  vk_pipeline pipeline_conv2d_dw_whcn_f32;
@@ -3123,6 +3124,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
3123
 
3124
  ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3125
 
 
 
3126
  // conv2d
3127
  for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
3128
  uint32_t conv2d_WG_SIZE = 256;
@@ -7193,6 +7196,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
7193
  return ctx->device->pipeline_opt_step_adamw_f32;
7194
  }
7195
  return nullptr;
 
 
 
 
 
7196
  case GGML_OP_LEAKY_RELU:
7197
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7198
  return ctx->device->pipeline_leaky_relu_f32;
@@ -7692,6 +7700,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7692
  ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
7693
  ggml_vk_sync_buffers(subctx);
7694
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
 
 
 
 
7695
  } else if (use_src2) {
7696
  ggml_vk_sync_buffers(subctx);
7697
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
@@ -8045,6 +8057,12 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
8045
  );
8046
  }
8047
 
 
 
 
 
 
 
8048
  static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
8049
  int * op_params = (int *)dst->op_params;
8050
 
@@ -9598,6 +9616,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9598
  case GGML_OP_LEAKY_RELU:
9599
  case GGML_OP_FLASH_ATTN_EXT:
9600
  case GGML_OP_OPT_STEP_ADAMW:
 
9601
  break;
9602
  default:
9603
  std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -9662,6 +9681,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9662
  case GGML_OP_CONV_2D:
9663
  case GGML_OP_CONV_2D_DW:
9664
  case GGML_OP_LEAKY_RELU:
 
9665
  {
9666
  // These operations all go through ggml_vk_op_f32, so short-circuit and
9667
  // do the only thing needed for the dryrun.
@@ -9911,6 +9931,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9911
  case GGML_OP_OPT_STEP_ADAMW:
9912
  ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
9913
 
 
 
 
 
 
9914
  break;
9915
  default:
9916
  return false;
@@ -10014,8 +10039,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
10014
  case GGML_OP_REPEAT:
10015
  case GGML_OP_REPEAT_BACK:
10016
  case GGML_OP_OPT_STEP_ADAMW:
 
10017
  buf = tensor->buffer;
10018
-
10019
  break;
10020
  case GGML_OP_UNARY:
10021
  switch (ggml_get_unary_op(tensor)) {
@@ -11154,6 +11179,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
11154
  case GGML_OP_SIN:
11155
  case GGML_OP_COS:
11156
  case GGML_OP_CLAMP:
 
 
 
11157
  return op->src[0]->type == GGML_TYPE_F32;
11158
  case GGML_OP_UPSCALE:
11159
  case GGML_OP_ACC:
@@ -11175,8 +11203,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
11175
  case GGML_OP_POOL_2D:
11176
  case GGML_OP_RWKV_WKV6:
11177
  case GGML_OP_RWKV_WKV7:
11178
- case GGML_OP_LEAKY_RELU:
11179
- case GGML_OP_OPT_STEP_ADAMW:
11180
  return true;
11181
  case GGML_OP_CONV_TRANSPOSE_1D:
11182
  return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
@@ -11774,6 +11800,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11774
  src_clone[0]->flags = src0->flags;
11775
  tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
11776
  src_clone[2], src_clone[3], src_clone[4]);
 
 
 
 
11777
  }
11778
  else {
11779
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
 
510
  vk_pipeline pipeline_rwkv_wkv6_f32;
511
  vk_pipeline pipeline_rwkv_wkv7_f32;
512
  vk_pipeline pipeline_opt_step_adamw_f32;
513
+ vk_pipeline pipeline_opt_step_sgd_f32;
514
  vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
515
  vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
516
  vk_pipeline pipeline_conv2d_dw_whcn_f32;
 
3124
 
3125
  ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3126
 
3127
+ ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3128
+
3129
  // conv2d
3130
  for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
3131
  uint32_t conv2d_WG_SIZE = 256;
 
7196
  return ctx->device->pipeline_opt_step_adamw_f32;
7197
  }
7198
  return nullptr;
7199
+ case GGML_OP_OPT_STEP_SGD:
7200
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7201
+ return ctx->device->pipeline_opt_step_sgd_f32;
7202
+ }
7203
+ return nullptr;
7204
  case GGML_OP_LEAKY_RELU:
7205
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7206
  return ctx->device->pipeline_leaky_relu_f32;
 
7700
  ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
7701
  ggml_vk_sync_buffers(subctx);
7702
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
7703
+ } else if (op == GGML_OP_OPT_STEP_SGD) {
7704
+ // OPT_STEP_SGD works on src0, it does not need dst
7705
+ ggml_vk_sync_buffers(subctx);
7706
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
7707
  } else if (use_src2) {
7708
  ggml_vk_sync_buffers(subctx);
7709
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
 
8057
  );
8058
  }
8059
 
8060
+ static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
8061
+ const size_t n = ggml_nelements(dst->src[0]);
8062
+
8063
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
8064
+ }
8065
+
8066
  static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
8067
  int * op_params = (int *)dst->op_params;
8068
 
 
9616
  case GGML_OP_LEAKY_RELU:
9617
  case GGML_OP_FLASH_ATTN_EXT:
9618
  case GGML_OP_OPT_STEP_ADAMW:
9619
+ case GGML_OP_OPT_STEP_SGD:
9620
  break;
9621
  default:
9622
  std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
 
9681
  case GGML_OP_CONV_2D:
9682
  case GGML_OP_CONV_2D_DW:
9683
  case GGML_OP_LEAKY_RELU:
9684
+ case GGML_OP_OPT_STEP_SGD:
9685
  {
9686
  // These operations all go through ggml_vk_op_f32, so short-circuit and
9687
  // do the only thing needed for the dryrun.
 
9931
  case GGML_OP_OPT_STEP_ADAMW:
9932
  ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
9933
 
9934
+ break;
9935
+
9936
+ case GGML_OP_OPT_STEP_SGD:
9937
+ ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun);
9938
+
9939
  break;
9940
  default:
9941
  return false;
 
10039
  case GGML_OP_REPEAT:
10040
  case GGML_OP_REPEAT_BACK:
10041
  case GGML_OP_OPT_STEP_ADAMW:
10042
+ case GGML_OP_OPT_STEP_SGD:
10043
  buf = tensor->buffer;
 
10044
  break;
10045
  case GGML_OP_UNARY:
10046
  switch (ggml_get_unary_op(tensor)) {
 
11179
  case GGML_OP_SIN:
11180
  case GGML_OP_COS:
11181
  case GGML_OP_CLAMP:
11182
+ case GGML_OP_LEAKY_RELU:
11183
+ case GGML_OP_OPT_STEP_ADAMW:
11184
+ case GGML_OP_OPT_STEP_SGD:
11185
  return op->src[0]->type == GGML_TYPE_F32;
11186
  case GGML_OP_UPSCALE:
11187
  case GGML_OP_ACC:
 
11203
  case GGML_OP_POOL_2D:
11204
  case GGML_OP_RWKV_WKV6:
11205
  case GGML_OP_RWKV_WKV7:
 
 
11206
  return true;
11207
  case GGML_OP_CONV_TRANSPOSE_1D:
11208
  return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
 
11800
  src_clone[0]->flags = src0->flags;
11801
  tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
11802
  src_clone[2], src_clone[3], src_clone[4]);
11803
+ } else if (tensor->op == GGML_OP_OPT_STEP_SGD) {
11804
+ src_clone[0]->flags = src0->flags;
11805
+ tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
11806
+ src_clone[2]);
11807
  }
11808
  else {
11809
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #include "generic_head.comp"
4
+
5
+ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
6
+
7
+ layout (binding = 0) buffer X {A_TYPE data_x[];};
8
+ layout (binding = 1) readonly buffer G {A_TYPE data_grad[];};
9
+ layout (binding = 2) readonly buffer P {float data_params[2];};
10
+
11
+ void main() {
12
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
13
+
14
+ if (i >= p.KX) {
15
+ return;
16
+ }
17
+
18
+ const float alpha = data_params[0];
19
+ const float keep = 1.f - alpha * data_params[1];
20
+
21
+ data_x[i] = data_x[i] * keep - alpha * data_grad[i];
22
+ }
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -657,6 +657,7 @@ void process_shaders() {
657
  string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
658
 
659
  string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 
660
 
661
  string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
662
  string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
 
657
  string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
658
 
659
  string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
660
+ string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
661
 
662
  string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
663
  string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
ggml/src/ggml.c CHANGED
@@ -1012,11 +1012,12 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
1012
  "CROSS_ENTROPY_LOSS",
1013
  "CROSS_ENTROPY_LOSS_BACK",
1014
  "OPT_STEP_ADAMW",
 
1015
 
1016
  "GLU",
1017
  };
1018
 
1019
- static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
1020
 
1021
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1022
  "none",
@@ -1113,15 +1114,15 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1113
  "cross_entropy_loss(x,y)",
1114
  "cross_entropy_loss_back(x,y)",
1115
  "adamw(x)",
 
1116
 
1117
  "glu(x)",
1118
  };
1119
 
1120
- static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
1121
 
1122
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1123
 
1124
-
1125
  static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
1126
  "ABS",
1127
  "SGN",
@@ -5606,6 +5607,28 @@ struct ggml_tensor * ggml_opt_step_adamw(
5606
  return result;
5607
  }
5608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5609
  ////////////////////////////////////////////////////////////////////////////////
5610
 
5611
  struct ggml_hash_set ggml_hash_set_new(size_t size) {
 
1012
  "CROSS_ENTROPY_LOSS",
1013
  "CROSS_ENTROPY_LOSS_BACK",
1014
  "OPT_STEP_ADAMW",
1015
+ "OPT_STEP_SGD",
1016
 
1017
  "GLU",
1018
  };
1019
 
1020
+ static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
1021
 
1022
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1023
  "none",
 
1114
  "cross_entropy_loss(x,y)",
1115
  "cross_entropy_loss_back(x,y)",
1116
  "adamw(x)",
1117
+ "sgd(x)",
1118
 
1119
  "glu(x)",
1120
  };
1121
 
1122
+ static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
1123
 
1124
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1125
 
 
1126
  static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
1127
  "ABS",
1128
  "SGN",
 
5607
  return result;
5608
  }
5609
 
5610
+ // opt_step_sgd
5611
+
5612
+ struct ggml_tensor * ggml_opt_step_sgd(
5613
+ struct ggml_context * ctx,
5614
+ struct ggml_tensor * a,
5615
+ struct ggml_tensor * grad,
5616
+ struct ggml_tensor * params) {
5617
+ GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
5618
+ GGML_ASSERT(ggml_are_same_shape(a, grad));
5619
+ GGML_ASSERT(params->type == GGML_TYPE_F32);
5620
+ GGML_ASSERT(ggml_nelements(params) == 2);
5621
+
5622
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5623
+
5624
+ result->op = GGML_OP_OPT_STEP_SGD;
5625
+ result->src[0] = a;
5626
+ result->src[1] = grad;
5627
+ result->src[2] = params;
5628
+
5629
+ return result;
5630
+ }
5631
+
5632
  ////////////////////////////////////////////////////////////////////////////////
5633
 
5634
  struct ggml_hash_set ggml_hash_set_new(size_t size) {