Skip to content

Commit 424c5d0

Browse files
JohannesGaesslerggerganovslaren
committed
ggml/examples: add backend support for numerical optimization (ggml/949)
* CUDA eval works * stochastic gradient descent op * Adam except decay * CUDA CROSS_ENTROPY_LOSS_BACK * CUDA mnist-fc training works * backend CLI arg * refactor gguf load * remove sched from opt_step_adam * implement l1 regularization (weight decay) * extra call to add optimizer * initialize gradients with ggml_graph_reset * gradient accumulation * increment iter per eval instead of epoch * adjust backend interfaces * fix ggml_graph_reset without backend * fix ggml graph export/import * fixup * rename * revert ggml_opt changes * more general CUDA repeat_back * update documentation, fix CNN * validation split * add clarifying comment * optimize PyTorch training * adjust buffer size, thread count * fix 0.0f validation split * Update examples/mnist/mnist-common.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * fix gradient accumulation * tensor flag for accumulators -> tensor hash set * Update include/ggml.h Co-authored-by: slaren <slarengh@gmail.com> * Update tests/test-backend-ops.cpp Co-authored-by: slaren <slarengh@gmail.com> * Update tests/test-backend-ops.cpp Co-authored-by: slaren <slarengh@gmail.com> * fix test prints * Update src/ggml-backend.c Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * better CUDA support for noncontiguous out_prod * add comment --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: slaren <slarengh@gmail.com>
1 parent a6809c6 commit 424c5d0

24 files changed

+883
-129
lines changed

ggml/include/ggml-backend.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ extern "C" {
6666
// "offset" refers to the offset of the tensor data for setting/getting data
6767
GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
6868
GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
69+
GGML_API GGML_CALL void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
6970

7071
GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
7172

@@ -122,7 +123,7 @@ extern "C" {
122123
// The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
123124

124125
GGML_API size_t ggml_backend_reg_get_count(void);
125-
GGML_API size_t ggml_backend_reg_find_by_name(const char * name);
126+
GGML_API size_t ggml_backend_reg_find_by_name(const char * name); // returns index of backend with name, or SIZE_MAX if not found
126127
GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional)
127128
GGML_API const char * ggml_backend_reg_get_name(size_t i);
128129
GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific

ggml/include/ggml.h

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ extern "C" {
534534

535535
GGML_OP_CROSS_ENTROPY_LOSS,
536536
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
537+
GGML_OP_OPT_STEP_ADAMW,
537538

538539
GGML_OP_COUNT,
539540
};
@@ -571,10 +572,12 @@ extern "C" {
571572
GGML_LOG_LEVEL_DEBUG = 4,
572573
};
573574

575+
// this tensor...
574576
enum ggml_tensor_flag {
575-
GGML_TENSOR_FLAG_INPUT = 1,
576-
GGML_TENSOR_FLAG_OUTPUT = 2,
577-
GGML_TENSOR_FLAG_PARAM = 4,
577+
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
578+
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
579+
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
580+
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
578581
};
579582

580583
// n-dimensional tensor
@@ -2037,23 +2040,44 @@ extern "C" {
20372040
struct ggml_tensor * b,
20382041
struct ggml_tensor * c);
20392042

2043+
// AdamW optimizer step
2044+
// Paper: https://arxiv.org/pdf/1711.05101v3.pdf
2045+
// PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
2046+
GGML_API struct ggml_tensor * ggml_opt_step_adamw(
2047+
struct ggml_context * ctx,
2048+
struct ggml_tensor * a,
2049+
float alpha,
2050+
float beta1,
2051+
float beta2,
2052+
float eps,
2053+
float wd); // weight decay
2054+
20402055
//
20412056
// automatic differentiation
20422057
//
20432058

2044-
GGML_API void ggml_set_param(
2045-
struct ggml_context * ctx,
2046-
struct ggml_tensor * tensor);
2059+
GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
2060+
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
20472061

20482062
GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
2049-
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
2063+
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep);
2064+
2065+
GGML_API void ggml_build_opt_adamw(
2066+
struct ggml_context * ctx,
2067+
struct ggml_cgraph * gf,
2068+
struct ggml_cgraph * gb,
2069+
float alpha,
2070+
float beta1,
2071+
float beta2,
2072+
float eps,
2073+
float wd); // weight decay
20502074

20512075
// graph allocation in a context
20522076
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
20532077
GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
20542078
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
20552079
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
2056-
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
2080+
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
20572081
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
20582082

20592083
GGML_API int ggml_graph_size (struct ggml_cgraph * cgraph);

ggml/src/ggml-backend-impl.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,16 @@ extern "C" {
3838
typedef void * ggml_backend_buffer_context_t;
3939

4040
struct ggml_backend_buffer_i {
41-
const char * (*GGML_CALL get_name) (ggml_backend_buffer_t buffer);
42-
void (*GGML_CALL free_buffer)(ggml_backend_buffer_t buffer);
43-
void * (*GGML_CALL get_base) (ggml_backend_buffer_t buffer);
44-
void (*GGML_CALL init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
45-
void (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
46-
void (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
47-
bool (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
48-
void (*GGML_CALL clear) (ggml_backend_buffer_t buffer, uint8_t value);
49-
void (*GGML_CALL reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
41+
const char * (*GGML_CALL get_name) (ggml_backend_buffer_t buffer);
42+
void (*GGML_CALL free_buffer) (ggml_backend_buffer_t buffer);
43+
void * (*GGML_CALL get_base) (ggml_backend_buffer_t buffer);
44+
void (*GGML_CALL init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
45+
void (*GGML_CALL memset_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
46+
void (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
47+
void (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
48+
bool (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
49+
void (*GGML_CALL clear) (ggml_backend_buffer_t buffer, uint8_t value);
50+
void (*GGML_CALL reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
5051
};
5152

5253
struct ggml_backend_buffer {

ggml/src/ggml-backend.c

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,22 @@ GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void *
246246
buf->iface.get_tensor(buf, tensor, data, offset, size);
247247
}
248248

249+
GGML_API GGML_CALL void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
250+
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
251+
252+
GGML_ASSERT(buf != NULL && "tensor buffer not set");
253+
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
254+
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
255+
256+
if (!size) {
257+
return;
258+
}
259+
260+
GGML_ASSERT(buf->iface.memset_tensor != NULL && "memset not supported by backend buffer");
261+
262+
buf->iface.memset_tensor(buf, tensor, value, offset, size);
263+
}
264+
249265
void ggml_backend_synchronize(ggml_backend_t backend) {
250266
if (backend->iface.synchronize == NULL) {
251267
return;
@@ -569,6 +585,12 @@ GGML_CALL static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t
569585
free(buffer->context);
570586
}
571587

588+
GGML_CALL static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
589+
memset((char *)tensor->data + offset, value, size);
590+
591+
GGML_UNUSED(buffer);
592+
}
593+
572594
GGML_CALL static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
573595
memcpy((char *)tensor->data + offset, data, size);
574596

@@ -600,6 +622,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
600622
/* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
601623
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
602624
/* .init_tensor = */ NULL, // no initialization required
625+
/* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
603626
/* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
604627
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
605628
/* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
@@ -613,6 +636,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
613636
/* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
614637
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
615638
/* .init_tensor = */ NULL, // no initialization required
639+
/* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
616640
/* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
617641
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
618642
/* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
@@ -980,6 +1004,7 @@ static struct ggml_backend_buffer_i ggml_backend_multi_buffer_context_interface(
9801004
/* .free_buffer = */ ggml_backend_multi_buffer_free_buffer,
9811005
/* .get_base = */ NULL,
9821006
/* .init_tensor = */ NULL,
1007+
/* .memset_tensor = */ NULL,
9831008
/* .set_tensor = */ NULL,
9841009
/* .get_tensor = */ NULL,
9851010
/* .cpy_tensor = */ NULL,

ggml/src/ggml-cann.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,7 @@ static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
10371037
/* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
10381038
/* .get_base = */ ggml_backend_cann_buffer_get_base,
10391039
/* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
1040+
/* .memset_tensor = */ NULL,
10401041
/* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
10411042
/* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
10421043
/* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,

ggml/src/ggml-cuda.cu

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include "ggml-cuda/mmq.cuh"
2222
#include "ggml-cuda/mmvq.cuh"
2323
#include "ggml-cuda/norm.cuh"
24+
#include "ggml-cuda/opt-step-adamw.cuh"
25+
#include "ggml-cuda/out-prod.cuh"
2426
#include "ggml-cuda/pad.cuh"
2527
#include "ggml-cuda/pool2d.cuh"
2628
#include "ggml-cuda/quantize.cuh"
@@ -493,6 +495,14 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
493495
}
494496
}
495497

498+
GGML_CALL static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
499+
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
500+
501+
ggml_cuda_set_device(ctx->device);
502+
CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));
503+
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
504+
}
505+
496506
GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
497507
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
498508

@@ -544,6 +554,7 @@ static ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
544554
/* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
545555
/* .get_base = */ ggml_backend_cuda_buffer_get_base,
546556
/* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
557+
/* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor,
547558
/* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
548559
/* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
549560
/* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor,
@@ -860,6 +871,7 @@ static struct ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
860871
/* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer,
861872
/* .get_base = */ ggml_backend_cuda_split_buffer_get_base,
862873
/* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor,
874+
/* .memset_tensor = */ NULL,
863875
/* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor,
864876
/* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor,
865877
/* .cpy_tensor = */ NULL,
@@ -2168,6 +2180,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
21682180
case GGML_OP_REPEAT:
21692181
ggml_cuda_op_repeat(ctx, dst);
21702182
break;
2183+
case GGML_OP_REPEAT_BACK:
2184+
ggml_cuda_op_repeat_back(ctx, dst);
2185+
break;
21712186
case GGML_OP_GET_ROWS:
21722187
ggml_cuda_op_get_rows(ctx, dst);
21732188
break;
@@ -2201,6 +2216,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22012216
case GGML_UNARY_OP_NEG:
22022217
ggml_cuda_op_neg(ctx, dst);
22032218
break;
2219+
case GGML_UNARY_OP_STEP:
2220+
ggml_cuda_op_step(ctx, dst);
2221+
break;
22042222
case GGML_UNARY_OP_GELU:
22052223
ggml_cuda_op_gelu(ctx, dst);
22062224
break;
@@ -2267,6 +2285,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22672285
case GGML_OP_MUL_MAT_ID:
22682286
ggml_cuda_mul_mat_id(ctx, dst);
22692287
break;
2288+
case GGML_OP_OUT_PROD:
2289+
ggml_cuda_out_prod(ctx, dst);
2290+
break;
22702291
case GGML_OP_SCALE:
22712292
ggml_cuda_op_scale(ctx, dst);
22722293
break;
@@ -2324,6 +2345,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23242345
case GGML_OP_CROSS_ENTROPY_LOSS:
23252346
ggml_cuda_cross_entropy_loss(ctx, dst);
23262347
break;
2348+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2349+
ggml_cuda_cross_entropy_loss_back(ctx, dst);
2350+
break;
2351+
case GGML_OP_OPT_STEP_ADAMW:
2352+
ggml_cuda_opt_step_adamw(ctx, dst);
2353+
break;
23272354
default:
23282355
return false;
23292356
}
@@ -2761,6 +2788,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
27612788
case GGML_OP_UNARY:
27622789
switch (ggml_get_unary_op(op)) {
27632790
case GGML_UNARY_OP_NEG:
2791+
case GGML_UNARY_OP_STEP:
27642792
case GGML_UNARY_OP_GELU:
27652793
case GGML_UNARY_OP_SILU:
27662794
case GGML_UNARY_OP_RELU:
@@ -2813,6 +2841,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28132841
return false;
28142842
}
28152843
} break;
2844+
case GGML_OP_OUT_PROD:
2845+
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
28162846
case GGML_OP_GET_ROWS:
28172847
{
28182848
switch (op->src[0]->type) {
@@ -2869,6 +2899,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28692899
} break;
28702900
case GGML_OP_DUP:
28712901
case GGML_OP_REPEAT:
2902+
{
2903+
ggml_type src0_type = op->src[0]->type;
2904+
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
2905+
} break;
2906+
case GGML_OP_REPEAT_BACK:
2907+
return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
28722908
case GGML_OP_CONCAT:
28732909
{
28742910
ggml_type src0_type = op->src[0]->type;
@@ -2935,9 +2971,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
29352971
}
29362972
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
29372973
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
2974+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
29382975
case GGML_OP_CROSS_ENTROPY_LOSS:
2976+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2977+
case GGML_OP_OPT_STEP_ADAMW:
29392978
return true;
2940-
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
29412979
default:
29422980
return false;
29432981
}

0 commit comments

Comments
 (0)