Skip to content

Commit a603bc6

Browse files
ggerganovqnixsynapse
authored andcommitted
graph : fix geglu (ggml-org#14077)
ggml-ci
1 parent 562b68d commit a603bc6

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

src/llama-graph.cpp

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -622,22 +622,14 @@ ggml_tensor * llm_graph_context::build_ffn(
622622
{
623623
// Split into two equal parts
624624
int64_t split_point = cur->ne[0] / 2;
625-
ggml_tensor * output_ffn_up = ggml_cont(ctx0, ggml_view_2d(
626-
ctx0, cur, split_point,
627-
cur->ne[1], cur->nb[1], 0
628-
));
629-
ggml_tensor * output_ffn_gate = ggml_cont(ctx0, ggml_view_2d(
630-
ctx0, cur, split_point,
631-
cur->ne[1], cur->nb[1],
632-
split_point * ggml_element_size(cur)
633-
));
634-
635-
// Apply GELU activation function to the first part
636-
output_ffn_up = ggml_gelu(ctx0, output_ffn_up);
637-
cb(output_ffn_up, "ffn_gelu", il);
638-
639-
// Element-wise multiplication between the activated part and the gate part
640-
cur = ggml_mul(ctx0, output_ffn_up, output_ffn_gate);
625+
// TODO: these conts should not be needed
626+
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
627+
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
628+
629+
x0 = ggml_gelu(ctx0, x0);
630+
cb(x0, "ffn_gelu", il);
631+
632+
cur = ggml_mul(ctx0, x0, x1);
641633
cb(cur, "ffn_geglu", il);
642634
} break;
643635
}

0 commit comments

Comments
 (0)