Skip to content

Commit 4d6b804

Browse files
committed
handle when no cls_b and cls_out_b
1 parent 8b794f9 commit 4d6b804

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

src/llama-graph.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,10 +1572,13 @@ void llm_graph_context::build_pooling(
15721572
ggml_tensor * inp_cls = build_inp_cls();
15731573
inp = ggml_get_rows(ctx0, inp, inp_cls);
15741574

1575-
if (cls != nullptr && cls_b != nullptr) {
1575+
if (cls != nullptr) {
15761576
// classification head
15771577
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1578-
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
1578+
cur = ggml_mul_mat(ctx0, cls, inp);
1579+
if (cls_b != nullptr) {
1580+
cur = ggml_add(ctx0, cur, cls_b);
1581+
}
15791582
cur = ggml_tanh(ctx0, cur);
15801583

15811584
if (cls_norm) {
@@ -1586,16 +1589,22 @@ void llm_graph_context::build_pooling(
15861589
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
15871590
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
15881591
if (cls_out) {
1589-
GGML_ASSERT(cls_out_b != nullptr);
1590-
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
1592+
cur = ggml_mul_mat(ctx0, cls_out, cur);
1593+
if (cls_out_b != nullptr) {
1594+
cur = ggml_add(ctx0, cur, cls_out_b);
1595+
}
15911596
}
15921597
} else if (cls_out) {
15931598
// Single layer classification head (direct projection)
15941599
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1595-
GGML_ASSERT(cls_out_b != nullptr);
1596-
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
1600+
cur = ggml_mul_mat(ctx0, cls_out, inp);
1601+
if (cls_out_b != nullptr) {
1602+
cur = ggml_add(ctx0, cur, cls_out_b);
1603+
}
15971604
} else {
1598-
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
1605+
// Some models may not have either classification heads
1606+
// In this case, just use the CLS/pooled embedding directly
1607+
cur = inp;
15991608
}
16001609
} break;
16011610
default:

0 commit comments

Comments
 (0)