Spaces:
Running
finetune: SGD optimizer, more CLI args (llama/13873)
Browse files* examples/finetune -opt SGD (stochastic gradient descent) memory opt
add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating
m, v tensors.
support finetune.cpp arg -opt SGD (or sgd). (default adamw as before)
llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch)
when using SGD instead of 19gb (55 sec/epoch) using adamw.
(wikipedia 100 lines finetune)
(
using the same GPU memory, adamw can only do before OOM 512
batch/context, reaching:
train: [ββββββββ] data=0000140/0000140 loss=0.02575Β±0.00099 acc=99.52Β±0.03% t=00:00:47 ETA=00:00:00
val: [ββββββββ] data=0000008/0000008 loss=4.76565Β±0.28810 acc=41.46Β±0.77% t=00:00:00 ETA=00:00:00
SGD is superior, though it converges slower, with max before OOM 1728
batch/context (esp see the better validation perf):
train: [ββββββββ] data=0000039/0000039 loss=0.00371Β±0.00010 acc=99.96Β±0.01% t=00:00:41 ETA=00:00:00
val: [ββββββββ] data=0000003/0000003 loss=5.11406Β±0.76034 acc=48.01Β±0.69% t=00:00:01 ETA=00:00:00
)
note: when finetuning long enough (or w/ enough -lr),
validation accuracy *eventually* drops ('catastrophic forgetting')
-lr-half (halflife) option useful for SGD to avoid oscillation or
super slow underdamped learning (makes setting -lr more forgiving).
terminal -lr for now is set by lr-halvings i.e. if you want at most
1/8 the inital -lr you set -lr-halvings 3.
note: objective loss not directly comparable between adamw, sgd? -
check perplexity or accuracy or consider relative improvements
for convergence
new finetune args -wd 1e-9 to enable weight decay in sgd or adamw,
and max -epochs N (default 2 as before)
cache (1 - wd*alpha) in 'adamw' opt struct -
no noticeable perf benefit, disabled (still done
for new SGD though)
since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params
would probably be able to change between SGD and AdamW with each epoch
but would need to use adamw for the first (unconfirmed - no cmdline arg
to set such a policy yet)
test-opt checks adamw as before and now sgd (except for a few disabled
tests for sgd only; probably just needs logging values and adding
alternate reference values); tolerance on the 'regression'
test is broader for sgd (so we don't need many more epochs)
* Vulkan: Implement GGML_OP_OPT_STEP_SGD
* tests: Fix OPT_STEP_SGD test-backend-ops
* SGD op param store weight-decay and not 1-alpha*wd
* minor + cosmetic changes
* fix vulkan sgd
* try CI fix
---------
Co-authored-by: 0cc4m <[email protected]>
Co-authored-by: Johannes GΓ€Γler <[email protected]>
- ggml/include/ggml-opt.h +25 -6
- ggml/include/ggml.h +9 -1
- ggml/src/ggml-cpu/ggml-cpu.c +6 -0
- ggml/src/ggml-cpu/ops.cpp +63 -2
- ggml/src/ggml-cpu/ops.h +1 -1
- ggml/src/ggml-cuda/ggml-cuda.cu +5 -0
- ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- ggml/src/ggml-opt.cpp +97 -41
- ggml/src/ggml-vulkan/ggml-vulkan.cpp +33 -3
- ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +1 -0
- ggml/src/ggml.c +26 -3
|
@@ -74,16 +74,26 @@ extern "C" {
|
|
| 74 |
GGML_OPT_BUILD_TYPE_OPT = 30,
|
| 75 |
};
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
|
| 78 |
struct ggml_opt_optimizer_params {
|
| 79 |
-
// AdamW optimizer parameters
|
| 80 |
struct {
|
| 81 |
float alpha; // learning rate
|
| 82 |
-
float beta1;
|
| 83 |
-
float beta2;
|
| 84 |
float eps; // epsilon for numerical stability
|
| 85 |
-
float wd; // weight decay
|
| 86 |
} adamw;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
};
|
| 88 |
|
| 89 |
// callback to calculate optimizer parameters prior to a backward pass
|
|
@@ -112,8 +122,11 @@ extern "C" {
|
|
| 112 |
|
| 113 |
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
|
| 114 |
|
| 115 |
-
ggml_opt_get_optimizer_params get_opt_pars;
|
| 116 |
-
void *
|
|
|
|
|
|
|
|
|
|
| 117 |
};
|
| 118 |
|
| 119 |
// get parameters for an optimization context with defaults set where possible
|
|
@@ -142,6 +155,10 @@ extern "C" {
|
|
| 142 |
// get the gradient accumulator for a node from the forward graph
|
| 143 |
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
// ====== Optimization Result ======
|
| 146 |
|
| 147 |
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
|
|
@@ -226,12 +243,14 @@ extern "C" {
|
|
| 226 |
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
| 227 |
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
|
| 228 |
enum ggml_opt_loss_type loss_type, // loss to minimize
|
|
|
|
| 229 |
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
|
| 230 |
int64_t nepoch, // how many times the dataset should be iterated over
|
| 231 |
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
|
| 232 |
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
|
| 233 |
bool silent); // whether or not info prints to stderr should be suppressed
|
| 234 |
|
|
|
|
| 235 |
#ifdef __cplusplus
|
| 236 |
}
|
| 237 |
#endif
|
|
|
|
| 74 |
GGML_OPT_BUILD_TYPE_OPT = 30,
|
| 75 |
};
|
| 76 |
|
| 77 |
+
enum ggml_opt_optimizer_type {
|
| 78 |
+
GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
| 79 |
+
GGML_OPT_OPTIMIZER_TYPE_SGD,
|
| 80 |
+
|
| 81 |
+
GGML_OPT_OPTIMIZER_TYPE_COUNT
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
|
| 85 |
struct ggml_opt_optimizer_params {
|
|
|
|
| 86 |
struct {
|
| 87 |
float alpha; // learning rate
|
| 88 |
+
float beta1; // first AdamW momentum
|
| 89 |
+
float beta2; // second AdamW momentum
|
| 90 |
float eps; // epsilon for numerical stability
|
| 91 |
+
float wd; // weight decay - 0.0f to disable
|
| 92 |
} adamw;
|
| 93 |
+
struct {
|
| 94 |
+
float alpha; // learning rate
|
| 95 |
+
float wd; // weight decay
|
| 96 |
+
} sgd;
|
| 97 |
};
|
| 98 |
|
| 99 |
// callback to calculate optimizer parameters prior to a backward pass
|
|
|
|
| 122 |
|
| 123 |
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
|
| 124 |
|
| 125 |
+
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
| 126 |
+
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
| 127 |
+
|
| 128 |
+
// only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
|
| 129 |
+
enum ggml_opt_optimizer_type optimizer;
|
| 130 |
};
|
| 131 |
|
| 132 |
// get parameters for an optimization context with defaults set where possible
|
|
|
|
| 155 |
// get the gradient accumulator for a node from the forward graph
|
| 156 |
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
|
| 157 |
|
| 158 |
+
GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme
|
| 159 |
+
|
| 160 |
+
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
|
| 161 |
+
|
| 162 |
// ====== Optimization Result ======
|
| 163 |
|
| 164 |
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
|
|
|
|
| 243 |
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
| 244 |
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
|
| 245 |
enum ggml_opt_loss_type loss_type, // loss to minimize
|
| 246 |
+
enum ggml_opt_optimizer_type optimizer, // sgd or adamw
|
| 247 |
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
|
| 248 |
int64_t nepoch, // how many times the dataset should be iterated over
|
| 249 |
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
|
| 250 |
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
|
| 251 |
bool silent); // whether or not info prints to stderr should be suppressed
|
| 252 |
|
| 253 |
+
|
| 254 |
#ifdef __cplusplus
|
| 255 |
}
|
| 256 |
#endif
|
|
@@ -542,6 +542,7 @@ extern "C" {
|
|
| 542 |
GGML_OP_CROSS_ENTROPY_LOSS,
|
| 543 |
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
| 544 |
GGML_OP_OPT_STEP_ADAMW,
|
|
|
|
| 545 |
|
| 546 |
GGML_OP_GLU,
|
| 547 |
|
|
@@ -2311,7 +2312,14 @@ extern "C" {
|
|
| 2311 |
struct ggml_tensor * grad,
|
| 2312 |
struct ggml_tensor * m,
|
| 2313 |
struct ggml_tensor * v,
|
| 2314 |
-
struct ggml_tensor * adamw_params); // parameters such
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2315 |
|
| 2316 |
//
|
| 2317 |
// automatic differentiation
|
|
|
|
| 542 |
GGML_OP_CROSS_ENTROPY_LOSS,
|
| 543 |
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
| 544 |
GGML_OP_OPT_STEP_ADAMW,
|
| 545 |
+
GGML_OP_OPT_STEP_SGD,
|
| 546 |
|
| 547 |
GGML_OP_GLU,
|
| 548 |
|
|
|
|
| 2312 |
struct ggml_tensor * grad,
|
| 2313 |
struct ggml_tensor * m,
|
| 2314 |
struct ggml_tensor * v,
|
| 2315 |
+
struct ggml_tensor * adamw_params); // parameters such as the learning rate
|
| 2316 |
+
|
| 2317 |
+
// stochastic gradient descent step (with weight decay)
|
| 2318 |
+
GGML_API struct ggml_tensor * ggml_opt_step_sgd(
|
| 2319 |
+
struct ggml_context * ctx,
|
| 2320 |
+
struct ggml_tensor * a,
|
| 2321 |
+
struct ggml_tensor * grad,
|
| 2322 |
+
struct ggml_tensor * sgd_params); // alpha, weight decay
|
| 2323 |
|
| 2324 |
//
|
| 2325 |
// automatic differentiation
|
|
@@ -2022,6 +2022,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 2022 |
ggml_compute_forward_opt_step_adamw(params, tensor);
|
| 2023 |
}
|
| 2024 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2025 |
case GGML_OP_NONE:
|
| 2026 |
{
|
| 2027 |
// nop
|
|
@@ -2325,6 +2330,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 2325 |
case GGML_OP_CROSS_ENTROPY_LOSS:
|
| 2326 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 2327 |
case GGML_OP_OPT_STEP_ADAMW:
|
|
|
|
| 2328 |
{
|
| 2329 |
n_tasks = n_threads;
|
| 2330 |
} break;
|
|
|
|
| 2022 |
ggml_compute_forward_opt_step_adamw(params, tensor);
|
| 2023 |
}
|
| 2024 |
break;
|
| 2025 |
+
case GGML_OP_OPT_STEP_SGD:
|
| 2026 |
+
{
|
| 2027 |
+
ggml_compute_forward_opt_step_sgd(params, tensor);
|
| 2028 |
+
}
|
| 2029 |
+
break;
|
| 2030 |
case GGML_OP_NONE:
|
| 2031 |
{
|
| 2032 |
// nop
|
|
|
|
| 2330 |
case GGML_OP_CROSS_ENTROPY_LOSS:
|
| 2331 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 2332 |
case GGML_OP_OPT_STEP_ADAMW:
|
| 2333 |
+
case GGML_OP_OPT_STEP_SGD:
|
| 2334 |
{
|
| 2335 |
n_tasks = n_threads;
|
| 2336 |
} break;
|
|
@@ -10330,6 +10330,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
|
| 10330 |
const int ir1 = MIN(ir0 + dr, nr);
|
| 10331 |
|
| 10332 |
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
|
|
|
|
| 10333 |
const float alpha = adamw_params_ptr[0];
|
| 10334 |
const float beta1 = adamw_params_ptr[1];
|
| 10335 |
const float beta2 = adamw_params_ptr[2];
|
|
@@ -10337,7 +10338,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
|
| 10337 |
const float wd = adamw_params_ptr[4];
|
| 10338 |
const float beta1h = adamw_params_ptr[5];
|
| 10339 |
const float beta2h = adamw_params_ptr[6];
|
| 10340 |
-
|
| 10341 |
for (int ir = ir0; ir < ir1; ++ir) {
|
| 10342 |
const int64_t i03 = ir/(ne02*ne01);
|
| 10343 |
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
@@ -10360,7 +10361,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
|
| 10360 |
// The weight decay is applied independently of the Adam momenta m and v.
|
| 10361 |
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
|
| 10362 |
// See: https://arxiv.org/pdf/1711.05101v3.pdf
|
| 10363 |
-
w[i00] = w[i00]*
|
| 10364 |
}
|
| 10365 |
}
|
| 10366 |
}
|
|
@@ -10382,3 +10383,63 @@ void ggml_compute_forward_opt_step_adamw(
|
|
| 10382 |
}
|
| 10383 |
}
|
| 10384 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10330 |
const int ir1 = MIN(ir0 + dr, nr);
|
| 10331 |
|
| 10332 |
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
|
| 10333 |
+
|
| 10334 |
const float alpha = adamw_params_ptr[0];
|
| 10335 |
const float beta1 = adamw_params_ptr[1];
|
| 10336 |
const float beta2 = adamw_params_ptr[2];
|
|
|
|
| 10338 |
const float wd = adamw_params_ptr[4];
|
| 10339 |
const float beta1h = adamw_params_ptr[5];
|
| 10340 |
const float beta2h = adamw_params_ptr[6];
|
| 10341 |
+
const float keep = 1.f - alpha * wd;
|
| 10342 |
for (int ir = ir0; ir < ir1; ++ir) {
|
| 10343 |
const int64_t i03 = ir/(ne02*ne01);
|
| 10344 |
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
|
|
| 10361 |
// The weight decay is applied independently of the Adam momenta m and v.
|
| 10362 |
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
|
| 10363 |
// See: https://arxiv.org/pdf/1711.05101v3.pdf
|
| 10364 |
+
w[i00] = w[i00] * keep - alpha * mh / vh;
|
| 10365 |
}
|
| 10366 |
}
|
| 10367 |
}
|
|
|
|
| 10383 |
}
|
| 10384 |
}
|
| 10385 |
}
|
| 10386 |
+
|
| 10387 |
+
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
| 10388 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 10389 |
+
const ggml_tensor * src0_grad = dst->src[1];
|
| 10390 |
+
const ggml_tensor * sgd_params = dst->src[2];
|
| 10391 |
+
|
| 10392 |
+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
|
| 10393 |
+
GGML_ASSERT(ggml_nelements(sgd_params) == 2);
|
| 10394 |
+
|
| 10395 |
+
const int ith = params->ith;
|
| 10396 |
+
const int nth = params->nth;
|
| 10397 |
+
|
| 10398 |
+
const int nr = ggml_nrows(src0);
|
| 10399 |
+
|
| 10400 |
+
GGML_TENSOR_UNARY_OP_LOCALS
|
| 10401 |
+
GGML_ASSERT(nb00 == sizeof(float));
|
| 10402 |
+
|
| 10403 |
+
// rows per thread
|
| 10404 |
+
const int dr = (nr + nth - 1) / nth;
|
| 10405 |
+
|
| 10406 |
+
// row range for this thread
|
| 10407 |
+
const int ir0 = dr * ith;
|
| 10408 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 10409 |
+
|
| 10410 |
+
// using adamw param subset we care about - alpha, wd - could have a separate struct
|
| 10411 |
+
const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
|
| 10412 |
+
const float alpha = sgd_params_ptr[0];
|
| 10413 |
+
const float keep = 1.f - alpha * sgd_params_ptr[1];
|
| 10414 |
+
|
| 10415 |
+
for (int ir = ir0; ir < ir1; ++ir) {
|
| 10416 |
+
const int64_t i03 = ir / (ne02 * ne01);
|
| 10417 |
+
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
|
| 10418 |
+
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
|
| 10419 |
+
|
| 10420 |
+
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
|
| 10421 |
+
|
| 10422 |
+
float * w = (float *) ((char *) src0->data + offset); // weight
|
| 10423 |
+
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
|
| 10424 |
+
|
| 10425 |
+
for (int i00 = 0; i00 < ne00; ++i00) {
|
| 10426 |
+
w[i00] = w[i00] * keep - alpha * g[i00];
|
| 10427 |
+
}
|
| 10428 |
+
}
|
| 10429 |
+
}
|
| 10430 |
+
|
| 10431 |
+
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
|
| 10432 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 10433 |
+
|
| 10434 |
+
switch (src0->type) {
|
| 10435 |
+
case GGML_TYPE_F32:
|
| 10436 |
+
{
|
| 10437 |
+
ggml_compute_forward_opt_step_sgd_f32(params, dst);
|
| 10438 |
+
}
|
| 10439 |
+
break;
|
| 10440 |
+
default:
|
| 10441 |
+
{
|
| 10442 |
+
GGML_ABORT("fatal error - sgd is F32 only");
|
| 10443 |
+
}
|
| 10444 |
+
}
|
| 10445 |
+
}
|
|
@@ -107,7 +107,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params *
|
|
| 107 |
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 108 |
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 109 |
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 110 |
-
|
| 111 |
#ifdef __cplusplus
|
| 112 |
}
|
| 113 |
#endif
|
|
|
|
| 107 |
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 108 |
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 109 |
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 110 |
+
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 111 |
#ifdef __cplusplus
|
| 112 |
}
|
| 113 |
#endif
|
|
@@ -28,6 +28,7 @@
|
|
| 28 |
#include "ggml-cuda/mmvq.cuh"
|
| 29 |
#include "ggml-cuda/norm.cuh"
|
| 30 |
#include "ggml-cuda/opt-step-adamw.cuh"
|
|
|
|
| 31 |
#include "ggml-cuda/out-prod.cuh"
|
| 32 |
#include "ggml-cuda/pad.cuh"
|
| 33 |
#include "ggml-cuda/pool2d.cuh"
|
|
@@ -2479,6 +2480,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2479 |
case GGML_OP_OPT_STEP_ADAMW:
|
| 2480 |
ggml_cuda_opt_step_adamw(ctx, dst);
|
| 2481 |
break;
|
|
|
|
|
|
|
|
|
|
| 2482 |
default:
|
| 2483 |
return false;
|
| 2484 |
}
|
|
@@ -3536,6 +3540,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3536 |
case GGML_OP_CROSS_ENTROPY_LOSS:
|
| 3537 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 3538 |
case GGML_OP_OPT_STEP_ADAMW:
|
|
|
|
| 3539 |
return true;
|
| 3540 |
default:
|
| 3541 |
return false;
|
|
|
|
| 28 |
#include "ggml-cuda/mmvq.cuh"
|
| 29 |
#include "ggml-cuda/norm.cuh"
|
| 30 |
#include "ggml-cuda/opt-step-adamw.cuh"
|
| 31 |
+
#include "ggml-cuda/opt-step-sgd.cuh"
|
| 32 |
#include "ggml-cuda/out-prod.cuh"
|
| 33 |
#include "ggml-cuda/pad.cuh"
|
| 34 |
#include "ggml-cuda/pool2d.cuh"
|
|
|
|
| 2480 |
case GGML_OP_OPT_STEP_ADAMW:
|
| 2481 |
ggml_cuda_opt_step_adamw(ctx, dst);
|
| 2482 |
break;
|
| 2483 |
+
case GGML_OP_OPT_STEP_SGD:
|
| 2484 |
+
ggml_cuda_opt_step_sgd(ctx, dst);
|
| 2485 |
+
break;
|
| 2486 |
default:
|
| 2487 |
return false;
|
| 2488 |
}
|
|
|
|
| 3540 |
case GGML_OP_CROSS_ENTROPY_LOSS:
|
| 3541 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 3542 |
case GGML_OP_OPT_STEP_ADAMW:
|
| 3543 |
+
case GGML_OP_OPT_STEP_SGD:
|
| 3544 |
return true;
|
| 3545 |
default:
|
| 3546 |
return false;
|
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ggml-impl.h"
|
| 2 |
+
#include "opt-step-sgd.cuh"
|
| 3 |
+
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
static __global__ void opt_step_sgd_f32(
|
| 7 |
+
float * __restrict__ x, const float * __restrict__ g,
|
| 8 |
+
const float * __restrict__ pars, const int64_t k) {
|
| 9 |
+
|
| 10 |
+
const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
|
| 11 |
+
|
| 12 |
+
if (i >= k) {
|
| 13 |
+
return;
|
| 14 |
+
}
|
| 15 |
+
x[i] = x[i] * (1.0f - pars[0] * pars[1]) - pars[0] * g[i];
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
static void opt_step_sgd_f32_cuda(
|
| 19 |
+
float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) {
|
| 20 |
+
|
| 21 |
+
const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
|
| 22 |
+
const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
|
| 23 |
+
opt_step_sgd_f32<<<block_nums, block_dims, 0, stream>>>(x, g, pars, k);
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 27 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 28 |
+
const ggml_tensor * src0_grad = dst->src[1];
|
| 29 |
+
const ggml_tensor * params = dst->src[2];
|
| 30 |
+
|
| 31 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 32 |
+
GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
|
| 33 |
+
GGML_ASSERT(params->type == GGML_TYPE_F32);
|
| 34 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 35 |
+
GGML_ASSERT(ggml_is_contiguous(src0_grad));
|
| 36 |
+
GGML_ASSERT(ggml_is_contiguous(params));
|
| 37 |
+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
|
| 38 |
+
GGML_ASSERT(ggml_nelements(params) == 2);
|
| 39 |
+
|
| 40 |
+
float * src0_d = (float *) src0->data;
|
| 41 |
+
const float * src0_grad_d = (const float *) src0_grad->data;
|
| 42 |
+
const float * params_d = (const float *) params->data;
|
| 43 |
+
|
| 44 |
+
cudaStream_t stream = ctx.stream();
|
| 45 |
+
|
| 46 |
+
const int64_t ne = ggml_nelements(src0);
|
| 47 |
+
|
| 48 |
+
opt_step_sgd_f32_cuda(src0_d, src0_grad_d, params_d, ne, stream);
|
| 49 |
+
}
|
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
|
| 3 |
+
#define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256
|
| 4 |
+
|
| 5 |
+
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@@ -64,9 +64,11 @@ struct ggml_opt_context {
|
|
| 64 |
int32_t opt_i = 0;
|
| 65 |
bool loss_per_datapoint = false;
|
| 66 |
|
| 67 |
-
ggml_opt_get_optimizer_params get_opt_pars
|
| 68 |
-
void *
|
| 69 |
-
struct ggml_tensor *
|
|
|
|
|
|
|
| 70 |
};
|
| 71 |
|
| 72 |
struct ggml_opt_result {
|
|
@@ -229,9 +231,13 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
|
|
| 229 |
result.adamw.eps = 1e-8f;
|
| 230 |
result.adamw.wd = 0.0f;
|
| 231 |
|
|
|
|
|
|
|
|
|
|
| 232 |
return result;
|
| 233 |
}
|
| 234 |
|
|
|
|
| 235 |
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
|
| 236 |
return *((struct ggml_opt_optimizer_params *) userdata);
|
| 237 |
}
|
|
@@ -249,6 +255,7 @@ struct ggml_opt_params ggml_opt_default_params(
|
|
| 249 |
/*opt_period =*/ 1,
|
| 250 |
/*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
|
| 251 |
/*get_opt_pars_ud =*/ nullptr,
|
|
|
|
| 252 |
};
|
| 253 |
}
|
| 254 |
|
|
@@ -316,9 +323,14 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
|
| 316 |
GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
|
| 317 |
GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
|
| 318 |
|
|
|
|
|
|
|
| 319 |
const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
|
| 320 |
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
|
| 321 |
|
|
|
|
|
|
|
|
|
|
| 322 |
ggml_set_input(opt_ctx->inputs);
|
| 323 |
ggml_set_output(opt_ctx->outputs);
|
| 324 |
|
|
@@ -340,8 +352,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
|
| 340 |
// - pred (if using static graphs)
|
| 341 |
// - ncorrect (if using static graphs, 2 tensors).
|
| 342 |
constexpr size_t n_loss = 1;
|
| 343 |
-
const size_t tensors_per_param = (accumulate ? 1 : 0) +
|
| 344 |
-
(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
|
| 345 |
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
|
| 346 |
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
|
| 347 |
struct ggml_init_params params = {
|
|
@@ -458,7 +469,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
|
| 458 |
}
|
| 459 |
}
|
| 460 |
|
| 461 |
-
if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
|
| 462 |
opt_ctx->grad_m.resize(n_nodes);
|
| 463 |
opt_ctx->grad_v.resize(n_nodes);
|
| 464 |
for (int i = 0; i < n_nodes; ++i) {
|
|
@@ -492,23 +503,36 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
|
| 492 |
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
|
| 493 |
opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
|
| 494 |
|
| 495 |
-
opt_ctx->
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
|
|
|
| 499 |
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
|
| 500 |
struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
|
| 501 |
struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
|
| 502 |
|
| 503 |
if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
|
| 504 |
-
struct ggml_tensor * m
|
| 505 |
-
struct ggml_tensor * v
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
|
| 513 |
}
|
| 514 |
}
|
|
@@ -534,6 +558,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
|
| 534 |
result->opt_period = params.opt_period;
|
| 535 |
result->get_opt_pars = params.get_opt_pars;
|
| 536 |
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
|
|
|
| 537 |
|
| 538 |
GGML_ASSERT(result->opt_period >= 1);
|
| 539 |
|
|
@@ -756,29 +781,43 @@ void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
|
|
| 756 |
void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
|
| 757 |
GGML_ASSERT(opt_ctx->eval_ready);
|
| 758 |
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 782 |
}
|
| 783 |
|
| 784 |
ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
|
@@ -963,6 +1002,7 @@ void ggml_opt_fit(
|
|
| 963 |
ggml_tensor * outputs,
|
| 964 |
ggml_opt_dataset_t dataset,
|
| 965 |
enum ggml_opt_loss_type loss_type,
|
|
|
|
| 966 |
ggml_opt_get_optimizer_params get_opt_pars,
|
| 967 |
int64_t nepoch,
|
| 968 |
int64_t nbatch_logical,
|
|
@@ -993,6 +1033,7 @@ void ggml_opt_fit(
|
|
| 993 |
params.opt_period = opt_period;
|
| 994 |
params.get_opt_pars = get_opt_pars;
|
| 995 |
params.get_opt_pars_ud = &epoch;
|
|
|
|
| 996 |
ggml_opt_context_t opt_ctx = ggml_opt_init(params);
|
| 997 |
|
| 998 |
// Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
|
|
@@ -1035,3 +1076,18 @@ void ggml_opt_fit(
|
|
| 1035 |
ggml_opt_result_free(result_train);
|
| 1036 |
ggml_opt_result_free(result_val);
|
| 1037 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
int32_t opt_i = 0;
|
| 65 |
bool loss_per_datapoint = false;
|
| 66 |
|
| 67 |
+
ggml_opt_get_optimizer_params get_opt_pars = nullptr;
|
| 68 |
+
void * get_opt_pars_ud = nullptr;
|
| 69 |
+
struct ggml_tensor * opt_step_params = nullptr; // Stores output of get_opt_pars.
|
| 70 |
+
|
| 71 |
+
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
| 72 |
};
|
| 73 |
|
| 74 |
struct ggml_opt_result {
|
|
|
|
| 231 |
result.adamw.eps = 1e-8f;
|
| 232 |
result.adamw.wd = 0.0f;
|
| 233 |
|
| 234 |
+
result.sgd.alpha = 1e-3f;
|
| 235 |
+
result.sgd.wd = 0.0f;
|
| 236 |
+
|
| 237 |
return result;
|
| 238 |
}
|
| 239 |
|
| 240 |
+
|
| 241 |
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
|
| 242 |
return *((struct ggml_opt_optimizer_params *) userdata);
|
| 243 |
}
|
|
|
|
| 255 |
/*opt_period =*/ 1,
|
| 256 |
/*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
|
| 257 |
/*get_opt_pars_ud =*/ nullptr,
|
| 258 |
+
/*optimizer =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
| 259 |
};
|
| 260 |
}
|
| 261 |
|
|
|
|
| 323 |
GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
|
| 324 |
GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
|
| 325 |
|
| 326 |
+
const enum ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
|
| 327 |
+
|
| 328 |
const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
|
| 329 |
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
|
| 330 |
|
| 331 |
+
const bool need_momenta = opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT &&
|
| 332 |
+
opt_ctx->optimizer == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
| 333 |
+
|
| 334 |
ggml_set_input(opt_ctx->inputs);
|
| 335 |
ggml_set_output(opt_ctx->outputs);
|
| 336 |
|
|
|
|
| 352 |
// - pred (if using static graphs)
|
| 353 |
// - ncorrect (if using static graphs, 2 tensors).
|
| 354 |
constexpr size_t n_loss = 1;
|
| 355 |
+
const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
|
|
|
|
| 356 |
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
|
| 357 |
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
|
| 358 |
struct ggml_init_params params = {
|
|
|
|
| 469 |
}
|
| 470 |
}
|
| 471 |
|
| 472 |
+
if (need_momenta && opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
|
| 473 |
opt_ctx->grad_m.resize(n_nodes);
|
| 474 |
opt_ctx->grad_v.resize(n_nodes);
|
| 475 |
for (int i = 0; i < n_nodes; ++i) {
|
|
|
|
| 503 |
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
|
| 504 |
opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
|
| 505 |
|
| 506 |
+
opt_ctx->opt_step_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, need_momenta ? 7 : 2);
|
| 507 |
+
ggml_tensor * adamw_params = opt_ctx->opt_step_params;
|
| 508 |
+
ggml_set_input(adamw_params);
|
| 509 |
+
const char * optimizer_name = ggml_opt_optimizer_name(opt_ctx->optimizer);
|
| 510 |
+
ggml_format_name(adamw_params, "%s_params", optimizer_name);
|
| 511 |
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
|
| 512 |
struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
|
| 513 |
struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
|
| 514 |
|
| 515 |
if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
|
| 516 |
+
struct ggml_tensor * m = nullptr;
|
| 517 |
+
struct ggml_tensor * v = nullptr;
|
| 518 |
+
if (need_momenta) {
|
| 519 |
+
m = opt_ctx->grad_m[i];
|
| 520 |
+
v = opt_ctx->grad_v[i];
|
| 521 |
+
ggml_format_name(m, "AdamW m for %s", node->name);
|
| 522 |
+
ggml_format_name(v, "AdamW v for %s", node->name);
|
| 523 |
+
}
|
| 524 |
+
struct ggml_tensor * opt_step;
|
| 525 |
+
switch (optimizer) {
|
| 526 |
+
case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
|
| 527 |
+
opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
|
| 528 |
+
break;
|
| 529 |
+
case GGML_OPT_OPTIMIZER_TYPE_SGD:
|
| 530 |
+
opt_step = ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
|
| 531 |
+
break;
|
| 532 |
+
default:
|
| 533 |
+
GGML_ABORT("fatal error");
|
| 534 |
+
}
|
| 535 |
+
ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
|
| 536 |
ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
|
| 537 |
}
|
| 538 |
}
|
|
|
|
| 558 |
result->opt_period = params.opt_period;
|
| 559 |
result->get_opt_pars = params.get_opt_pars;
|
| 560 |
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
| 561 |
+
result->optimizer = params.optimizer;
|
| 562 |
|
| 563 |
GGML_ASSERT(result->opt_period >= 1);
|
| 564 |
|
|
|
|
| 781 |
void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
|
| 782 |
GGML_ASSERT(opt_ctx->eval_ready);
|
| 783 |
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
|
| 784 |
+
const ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
|
| 785 |
+
|
| 786 |
+
switch (opt_ctx->optimizer) {
|
| 787 |
+
case GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
|
| 788 |
+
GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
|
| 789 |
+
GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
|
| 790 |
+
GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
|
| 791 |
+
GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
|
| 792 |
+
GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
|
| 793 |
+
GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
|
| 794 |
+
GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
|
| 795 |
+
GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
|
| 796 |
+
|
| 797 |
+
// beta1, beta2 after applying warmup
|
| 798 |
+
const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
|
| 799 |
+
const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
|
| 800 |
+
|
| 801 |
+
float * adamw_par_data = ggml_get_data_f32(opt_ctx->opt_step_params);
|
| 802 |
+
adamw_par_data[0] = opt_pars.adamw.alpha;
|
| 803 |
+
adamw_par_data[1] = opt_pars.adamw.beta1;
|
| 804 |
+
adamw_par_data[2] = opt_pars.adamw.beta2;
|
| 805 |
+
adamw_par_data[3] = opt_pars.adamw.eps;
|
| 806 |
+
adamw_par_data[4] = opt_pars.adamw.wd;
|
| 807 |
+
adamw_par_data[5] = beta1h;
|
| 808 |
+
adamw_par_data[6] = beta2h;
|
| 809 |
+
} break;
|
| 810 |
+
case GGML_OPT_OPTIMIZER_TYPE_SGD: {
|
| 811 |
+
GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
|
| 812 |
+
GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
|
| 813 |
+
GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
|
| 814 |
+
float * sgd = ggml_get_data_f32(opt_ctx->opt_step_params);
|
| 815 |
+
sgd[0] = opt_pars.sgd.alpha;
|
| 816 |
+
sgd[1] = opt_pars.sgd.wd;
|
| 817 |
+
} break;
|
| 818 |
+
default:
|
| 819 |
+
GGML_ABORT("fatal error");
|
| 820 |
+
}
|
| 821 |
}
|
| 822 |
|
| 823 |
ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
|
|
|
| 1002 |
ggml_tensor * outputs,
|
| 1003 |
ggml_opt_dataset_t dataset,
|
| 1004 |
enum ggml_opt_loss_type loss_type,
|
| 1005 |
+
enum ggml_opt_optimizer_type optimizer,
|
| 1006 |
ggml_opt_get_optimizer_params get_opt_pars,
|
| 1007 |
int64_t nepoch,
|
| 1008 |
int64_t nbatch_logical,
|
|
|
|
| 1033 |
params.opt_period = opt_period;
|
| 1034 |
params.get_opt_pars = get_opt_pars;
|
| 1035 |
params.get_opt_pars_ud = &epoch;
|
| 1036 |
+
params.optimizer = optimizer;
|
| 1037 |
ggml_opt_context_t opt_ctx = ggml_opt_init(params);
|
| 1038 |
|
| 1039 |
// Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
|
|
|
|
| 1076 |
ggml_opt_result_free(result_train);
|
| 1077 |
ggml_opt_result_free(result_val);
|
| 1078 |
}
|
| 1079 |
+
|
| 1080 |
+
enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t c) {
|
| 1081 |
+
return c->optimizer;
|
| 1082 |
+
}
|
| 1083 |
+
|
| 1084 |
+
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type o) {
|
| 1085 |
+
switch (o) {
|
| 1086 |
+
case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
|
| 1087 |
+
return "adamw";
|
| 1088 |
+
case GGML_OPT_OPTIMIZER_TYPE_SGD:
|
| 1089 |
+
return "sgd";
|
| 1090 |
+
default:
|
| 1091 |
+
return "undefined";
|
| 1092 |
+
};
|
| 1093 |
+
}
|
|
@@ -510,6 +510,7 @@ struct vk_device_struct {
|
|
| 510 |
vk_pipeline pipeline_rwkv_wkv6_f32;
|
| 511 |
vk_pipeline pipeline_rwkv_wkv7_f32;
|
| 512 |
vk_pipeline pipeline_opt_step_adamw_f32;
|
|
|
|
| 513 |
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
| 514 |
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
|
| 515 |
vk_pipeline pipeline_conv2d_dw_whcn_f32;
|
|
@@ -3123,6 +3124,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 3123 |
|
| 3124 |
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 3125 |
|
|
|
|
|
|
|
| 3126 |
// conv2d
|
| 3127 |
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
|
| 3128 |
uint32_t conv2d_WG_SIZE = 256;
|
|
@@ -7193,6 +7196,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 7193 |
return ctx->device->pipeline_opt_step_adamw_f32;
|
| 7194 |
}
|
| 7195 |
return nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7196 |
case GGML_OP_LEAKY_RELU:
|
| 7197 |
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 7198 |
return ctx->device->pipeline_leaky_relu_f32;
|
|
@@ -7692,6 +7700,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
| 7692 |
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
|
| 7693 |
ggml_vk_sync_buffers(subctx);
|
| 7694 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7695 |
} else if (use_src2) {
|
| 7696 |
ggml_vk_sync_buffers(subctx);
|
| 7697 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
|
@@ -8045,6 +8057,12 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
|
|
| 8045 |
);
|
| 8046 |
}
|
| 8047 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8048 |
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 8049 |
int * op_params = (int *)dst->op_params;
|
| 8050 |
|
|
@@ -9598,6 +9616,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 9598 |
case GGML_OP_LEAKY_RELU:
|
| 9599 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 9600 |
case GGML_OP_OPT_STEP_ADAMW:
|
|
|
|
| 9601 |
break;
|
| 9602 |
default:
|
| 9603 |
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
|
@@ -9662,6 +9681,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 9662 |
case GGML_OP_CONV_2D:
|
| 9663 |
case GGML_OP_CONV_2D_DW:
|
| 9664 |
case GGML_OP_LEAKY_RELU:
|
|
|
|
| 9665 |
{
|
| 9666 |
// These operations all go through ggml_vk_op_f32, so short-circuit and
|
| 9667 |
// do the only thing needed for the dryrun.
|
|
@@ -9911,6 +9931,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 9911 |
case GGML_OP_OPT_STEP_ADAMW:
|
| 9912 |
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
| 9913 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9914 |
break;
|
| 9915 |
default:
|
| 9916 |
return false;
|
|
@@ -10014,8 +10039,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
| 10014 |
case GGML_OP_REPEAT:
|
| 10015 |
case GGML_OP_REPEAT_BACK:
|
| 10016 |
case GGML_OP_OPT_STEP_ADAMW:
|
|
|
|
| 10017 |
buf = tensor->buffer;
|
| 10018 |
-
|
| 10019 |
break;
|
| 10020 |
case GGML_OP_UNARY:
|
| 10021 |
switch (ggml_get_unary_op(tensor)) {
|
|
@@ -11154,6 +11179,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 11154 |
case GGML_OP_SIN:
|
| 11155 |
case GGML_OP_COS:
|
| 11156 |
case GGML_OP_CLAMP:
|
|
|
|
|
|
|
|
|
|
| 11157 |
return op->src[0]->type == GGML_TYPE_F32;
|
| 11158 |
case GGML_OP_UPSCALE:
|
| 11159 |
case GGML_OP_ACC:
|
|
@@ -11175,8 +11203,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 11175 |
case GGML_OP_POOL_2D:
|
| 11176 |
case GGML_OP_RWKV_WKV6:
|
| 11177 |
case GGML_OP_RWKV_WKV7:
|
| 11178 |
-
case GGML_OP_LEAKY_RELU:
|
| 11179 |
-
case GGML_OP_OPT_STEP_ADAMW:
|
| 11180 |
return true;
|
| 11181 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 11182 |
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
|
@@ -11774,6 +11800,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
| 11774 |
src_clone[0]->flags = src0->flags;
|
| 11775 |
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
| 11776 |
src_clone[2], src_clone[3], src_clone[4]);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11777 |
}
|
| 11778 |
else {
|
| 11779 |
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
|
|
|
| 510 |
vk_pipeline pipeline_rwkv_wkv6_f32;
|
| 511 |
vk_pipeline pipeline_rwkv_wkv7_f32;
|
| 512 |
vk_pipeline pipeline_opt_step_adamw_f32;
|
| 513 |
+
vk_pipeline pipeline_opt_step_sgd_f32;
|
| 514 |
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
| 515 |
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
|
| 516 |
vk_pipeline pipeline_conv2d_dw_whcn_f32;
|
|
|
|
| 3124 |
|
| 3125 |
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 3126 |
|
| 3127 |
+
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 3128 |
+
|
| 3129 |
// conv2d
|
| 3130 |
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
|
| 3131 |
uint32_t conv2d_WG_SIZE = 256;
|
|
|
|
| 7196 |
return ctx->device->pipeline_opt_step_adamw_f32;
|
| 7197 |
}
|
| 7198 |
return nullptr;
|
| 7199 |
+
case GGML_OP_OPT_STEP_SGD:
|
| 7200 |
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 7201 |
+
return ctx->device->pipeline_opt_step_sgd_f32;
|
| 7202 |
+
}
|
| 7203 |
+
return nullptr;
|
| 7204 |
case GGML_OP_LEAKY_RELU:
|
| 7205 |
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 7206 |
return ctx->device->pipeline_leaky_relu_f32;
|
|
|
|
| 7700 |
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
|
| 7701 |
ggml_vk_sync_buffers(subctx);
|
| 7702 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
| 7703 |
+
} else if (op == GGML_OP_OPT_STEP_SGD) {
|
| 7704 |
+
// OPT_STEP_SGD works on src0, it does not need dst
|
| 7705 |
+
ggml_vk_sync_buffers(subctx);
|
| 7706 |
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
|
| 7707 |
} else if (use_src2) {
|
| 7708 |
ggml_vk_sync_buffers(subctx);
|
| 7709 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
|
|
|
| 8057 |
);
|
| 8058 |
}
|
| 8059 |
|
| 8060 |
+
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
| 8061 |
+
const size_t n = ggml_nelements(dst->src[0]);
|
| 8062 |
+
|
| 8063 |
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
|
| 8064 |
+
}
|
| 8065 |
+
|
| 8066 |
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 8067 |
int * op_params = (int *)dst->op_params;
|
| 8068 |
|
|
|
|
| 9616 |
case GGML_OP_LEAKY_RELU:
|
| 9617 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 9618 |
case GGML_OP_OPT_STEP_ADAMW:
|
| 9619 |
+
case GGML_OP_OPT_STEP_SGD:
|
| 9620 |
break;
|
| 9621 |
default:
|
| 9622 |
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
|
|
|
| 9681 |
case GGML_OP_CONV_2D:
|
| 9682 |
case GGML_OP_CONV_2D_DW:
|
| 9683 |
case GGML_OP_LEAKY_RELU:
|
| 9684 |
+
case GGML_OP_OPT_STEP_SGD:
|
| 9685 |
{
|
| 9686 |
// These operations all go through ggml_vk_op_f32, so short-circuit and
|
| 9687 |
// do the only thing needed for the dryrun.
|
|
|
|
| 9931 |
case GGML_OP_OPT_STEP_ADAMW:
|
| 9932 |
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
| 9933 |
|
| 9934 |
+
break;
|
| 9935 |
+
|
| 9936 |
+
case GGML_OP_OPT_STEP_SGD:
|
| 9937 |
+
ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
| 9938 |
+
|
| 9939 |
break;
|
| 9940 |
default:
|
| 9941 |
return false;
|
|
|
|
| 10039 |
case GGML_OP_REPEAT:
|
| 10040 |
case GGML_OP_REPEAT_BACK:
|
| 10041 |
case GGML_OP_OPT_STEP_ADAMW:
|
| 10042 |
+
case GGML_OP_OPT_STEP_SGD:
|
| 10043 |
buf = tensor->buffer;
|
|
|
|
| 10044 |
break;
|
| 10045 |
case GGML_OP_UNARY:
|
| 10046 |
switch (ggml_get_unary_op(tensor)) {
|
|
|
|
| 11179 |
case GGML_OP_SIN:
|
| 11180 |
case GGML_OP_COS:
|
| 11181 |
case GGML_OP_CLAMP:
|
| 11182 |
+
case GGML_OP_LEAKY_RELU:
|
| 11183 |
+
case GGML_OP_OPT_STEP_ADAMW:
|
| 11184 |
+
case GGML_OP_OPT_STEP_SGD:
|
| 11185 |
return op->src[0]->type == GGML_TYPE_F32;
|
| 11186 |
case GGML_OP_UPSCALE:
|
| 11187 |
case GGML_OP_ACC:
|
|
|
|
| 11203 |
case GGML_OP_POOL_2D:
|
| 11204 |
case GGML_OP_RWKV_WKV6:
|
| 11205 |
case GGML_OP_RWKV_WKV7:
|
|
|
|
|
|
|
| 11206 |
return true;
|
| 11207 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 11208 |
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
|
|
|
| 11800 |
src_clone[0]->flags = src0->flags;
|
| 11801 |
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
| 11802 |
src_clone[2], src_clone[3], src_clone[4]);
|
| 11803 |
+
} else if (tensor->op == GGML_OP_OPT_STEP_SGD) {
|
| 11804 |
+
src_clone[0]->flags = src0->flags;
|
| 11805 |
+
tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
|
| 11806 |
+
src_clone[2]);
|
| 11807 |
}
|
| 11808 |
else {
|
| 11809 |
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#include "generic_head.comp"
|
| 4 |
+
|
| 5 |
+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
| 6 |
+
|
| 7 |
+
layout (binding = 0) buffer X {A_TYPE data_x[];};
|
| 8 |
+
layout (binding = 1) readonly buffer G {A_TYPE data_grad[];};
|
| 9 |
+
layout (binding = 2) readonly buffer P {float data_params[2];};
|
| 10 |
+
|
| 11 |
+
void main() {
|
| 12 |
+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
| 13 |
+
|
| 14 |
+
if (i >= p.KX) {
|
| 15 |
+
return;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
const float alpha = data_params[0];
|
| 19 |
+
const float keep = 1.f - alpha * data_params[1];
|
| 20 |
+
|
| 21 |
+
data_x[i] = data_x[i] * keep - alpha * data_grad[i];
|
| 22 |
+
}
|
|
@@ -657,6 +657,7 @@ void process_shaders() {
|
|
| 657 |
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
| 658 |
|
| 659 |
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
|
|
|
| 660 |
|
| 661 |
string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
|
| 662 |
string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
|
|
|
|
| 657 |
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
| 658 |
|
| 659 |
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
| 660 |
+
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
| 661 |
|
| 662 |
string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
|
| 663 |
string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
|
|
@@ -1012,11 +1012,12 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 1012 |
"CROSS_ENTROPY_LOSS",
|
| 1013 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 1014 |
"OPT_STEP_ADAMW",
|
|
|
|
| 1015 |
|
| 1016 |
"GLU",
|
| 1017 |
};
|
| 1018 |
|
| 1019 |
-
static_assert(GGML_OP_COUNT ==
|
| 1020 |
|
| 1021 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 1022 |
"none",
|
|
@@ -1113,15 +1114,15 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1113 |
"cross_entropy_loss(x,y)",
|
| 1114 |
"cross_entropy_loss_back(x,y)",
|
| 1115 |
"adamw(x)",
|
|
|
|
| 1116 |
|
| 1117 |
"glu(x)",
|
| 1118 |
};
|
| 1119 |
|
| 1120 |
-
static_assert(GGML_OP_COUNT ==
|
| 1121 |
|
| 1122 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1123 |
|
| 1124 |
-
|
| 1125 |
static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
| 1126 |
"ABS",
|
| 1127 |
"SGN",
|
|
@@ -5606,6 +5607,28 @@ struct ggml_tensor * ggml_opt_step_adamw(
|
|
| 5606 |
return result;
|
| 5607 |
}
|
| 5608 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5609 |
////////////////////////////////////////////////////////////////////////////////
|
| 5610 |
|
| 5611 |
struct ggml_hash_set ggml_hash_set_new(size_t size) {
|
|
|
|
| 1012 |
"CROSS_ENTROPY_LOSS",
|
| 1013 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 1014 |
"OPT_STEP_ADAMW",
|
| 1015 |
+
"OPT_STEP_SGD",
|
| 1016 |
|
| 1017 |
"GLU",
|
| 1018 |
};
|
| 1019 |
|
| 1020 |
+
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
| 1021 |
|
| 1022 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 1023 |
"none",
|
|
|
|
| 1114 |
"cross_entropy_loss(x,y)",
|
| 1115 |
"cross_entropy_loss_back(x,y)",
|
| 1116 |
"adamw(x)",
|
| 1117 |
+
"sgd(x)",
|
| 1118 |
|
| 1119 |
"glu(x)",
|
| 1120 |
};
|
| 1121 |
|
| 1122 |
+
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
| 1123 |
|
| 1124 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1125 |
|
|
|
|
| 1126 |
static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
| 1127 |
"ABS",
|
| 1128 |
"SGN",
|
|
|
|
| 5607 |
return result;
|
| 5608 |
}
|
| 5609 |
|
| 5610 |
+
// opt_step_sgd
|
| 5611 |
+
|
| 5612 |
+
struct ggml_tensor * ggml_opt_step_sgd(
|
| 5613 |
+
struct ggml_context * ctx,
|
| 5614 |
+
struct ggml_tensor * a,
|
| 5615 |
+
struct ggml_tensor * grad,
|
| 5616 |
+
struct ggml_tensor * params) {
|
| 5617 |
+
GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
|
| 5618 |
+
GGML_ASSERT(ggml_are_same_shape(a, grad));
|
| 5619 |
+
GGML_ASSERT(params->type == GGML_TYPE_F32);
|
| 5620 |
+
GGML_ASSERT(ggml_nelements(params) == 2);
|
| 5621 |
+
|
| 5622 |
+
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
| 5623 |
+
|
| 5624 |
+
result->op = GGML_OP_OPT_STEP_SGD;
|
| 5625 |
+
result->src[0] = a;
|
| 5626 |
+
result->src[1] = grad;
|
| 5627 |
+
result->src[2] = params;
|
| 5628 |
+
|
| 5629 |
+
return result;
|
| 5630 |
+
}
|
| 5631 |
+
|
| 5632 |
////////////////////////////////////////////////////////////////////////////////
|
| 5633 |
|
| 5634 |
struct ggml_hash_set ggml_hash_set_new(size_t size) {
|