Skip to content

Commit a9d5096

Browse files
committed
graph : remove model reference from build_pooling
ggml-ci
1 parent 7b51bed commit a9d5096

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

src/llama-graph.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,7 +1992,12 @@ ggml_tensor * llm_graph_context::build_rwkv_channel_mix(
19921992
return cur;
19931993
}
19941994

1995-
void llm_graph_context::build_pooling(ggml_cgraph * gf) const {
1995+
void llm_graph_context::build_pooling(
1996+
ggml_cgraph * gf,
1997+
ggml_tensor * cls,
1998+
ggml_tensor * cls_b,
1999+
ggml_tensor * cls_out,
2000+
ggml_tensor * cls_out_b) const {
19962001
if (!cparams.embeddings) {
19972002
return;
19982003
}
@@ -2036,18 +2041,18 @@ void llm_graph_context::build_pooling(ggml_cgraph * gf) const {
20362041

20372042
// classification head
20382043
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
2039-
GGML_ASSERT(model.cls != nullptr);
2040-
GGML_ASSERT(model.cls_b != nullptr);
2044+
GGML_ASSERT(cls != nullptr);
2045+
GGML_ASSERT(cls_b != nullptr);
20412046

2042-
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
2047+
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
20432048
cur = ggml_tanh(ctx0, cur);
20442049

20452050
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
20462051
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
2047-
if (model.cls_out) {
2048-
GGML_ASSERT(model.cls_out_b != nullptr);
2052+
if (cls_out) {
2053+
GGML_ASSERT(cls_out_b != nullptr);
20492054

2050-
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
2055+
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
20512056
}
20522057
} break;
20532058
default:

src/llama-graph.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,5 +617,10 @@ struct llm_graph_context {
617617
// pooling
618618
//
619619

620-
void build_pooling(ggml_cgraph * gf) const;
620+
void build_pooling(
621+
ggml_cgraph * gf,
622+
ggml_tensor * cls,
623+
ggml_tensor * cls_b,
624+
ggml_tensor * cls_out,
625+
ggml_tensor * cls_out_b) const;
621626
};

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10583,7 +10583,7 @@ llm_graph_result_ptr llama_model::build_graph(
1058310583
}
1058410584

1058510585
// add on pooling layer
10586-
llm->build_pooling(gf);
10586+
llm->build_pooling(gf, cls, cls_b, cls_out, cls_out_b);
1058710587

1058810588
return std::move(llm->res);
1058910589
}

0 commit comments

Comments
 (0)