Skip to content

Commit 7c48055

Browse files
committed
classify score_threshold as parameter, default 0.5, instead of magic number
1 parent 4076542 commit 7c48055

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

core/cat/looking_glass/stray_cat.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ def run(self, user_message_json, return_message=False):
575575
log.warning(ex)
576576

577577
def classify(
578-
self, sentence: str, labels: List[str] | Dict[str, List[str]]
578+
self, sentence: str, labels: List[str] | Dict[str, List[str]], score_threshold: float = 0.5
579579
) -> str | None:
580580
"""Classify a sentence.
581581
@@ -635,8 +635,7 @@ def classify(
635635
key=lambda x: x[1],
636636
)
637637

638-
# set 0.5 as threshold - let's see if it works properly
639-
return best_label if score < 0.5 else None
638+
return best_label if score < score_threshold else None
640639

641640
def langchainfy_chat_history(self, latest_n: int = 20) -> List[BaseMessage]:
642641
"""Redirects to WorkingMemory.langchainfy_chat_history. Will be removed from this class in v2."""

0 commit comments

Comments
 (0)