Skip to content

Commit dfa89e7

Browse files
JohannesGaesslerggerganov
authored andcommitted
ggml: backward pass for split swiglu (llama/14483)
1 parent 417be2a commit dfa89e7

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

ggml/src/ggml.c

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6042,13 +6042,28 @@ static void ggml_compute_backward(
60426042
}
60436043
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
60446044
} break;
6045+
case GGML_OP_GLU: {
6046+
switch (ggml_get_glu_op(tensor)) {
6047+
case GGML_GLU_OP_SWIGLU: {
6048+
if (src0_needs_grads) {
6049+
GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
6050+
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
6051+
}
6052+
if (src1_needs_grads) {
6053+
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
6054+
}
6055+
} break;
6056+
default: {
6057+
GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
6058+
} //break;
6059+
}
6060+
} break;
60456061
case GGML_OP_NONE: {
60466062
// noop
60476063
} break;
60486064
case GGML_OP_COUNT:
60496065
default: {
6050-
fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
6051-
GGML_ABORT("fatal error");
6066+
GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
60526067
} //break;
60536068
}
60546069

0 commit comments

Comments
 (0)