@@ -327,25 +327,29 @@ async def _generate_question(
327327    ):
328328        if  len (node_or_edge ) ==  2 :
329329            des  =  node_or_edge [0 ] +  ": "  +  node_or_edge [1 ]['description' ]
330-             answer  =  node_or_edge [1 ]['description' ]
331330            loss  =  node_or_edge [1 ]['loss' ]
332331        else :
333332            des  =  node_or_edge [2 ]['description' ]
334-             answer  =  node_or_edge [2 ]['description' ]
335333            loss  =  node_or_edge [2 ]['loss' ]
336334
337335        async  with  semaphore :
338336            try :
339337                language  =  "Chinese"  if  detect_main_language (des ) ==  "zh"  else  "English" 
340-                 question  =  await  llm_client .generate_answer (
341-                     QUESTION_GENERATION_PROMPT [language ]['SINGLE_TEMPLATE' ].format (
342-                         answer = des 
338+ 
339+                 qa  =  await  llm_client .generate_answer (
340+                     QUESTION_GENERATION_PROMPT [language ]['SINGLE_QA_TEMPLATE' ].format (
341+                         doc = des 
343342                    )
344343                )
345-                 if  question .startswith ("Question:" ):
346-                     question  =  question [len ("Question:" ):].strip ()
347-                 elif  question .startswith ("问题:" ):
348-                     question  =  question [len ("问题:" ):].strip ()
344+ 
345+                 if  "Question:"  in  qa  and  "Answer:"  in  qa :
346+                     question  =  qa .split ("Question:" )[1 ].split ("Answer:" )[0 ].strip ()
347+                     answer  =  qa .split ("Answer:" )[1 ].strip ()
348+                 elif  "问题:"  in  qa  and  "答案:"  in  qa :
349+                     question  =  qa .split ("问题:" )[1 ].split ("答案:" )[0 ].strip ()
350+                     answer  =  qa .split ("答案:" )[1 ].strip ()
351+                 else :
352+                     return  {}
349353
350354                question  =  question .strip ("\" " )
351355                answer  =  answer .strip ("\" " )
@@ -370,9 +374,7 @@ async def _generate_question(
370374
371375    edges , nodes  =  await  _pre_tokenize (graph_storage , tokenizer , edges , nodes )
372376
373-     # TODO: 需要把node的name也加进去,或者只用edge,两种都试一下 
374377    tasks  =  []
375-     # des中可能会有SEP分割符 
376378    for  node  in  nodes :
377379        if  "<SEP>"  in  node [1 ]['description' ]:
378380            description_list  =  node [1 ]['description' ].split ("<SEP>" )
0 commit comments