Skip to content

Commit 7534bbf

Browse files
committed
examples/finetune -opt SGD (stochastic gradient descent) memory opt
new finetune CLI arg -wd 1e-5 to enable weight decay in sgd or adamw, and -epochs N (default 2 as before) cache 1. - wd*alpha in 'adamw' opt struct cache computed optimizer opts (formerly were computed twice per epoch) new GGML_OPT_OPTIMIZER_SGD in ggml. avoids allocating m,v. ggml_opt_init now becomes aware of the optimization method observed 11gb gpu ram when using SGD instead of 20gb using adamw for llama 3.2-1b-F32 (finetune/ggml-opt only works on F32 so far), objective perplexity not directly comparable but improvements observed over two epochs, and accuracy on train strictly improves when switching between tune methods since memory is pre-allocated, the user defined fn that can vary optimizer settings would probably be able to change between SGD and adamw with each epoch but would need to use adamw for the first (not verified)
1 parent aa59aa3 commit 7534bbf

File tree

16 files changed

+314
-69
lines changed

16 files changed

+314
-69
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
1212
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
1313
endif()
1414

15+
message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")
16+
1517
# Add path to modules
1618
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
1719

common/arg.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,8 +1237,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
12371237
}
12381238
sampler_type_names.pop_back();
12391239

1240-
params.optimize = ggml_opt_get_default_optimizer_params(NULL);
1241-
params.optimize.adamw.alpha = 1e-8; // default 1e-3 is much too high for LLAMA_EXAMPLE_FINETUNE
1240+
params.optimize = ggml_opt_get_default_optimizer_params(NULL);
12421241

12431242
/**
12441243
* filter options by example
@@ -2182,19 +2181,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
21822181
params.ppl_output_type = value;
21832182
}
21842183
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2185-
add_opt(common_arg({ "-lr", "--learning-rate" }, "ALPHA",
2186-
string_format("adamw optimizer alpha (default: %.1f)", (double) params.optimize.adamw.alpha),
2187-
[](common_params & params, const std::string & value) {
2188-
params.optimize.adamw.alpha = std::stof(value);
2189-
})
2184+
add_opt(
2185+
common_arg(
2186+
{ "-lr", "--learning-rate" }, "ALPHA",
2187+
string_format("adamw or sgd optimizer alpha (default: %.2g)", (double) params.optimize.adamw.alpha),
2188+
[](common_params & params, const std::string & value) { params.optimize.adamw.alpha = std::stof(value); })
2189+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
2190+
add_opt(common_arg(
2191+
{ "-wd", "--weight-decay" }, "WD",
2192+
string_format("adamw or sgd optimizer weight decay (0 is off) (default: %.2g)",
2193+
(double) params.optimize.adamw.wd),
2194+
[](common_params & params, const std::string & value) { params.optimize.adamw.wd = std::stof(value); })
2195+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
2196+
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
2197+
string_format("optimizer max # of epochs (default: %d)", params.optimize.epochs),
2198+
[](common_params & params, int epochs) { params.optimize.epochs = epochs; })
21902199
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
21912200
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or //TODO:sgd",
21922201
[](common_params & params, const std::string & name) {
21932202
params.optimize.optimizer = named_ggml_opt_optimizer(name.c_str());
21942203
if (params.optimize.optimizer == GGML_OPT_OPTIMIZER_COUNT) {
21952204
throw std::invalid_argument("invalid --optimizer (try adamw)");
2196-
} else if (params.optimize.optimizer == GGML_OPT_OPTIMIZER_SGD) {
2197-
throw std::invalid_argument("TODO: implement SGD");
21982205
}
21992206
})
22002207
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));

examples/training/finetune.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ int main(int argc, char ** argv) {
3838
common_init();
3939
llama_backend_init();
4040
llama_numa_init(params.numa);
41-
4241
// load the model and apply lora adapter, if any
4342
common_init_result llama_init = common_init_from_params(params);
4443
llama_model_ptr & model = llama_init.model;
@@ -61,7 +60,8 @@ int main(int argc, char ** argv) {
6160
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
6261

6362
struct ggml_opt_optimizer_params & optimizer_params = params.optimize;
64-
LOG_INF("-optimizer %d -lr: %.1f", optimizer_params.optimizer, (double) optimizer_params.adamw.alpha);
63+
LOG_INF("-optimizer %s -lr: %.2g -epochs %d\n", ggml_opt_optimizer_name(optimizer_params.optimizer),
64+
(double) optimizer_params.adamw.alpha, optimizer_params.epochs);
6565

6666
struct llama_opt_params lopt_params {
6767
/*n_ctx_train =*/ 0,
@@ -77,7 +77,7 @@ int main(int argc, char ** argv) {
7777
ggml_opt_result_t result_train = ggml_opt_result_init();
7878
ggml_opt_result_t result_eval = ggml_opt_result_init();
7979

80-
for (int epoch = 0; epoch < 2; ++epoch) {
80+
for (unsigned epoch = 0; epoch < optimizer_params.epochs; ++epoch) {
8181
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
8282
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
8383
fprintf(stderr, "\n");

ggml/include/ggml-opt.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,17 @@ extern "C" {
9090
// AdamW optimizer parameters
9191
struct {
9292
float alpha; // learning rate
93-
float beta1;
94-
float beta2;
93+
float beta1; // adamw
94+
float beta2; // adamw
9595
float eps; // epsilon for numerical stability
96-
float wd; // weight decay for AdamW, use 0.0f to disable
96+
float wd; // weight decay for SGD or AdamW, use 0.0f to disable
9797
} adamw;
98+
99+
// only GGML_OPT_OPTIMIZER_ADMW allocates m, v per parameter
98100
enum ggml_opt_optimizer optimizer;
101+
102+
// affects finetune.cpp only so far:
103+
unsigned epochs; // max # of epochs sampling over training data
99104
};
100105

101106
// callback to calculate optimizer parameters prior to a backward pass
@@ -126,6 +131,8 @@ extern "C" {
126131

127132
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
128133
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
134+
struct ggml_opt_optimizer_params
135+
opt_params; // holds result of get_opt_pars(get_opt_pars_ud) after ggml_opt_init (could call get_opt_pars repeatedly instead)
129136
};
130137

131138
// get parameters for an optimization context with defaults set where possible

ggml/include/ggml.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ extern "C" {
450450
GGML_OP_REPEAT_BACK,
451451
GGML_OP_CONCAT,
452452
GGML_OP_SILU_BACK,
453-
GGML_OP_NORM, // normalize
453+
GGML_OP_NORM, // normalize
454454
GGML_OP_RMS_NORM,
455455
GGML_OP_RMS_NORM_BACK,
456456
GGML_OP_GROUP_NORM,
@@ -486,7 +486,7 @@ extern "C" {
486486
GGML_OP_POOL_1D,
487487
GGML_OP_POOL_2D,
488488
GGML_OP_POOL_2D_BACK,
489-
GGML_OP_UPSCALE, // nearest interpolate
489+
GGML_OP_UPSCALE, // nearest interpolate
490490
GGML_OP_PAD,
491491
GGML_OP_PAD_REFLECT_1D,
492492
GGML_OP_ARANGE,
@@ -517,6 +517,7 @@ extern "C" {
517517
GGML_OP_CROSS_ENTROPY_LOSS,
518518
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
519519
GGML_OP_OPT_STEP_ADAMW,
520+
GGML_OP_OPT_STEP_SGD,
520521

521522
GGML_OP_COUNT,
522523
};
@@ -2063,6 +2064,11 @@ extern "C" {
20632064
struct ggml_tensor * v,
20642065
struct ggml_tensor * adamw_params); // parameters such a the learning rate
20652066

2067+
// SGD (with weight decay) step
2068+
GGML_API struct ggml_tensor * ggml_opt_step_sgd(
2069+
struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * grad,
2070+
struct ggml_tensor * adamw_params); // parameters: alpha, the learning rate, and wd, weight decay
2071+
20662072
//
20672073
// automatic differentiation
20682074
//

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2057,6 +2057,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20572057
ggml_compute_forward_opt_step_adamw(params, tensor);
20582058
}
20592059
break;
2060+
case GGML_OP_OPT_STEP_SGD:
2061+
{
2062+
ggml_compute_forward_opt_step_sgd(params, tensor);
2063+
}
2064+
break;
20602065
case GGML_OP_NONE:
20612066
{
20622067
// nop
@@ -2341,6 +2346,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23412346
case GGML_OP_CROSS_ENTROPY_LOSS:
23422347
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
23432348
case GGML_OP_OPT_STEP_ADAMW:
2349+
case GGML_OP_OPT_STEP_SGD:
23442350
{
23452351
n_tasks = n_threads;
23462352
} break;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8831,7 +8831,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88318831
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
88328832
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
88338833
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
8834-
GGML_ASSERT(ggml_nelements(adamw_params) == 7);
8834+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
88358835

88368836
const int ith = params->ith;
88378837
const int nth = params->nth;
@@ -8849,14 +8849,14 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88498849
const int ir1 = MIN(ir0 + dr, nr);
88508850

88518851
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
8852+
88528853
const float alpha = adamw_params_ptr[0];
88538854
const float beta1 = adamw_params_ptr[1];
88548855
const float beta2 = adamw_params_ptr[2];
88558856
const float eps = adamw_params_ptr[3];
8856-
const float wd = adamw_params_ptr[4];
88578857
const float beta1h = adamw_params_ptr[5];
88588858
const float beta2h = adamw_params_ptr[6];
8859-
8859+
const float keep = adamw_params_ptr[7];
88608860
for (int ir = ir0; ir < ir1; ++ir) {
88618861
const int64_t i03 = ir/(ne02*ne01);
88628862
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -8879,7 +8879,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88798879
// The weight decay is applied independently of the Adam momenta m and v.
88808880
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
88818881
// See: https://arxiv.org/pdf/1711.05101v3.pdf
8882-
w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
8882+
w[i00] = w[i00] * keep - alpha * mh / vh;
88838883
}
88848884
}
88858885
}
@@ -8901,3 +8901,63 @@ void ggml_compute_forward_opt_step_adamw(
89018901
}
89028902
}
89038903
}
8904+
8905+
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
8906+
const ggml_tensor * src0 = dst->src[0];
8907+
const ggml_tensor * src0_grad = dst->src[1];
8908+
const ggml_tensor * adamw_params = dst->src[2];
8909+
8910+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
8911+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
8912+
8913+
const int ith = params->ith;
8914+
const int nth = params->nth;
8915+
8916+
const int nr = ggml_nrows(src0);
8917+
8918+
GGML_TENSOR_UNARY_OP_LOCALS
8919+
GGML_ASSERT(nb00 == sizeof(float));
8920+
8921+
// rows per thread
8922+
const int dr = (nr + nth - 1) / nth;
8923+
8924+
// row range for this thread
8925+
const int ir0 = dr * ith;
8926+
const int ir1 = MIN(ir0 + dr, nr);
8927+
8928+
// using adamw param subset we care about - alpha, wd - could have a separate struct
8929+
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
8930+
const float alpha = adamw_params_ptr[0];
8931+
const float keep = adamw_params_ptr[7];
8932+
8933+
for (int ir = ir0; ir < ir1; ++ir) {
8934+
const int64_t i03 = ir / (ne02 * ne01);
8935+
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
8936+
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
8937+
8938+
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
8939+
8940+
float * w = (float *) ((char *) src0->data + offset); // weight
8941+
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
8942+
8943+
for (int i00 = 0; i00 < ne00; ++i00) {
8944+
w[i00] = w[i00] * keep - alpha * g[i00];
8945+
}
8946+
}
8947+
}
8948+
8949+
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
8950+
const ggml_tensor * src0 = dst->src[0];
8951+
8952+
switch (src0->type) {
8953+
case GGML_TYPE_F32:
8954+
{
8955+
ggml_compute_forward_opt_step_sgd_f32(params, dst);
8956+
}
8957+
break;
8958+
default:
8959+
{
8960+
GGML_ABORT("fatal error - sgd is F32 only");
8961+
}
8962+
}
8963+
}

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void ggml_compute_forward_custom(const struct ggml_compute_params * params, stru
104104
void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
105105
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
106106
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
107-
107+
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108108
#ifdef __cplusplus
109109
}
110110
#endif

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "ggml-cuda/mmvq.cuh"
2525
#include "ggml-cuda/norm.cuh"
2626
#include "ggml-cuda/opt-step-adamw.cuh"
27+
#include "ggml-cuda/opt-step-sgd.cuh"
2728
#include "ggml-cuda/out-prod.cuh"
2829
#include "ggml-cuda/pad.cuh"
2930
#include "ggml-cuda/pool2d.cuh"
@@ -2352,6 +2353,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23522353
case GGML_OP_OPT_STEP_ADAMW:
23532354
ggml_cuda_opt_step_adamw(ctx, dst);
23542355
break;
2356+
case GGML_OP_OPT_STEP_SGD:
2357+
ggml_cuda_opt_step_sgd(ctx, dst);
2358+
break;
23552359
default:
23562360
return false;
23572361
}
@@ -3256,6 +3260,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32563260
case GGML_OP_CROSS_ENTROPY_LOSS:
32573261
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
32583262
case GGML_OP_OPT_STEP_ADAMW:
3263+
case GGML_OP_OPT_STEP_SGD:
32593264
return true;
32603265
default:
32613266
return false;

ggml/src/ggml-cuda/opt-step-adamw.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ static __global__ void opt_step_adamw_f32(
1717
const float beta1 = pars[1];
1818
const float beta2 = pars[2];
1919
const float eps = pars[3];
20-
const float wd = pars[4];
2120
const float beta1h = pars[5];
2221
const float beta2h = pars[6];
22+
const float keep = pars[7];
2323

2424
const float gi = g[i];
2525
const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1);
@@ -31,7 +31,7 @@ static __global__ void opt_step_adamw_f32(
3131
const float mh = gmi*beta1h;
3232
const float vh = sqrtf(gvi*beta2h) + eps;
3333

34-
x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
34+
x[i] = x[i] * keep - alpha * mh / vh;
3535
}
3636

3737
static void opt_step_adamw_f32_cuda(
@@ -62,14 +62,13 @@ void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst
6262
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
6363
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
6464
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
65-
GGML_ASSERT(ggml_nelements(adamw_params) == 7);
65+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
6666

6767
float * src0_d = (float *) src0->data;
6868
const float * src0_grad_d = (const float *) src0_grad->data;
6969
float * src0_grad_m_d = (float *) src0_grad_m->data;
7070
float * src0_grad_v_d = (float *) src0_grad_v->data;
7171
const float * adamw_params_d = (const float *) adamw_params->data;
72-
7372
cudaStream_t stream = ctx.stream();
7473

7574
const int64_t ne = ggml_nelements(src0);

0 commit comments

Comments
 (0)