Skip to content

Commit e8e2a95

Browse files
authored
Refa: more fallbacks for bad citation format (#7710)
### What problem does this PR solve? More fallbacks for bad citation format ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring
1 parent b908c33 commit e8e2a95

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

api/db/services/dialog_service.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
# limitations under the License.
1515
#
1616
import binascii
17-
from datetime import datetime
1817
import logging
1918
import re
2019
import time
2120
from copy import deepcopy
21+
from datetime import datetime
2222
from functools import partial
2323
from timeit import default_timer as timer
2424

@@ -36,8 +36,7 @@
3636
from rag.app.resume import forbidden_select_fields4resume
3737
from rag.app.tag import label_question
3838
from rag.nlp.search import index_name
39-
from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in, \
40-
cross_languages
39+
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in
4140
from rag.utils import num_tokens_from_string, rmSpace
4241
from rag.utils.tavily_conn import Tavily
4342

@@ -303,6 +302,39 @@ def chat(dialog, messages, stream=True, **kwargs):
303302
if "max_tokens" in gen_conf:
304303
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
305304

305+
def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: dict):
306+
max_index = len(kbinfos["chunks"])
307+
308+
def safe_add(i):
309+
if 0 <= i < max_index:
310+
idx.add(i)
311+
return True
312+
return False
313+
314+
def find_and_replace(pattern, group_index=1, repl=lambda i: f"##{i}$$", flags=0):
315+
nonlocal answer
316+
for match in re.finditer(pattern, answer, flags=flags):
317+
try:
318+
i = int(match.group(group_index))
319+
if safe_add(i):
320+
answer = answer.replace(match.group(0), repl(i))
321+
except Exception:
322+
continue
323+
324+
find_and_replace(r"\(\s*ID:\s*(\d+)\s*\)") # (ID: 12)
325+
find_and_replace(r"ID[: ]+(\d+)") # ID: 12, ID 12
326+
find_and_replace(r"\$\$(\d+)\$\$") # $$12$$
327+
find_and_replace(r"\$\[(\d+)\]\$") # $[12]$
328+
find_and_replace(r"\$\$(\d+)\${2,}") # $$12$$$$
329+
find_and_replace(r"\$(\d+)\$") # $12$
330+
find_and_replace(r"#(\d+)\$\$") # #12$$
331+
find_and_replace(r"##(\d+)\$") # ##12$
332+
find_and_replace(r"##(\d+)#{2,}") # ##12###
333+
find_and_replace(r"【(\d+)】") # 【12】
334+
find_and_replace(r"ref\s*(\d+)", flags=re.IGNORECASE) # ref12, ref 12, REF 12
335+
336+
return answer, idx
337+
306338
def decorate_answer(answer):
307339
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer
308340

@@ -331,15 +363,7 @@ def decorate_answer(answer):
331363
if i < len(kbinfos["chunks"]):
332364
idx.add(i)
333365

334-
# handle (ID: 1), ID: 2 etc.
335-
for match in re.finditer(r"\(\s*ID:\s*(\d+)\s*\)|ID[: ]+\s*(\d+)", answer):
336-
full_match = match.group(0)
337-
id = match.group(1) or match.group(2)
338-
if id:
339-
i = int(id)
340-
if i < len(kbinfos["chunks"]):
341-
idx.add(i)
342-
answer = answer.replace(full_match, f"##{i}$$")
366+
answer, idx = repair_bad_citation_formats(answer, kbinfos, idx)
343367

344368
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
345369
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
@@ -502,7 +526,7 @@ def get_table():
502526

503527
# compose Markdown table
504528
columns = (
505-
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
529+
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
506530
)
507531

508532
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
@@ -598,4 +622,5 @@ def decorate_answer(answer):
598622
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
599623
answer = ans
600624
yield {"answer": answer, "reference": {}}
601-
yield decorate_answer(answer)
625+
yield decorate_answer(answer)
626+

0 commit comments

Comments
 (0)