| 
3 | 3 | from tqdm.asyncio import tqdm as tqdm_async  | 
4 | 4 | 
 
  | 
5 | 5 | from models import OpenAIModel, NetworkXStorage, TraverseStrategy, Tokenizer, JsonKVStorage  | 
6 |  | -from templates import ANSWER_REPHRASING_PROMPT, QUESTION_GENERATION_PROMPT  | 
 | 6 | +from templates import ANSWER_REPHRASING_PROMPT, QUESTION_GENERATION_PROMPT, MULTI_HOP_GENERATION_PROMPT  | 
7 | 7 | from utils import detect_main_language, compute_content_hash, logger  | 
8 | 8 | from graphgen.operators.split_graph import get_batches_with_strategy  | 
9 | 9 | 
 
  | 
@@ -399,3 +399,110 @@ async def _generate_question(  | 
399 | 399 |         except Exception as e: # pylint: disable=broad-except  | 
400 | 400 |             logger.error("Error occurred while generating questions: %s", e)  | 
401 | 401 |     return results  | 
 | 402 | + | 
 | 403 | +async def traverse_graph_for_multi_hop(  | 
 | 404 | +    llm_client: OpenAIModel,  | 
 | 405 | +    tokenizer: Tokenizer,  | 
 | 406 | +    graph_storage: NetworkXStorage,  | 
 | 407 | +    traverse_strategy: TraverseStrategy,  | 
 | 408 | +    text_chunks_storage: JsonKVStorage,  | 
 | 409 | +    max_concurrent: int = 1000  | 
 | 410 | +) -> dict:  | 
 | 411 | +    """  | 
 | 412 | +    Traverse the graph for multi-hop  | 
 | 413 | +
  | 
 | 414 | +    :param llm_client  | 
 | 415 | +    :param tokenizer  | 
 | 416 | +    :param graph_storage  | 
 | 417 | +    :param traverse_strategy  | 
 | 418 | +    :param text_chunks_storage  | 
 | 419 | +    :param max_concurrent  | 
 | 420 | +    :return: question and answer  | 
 | 421 | +    """  | 
 | 422 | +    assert traverse_strategy.qa_form == "multi_hop"  | 
 | 423 | + | 
 | 424 | +    semaphore = asyncio.Semaphore(max_concurrent)  | 
 | 425 | + | 
 | 426 | +    results = {}  | 
 | 427 | +    edges = list(await graph_storage.get_all_edges())  | 
 | 428 | +    nodes = list(await graph_storage.get_all_nodes())  | 
 | 429 | + | 
 | 430 | +    edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)  | 
 | 431 | + | 
 | 432 | +    processing_batches = await get_batches_with_strategy(  | 
 | 433 | +        nodes,  | 
 | 434 | +        edges,  | 
 | 435 | +        graph_storage,  | 
 | 436 | +        traverse_strategy  | 
 | 437 | +    )  | 
 | 438 | + | 
 | 439 | +    processing_batches = assign_difficulty(processing_batches, traverse_strategy.difficulty_order)  | 
 | 440 | + | 
 | 441 | +    async def _process_single_batch(  | 
 | 442 | +        _process_batch: tuple  | 
 | 443 | +    ) -> dict:  | 
 | 444 | +        async with semaphore:  | 
 | 445 | +            try:  | 
 | 446 | +                language = "Chinese" if detect_main_language(_process_batch[0][0]['description']) == "zh" else "English"  | 
 | 447 | + | 
 | 448 | +                _process_nodes = _process_batch[0]  | 
 | 449 | +                _process_edges = _process_batch[1]  | 
 | 450 | + | 
 | 451 | +                entities = [  | 
 | 452 | +                    f"{_process_node['node_id']}: {_process_node['description']}" for _process_node in _process_nodes  | 
 | 453 | +                ]  | 
 | 454 | + | 
 | 455 | +                relations = [  | 
 | 456 | +                    f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"  | 
 | 457 | +                    for _process_edge in _process_edges  | 
 | 458 | +                ]  | 
 | 459 | + | 
 | 460 | +                entities_str = "\n".join([f"{index + 1}. {entity}" for index, entity in enumerate(entities)])  | 
 | 461 | +                relations_str = "\n".join([f"{index + 1}. {relation}" for index, relation in enumerate(relations)])  | 
 | 462 | + | 
 | 463 | +                prompt = MULTI_HOP_GENERATION_PROMPT[language].format(  | 
 | 464 | +                    entities=entities_str,  | 
 | 465 | +                    relationships=relations_str  | 
 | 466 | +                )  | 
 | 467 | + | 
 | 468 | +                context = await llm_client.generate_answer(prompt)  | 
 | 469 | + | 
 | 470 | +                # post-process the context  | 
 | 471 | +                if "Question:" in context and "Answer:" in context:  | 
 | 472 | +                    question = context.split("Question:")[1].split("Answer:")[0].strip()  | 
 | 473 | +                    answer = context.split("Answer:")[1].strip()  | 
 | 474 | +                elif "问题:" in context and "答案:" in context:  | 
 | 475 | +                    question = context.split("问题:")[1].split("答案:")[0].strip()  | 
 | 476 | +                    answer = context.split("答案:")[1].strip()  | 
 | 477 | +                else:  | 
 | 478 | +                    return {}  | 
 | 479 | + | 
 | 480 | +                question = question.strip("\"")  | 
 | 481 | +                answer = answer.strip("\"")  | 
 | 482 | + | 
 | 483 | +                logger.info("Question: %s", question)  | 
 | 484 | +                logger.info("Answer: %s", answer)  | 
 | 485 | + | 
 | 486 | +                return {  | 
 | 487 | +                    compute_content_hash(question): {  | 
 | 488 | +                        "question": question,  | 
 | 489 | +                        "answer": answer,  | 
 | 490 | +                        "loss": get_average_loss(_process_batch),  | 
 | 491 | +                        "difficulty": _process_batch[2],  | 
 | 492 | +                    }  | 
 | 493 | +                }  | 
 | 494 | + | 
 | 495 | +            except Exception as e: # pylint: disable=broad-except  | 
 | 496 | +                logger.error("Error occurred while processing batch: %s", e)  | 
 | 497 | +                return {}  | 
 | 498 | + | 
 | 499 | +    for result in tqdm_async(  | 
 | 500 | +        asyncio.as_completed([_process_single_batch(batch) for batch in processing_batches]),  | 
 | 501 | +        total=len(processing_batches),  | 
 | 502 | +        desc="Processing batches"  | 
 | 503 | +    ):  | 
 | 504 | +        try:  | 
 | 505 | +            results.update(await result)  | 
 | 506 | +        except Exception as e: # pylint: disable=broad-except  | 
 | 507 | +            logger.error("Error occurred while processing batches: %s", e)  | 
 | 508 | +    return results  | 
0 commit comments