@@ -1992,7 +1992,12 @@ ggml_tensor * llm_graph_context::build_rwkv_channel_mix(
1992
1992
return cur;
1993
1993
}
1994
1994
1995
- void llm_graph_context::build_pooling (ggml_cgraph * gf) const {
1995
+ void llm_graph_context::build_pooling (
1996
+ ggml_cgraph * gf,
1997
+ ggml_tensor * cls,
1998
+ ggml_tensor * cls_b,
1999
+ ggml_tensor * cls_out,
2000
+ ggml_tensor * cls_out_b) const {
1996
2001
if (!cparams.embeddings ) {
1997
2002
return ;
1998
2003
}
@@ -2036,18 +2041,18 @@ void llm_graph_context::build_pooling(ggml_cgraph * gf) const {
2036
2041
2037
2042
// classification head
2038
2043
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
2039
- GGML_ASSERT (model. cls != nullptr );
2040
- GGML_ASSERT (model. cls_b != nullptr );
2044
+ GGML_ASSERT (cls != nullptr );
2045
+ GGML_ASSERT (cls_b != nullptr );
2041
2046
2042
- cur = ggml_add (ctx0, ggml_mul_mat (ctx0, model. cls , inp), model. cls_b );
2047
+ cur = ggml_add (ctx0, ggml_mul_mat (ctx0, cls, inp), cls_b);
2043
2048
cur = ggml_tanh (ctx0, cur);
2044
2049
2045
2050
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
2046
2051
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
2047
- if (model. cls_out ) {
2048
- GGML_ASSERT (model. cls_out_b != nullptr );
2052
+ if (cls_out) {
2053
+ GGML_ASSERT (cls_out_b != nullptr );
2049
2054
2050
- cur = ggml_add (ctx0, ggml_mul_mat (ctx0, model. cls_out , cur), model. cls_out_b );
2055
+ cur = ggml_add (ctx0, ggml_mul_mat (ctx0, cls_out, cur), cls_out_b);
2051
2056
}
2052
2057
} break ;
2053
2058
default :
0 commit comments