14
14
# limitations under the License.
15
15
#
16
16
import binascii
17
- from datetime import datetime
18
17
import logging
19
18
import re
20
19
import time
21
20
from copy import deepcopy
21
+ from datetime import datetime
22
22
from functools import partial
23
23
from timeit import default_timer as timer
24
24
36
36
from rag .app .resume import forbidden_select_fields4resume
37
37
from rag .app .tag import label_question
38
38
from rag .nlp .search import index_name
39
- from rag .prompts import chunks_format , citation_prompt , full_question , kb_prompt , keyword_extraction , llm_id2llm_type , message_fit_in , \
40
- cross_languages
39
+ from rag .prompts import chunks_format , citation_prompt , cross_languages , full_question , kb_prompt , keyword_extraction , llm_id2llm_type , message_fit_in
41
40
from rag .utils import num_tokens_from_string , rmSpace
42
41
from rag .utils .tavily_conn import Tavily
43
42
@@ -303,6 +302,39 @@ def chat(dialog, messages, stream=True, **kwargs):
303
302
if "max_tokens" in gen_conf :
304
303
gen_conf ["max_tokens" ] = min (gen_conf ["max_tokens" ], max_tokens - used_token_count )
305
304
305
+ def repair_bad_citation_formats (answer : str , kbinfos : dict , idx : dict ):
306
+ max_index = len (kbinfos ["chunks" ])
307
+
308
+ def safe_add (i ):
309
+ if 0 <= i < max_index :
310
+ idx .add (i )
311
+ return True
312
+ return False
313
+
314
+ def find_and_replace (pattern , group_index = 1 , repl = lambda i : f"##{ i } $$" , flags = 0 ):
315
+ nonlocal answer
316
+ for match in re .finditer (pattern , answer , flags = flags ):
317
+ try :
318
+ i = int (match .group (group_index ))
319
+ if safe_add (i ):
320
+ answer = answer .replace (match .group (0 ), repl (i ))
321
+ except Exception :
322
+ continue
323
+
324
+ find_and_replace (r"\(\s*ID:\s*(\d+)\s*\)" ) # (ID: 12)
325
+ find_and_replace (r"ID[: ]+(\d+)" ) # ID: 12, ID 12
326
+ find_and_replace (r"\$\$(\d+)\$\$" ) # $$12$$
327
+ find_and_replace (r"\$\[(\d+)\]\$" ) # $[12]$
328
+ find_and_replace (r"\$\$(\d+)\${2,}" ) # $$12$$$$
329
+ find_and_replace (r"\$(\d+)\$" ) # $12$
330
+ find_and_replace (r"#(\d+)\$\$" ) # #12$$
331
+ find_and_replace (r"##(\d+)\$" ) # ##12$
332
+ find_and_replace (r"##(\d+)#{2,}" ) # ##12###
333
+ find_and_replace (r"【(\d+)】" ) # 【12】
334
+ find_and_replace (r"ref\s*(\d+)" , flags = re .IGNORECASE ) # ref12, ref 12, REF 12
335
+
336
+ return answer , idx
337
+
306
338
def decorate_answer (answer ):
307
339
nonlocal prompt_config , knowledges , kwargs , kbinfos , prompt , retrieval_ts , questions , langfuse_tracer
308
340
@@ -331,15 +363,7 @@ def decorate_answer(answer):
331
363
if i < len (kbinfos ["chunks" ]):
332
364
idx .add (i )
333
365
334
- # handle (ID: 1), ID: 2 etc.
335
- for match in re .finditer (r"\(\s*ID:\s*(\d+)\s*\)|ID[: ]+\s*(\d+)" , answer ):
336
- full_match = match .group (0 )
337
- id = match .group (1 ) or match .group (2 )
338
- if id :
339
- i = int (id )
340
- if i < len (kbinfos ["chunks" ]):
341
- idx .add (i )
342
- answer = answer .replace (full_match , f"##{ i } $$" )
366
+ answer , idx = repair_bad_citation_formats (answer , kbinfos , idx )
343
367
344
368
idx = set ([kbinfos ["chunks" ][int (i )]["doc_id" ] for i in idx ])
345
369
recall_docs = [d for d in kbinfos ["doc_aggs" ] if d ["doc_id" ] in idx ]
@@ -502,7 +526,7 @@ def get_table():
502
526
503
527
# compose Markdown table
504
528
columns = (
505
- "|" + "|" .join ([re .sub (r"(/.*|([^()]+))" , "" , field_map .get (tbl ["columns" ][i ]["name" ], tbl ["columns" ][i ]["name" ])) for i in column_idx ]) + ("|Source|" if docid_idx and docid_idx else "|" )
529
+ "|" + "|" .join ([re .sub (r"(/.*|([^()]+))" , "" , field_map .get (tbl ["columns" ][i ]["name" ], tbl ["columns" ][i ]["name" ])) for i in column_idx ]) + ("|Source|" if docid_idx and docid_idx else "|" )
506
530
)
507
531
508
532
line = "|" + "|" .join (["------" for _ in range (len (column_idx ))]) + ("|------|" if docid_idx and docid_idx else "" )
@@ -598,4 +622,5 @@ def decorate_answer(answer):
598
622
for ans in chat_mdl .chat_streamly (prompt , msg , {"temperature" : 0.1 }):
599
623
answer = ans
600
624
yield {"answer" : answer , "reference" : {}}
601
- yield decorate_answer (answer )
625
+ yield decorate_answer (answer )
626
+
0 commit comments