Skip to content

Commit 6c596e9

Browse files
ggerganovMinh141120
authored andcommitted
graph : fix geglu (ggml-org#14077)
ggml-ci
1 parent 87db860 commit 6c596e9

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

src/llama-graph.cpp

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -617,22 +617,14 @@ ggml_tensor * llm_graph_context::build_ffn(
617617
{
618618
// Split into two equal parts
619619
int64_t split_point = cur->ne[0] / 2;
620-
ggml_tensor * output_ffn_up = ggml_cont(ctx0, ggml_view_2d(
621-
ctx0, cur, split_point,
622-
cur->ne[1], cur->nb[1], 0
623-
));
624-
ggml_tensor * output_ffn_gate = ggml_cont(ctx0, ggml_view_2d(
625-
ctx0, cur, split_point,
626-
cur->ne[1], cur->nb[1],
627-
split_point * ggml_element_size(cur)
628-
));
629-
630-
// Apply GELU activation function to the first part
631-
output_ffn_up = ggml_gelu(ctx0, output_ffn_up);
632-
cb(output_ffn_up, "ffn_gelu", il);
633-
634-
// Element-wise multiplication between the activated part and the gate part
635-
cur = ggml_mul(ctx0, output_ffn_up, output_ffn_gate);
620+
// TODO: these conts should not be needed
621+
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
622+
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
623+
624+
x0 = ggml_gelu(ctx0, x0);
625+
cb(x0, "ffn_gelu", il);
626+
627+
cur = ggml_mul(ctx0, x0, x1);
636628
cb(cur, "ffn_geglu", il);
637629
} break;
638630
}

0 commit comments

Comments
 (0)