@@ -26,9 +26,10 @@ async def _get_node_info(
2626
2727def  _get_level_n_edges_by_max_width (
2828    edge_adj_list : dict ,
29+     node_dict : dict ,
2930    edges : list ,
30-     src_id :  str ,
31-     tgt_id :  str ,
31+     nodes ,
32+     src_edge :  tuple ,
3233    max_depth : int ,
3334    bidirectional : bool ,
3435    max_extra_edges : int ,
@@ -39,15 +40,18 @@ def _get_level_n_edges_by_max_width(
3940    n is decided by max_depth in traverse_strategy 
4041
4142    :param edge_adj_list 
43+     :param node_dict 
4244    :param edges 
43-     :param src_id  
44-     :param tgt_id  
45+     :param nodes  
46+     :param src_edge  
4547    :param max_depth 
4648    :param bidirectional 
4749    :param max_extra_edges 
4850    :param edge_sampling 
4951    :return: level n edges 
5052    """ 
53+     src_id , tgt_id , _  =  src_edge 
54+ 
5155    level_n_edges  =  []
5256
5357    start_nodes  =  {tgt_id } if  not  bidirectional  else  {src_id , tgt_id }
@@ -66,7 +70,8 @@ def _get_level_n_edges_by_max_width(
6670            break 
6771
6872        if  len (candidate_edges ) >=  max_extra_edges :
69-             candidate_edges  =  _sort_edges (candidate_edges , edge_sampling )[:max_extra_edges ]
73+             er_tuples  =  [([nodes [node_dict [edge [0 ]]], nodes [node_dict [edge [1 ]]]], edge ) for  edge  in  candidate_edges ]
74+             candidate_edges  =  _sort_edges (er_tuples , edge_sampling )[:max_extra_edges ]
7075            for  edge  in  candidate_edges :
7176                level_n_edges .append (edge )
7277                edge [2 ]["visited" ] =  True 
@@ -138,7 +143,8 @@ def _get_level_n_edges_by_max_tokens(
138143        if  not  candidate_edges :
139144            break 
140145
141-         candidate_edges  =  _sort_edges (candidate_edges , edge_sampling )
146+         er_tuples  =  [([nodes [node_dict [edge [0 ]]], nodes [node_dict [edge [1 ]]]], edge ) for  edge  in  candidate_edges ]
147+         candidate_edges  =  _sort_edges (er_tuples , edge_sampling )
142148        for  edge  in  candidate_edges :
143149            max_tokens  -=  edge [2 ]["length" ]
144150            if  not  edge [0 ] in  temp_nodes :
@@ -166,22 +172,24 @@ def _get_level_n_edges_by_max_tokens(
166172    return  level_n_edges 
167173
168174
169- def  _sort_edges (edges : list , edge_sampling : str ) ->  list :
175+ def  _sort_edges (er_tuples : list , edge_sampling : str ) ->  list :
170176    """ 
171177    Sort edges with edge sampling strategy 
172178
173-     :param edges: total edges  
179+     :param er_tuples: [(nodes:list, edge:tuple)]  
174180    :param edge_sampling: edge sampling strategy (random, min_loss, max_loss) 
175181    :return: sorted edges 
176182    """ 
177183    if  edge_sampling  ==  "random" :
178-         random .shuffle ( edges )
184+         er_tuples   =   random .sample ( er_tuples ,  len ( er_tuples ) )
179185    elif  edge_sampling  ==  "min_loss" :
180-         edges  =  sorted (edges , key = lambda  x : x [2 ]["loss" ])
186+         er_tuples  =  sorted (er_tuples , key = lambda  x : sum ( node [ 1 ][ "loss" ]  for   node   in   x [ 0 ])  +   x [ 1 ] [2 ]["loss" ])
181187    elif  edge_sampling  ==  "max_loss" :
182-         edges  =  sorted (edges , key = lambda  x : x [2 ]["loss" ], reverse = True )
188+         er_tuples  =  sorted (er_tuples , key = lambda  x : sum (node [1 ]["loss" ] for  node  in  x [0 ]) +  x [1 ][2 ]["loss" ],
189+                            reverse = True )
183190    else :
184191        raise  ValueError (f"Invalid edge sampling: { edge_sampling }  " )
192+     edges  =  [edge  for  _ , edge  in  er_tuples ]
185193    return  edges 
186194
187195async  def  get_batches_with_strategy (
@@ -199,8 +207,6 @@ async def get_batches_with_strategy(
199207    max_depth  =  traverse_strategy .max_depth 
200208    edge_sampling  =  traverse_strategy .edge_sampling 
201209
202-     edges  =  _sort_edges (edges , edge_sampling )
203- 
204210    # 构建临接矩阵 
205211    edge_adj_list  =  defaultdict (list )
206212    node_dict  =  {}
@@ -220,6 +226,9 @@ async def get_cached_node_info(node_id: str) -> dict:
220226    for  i , (node_name , _ ) in  enumerate (nodes ):
221227        node_dict [node_name ] =  i 
222228
229+     er_tuples  =  [([nodes [node_dict [edge [0 ]]], nodes [node_dict [edge [1 ]]]], edge ) for  edge  in  edges ]
230+     edges  =  _sort_edges (er_tuples , edge_sampling )
231+ 
223232    for  edge  in  tqdm_async (edges , desc = "Preparing batches" ):
224233        if  "visited"  in  edge [2 ] and  edge [2 ]["visited" ]:
225234            continue 
@@ -238,7 +247,7 @@ async def get_cached_node_info(node_id: str) -> dict:
238247
239248        if  expand_method  ==  "max_width" :
240249            level_n_edges  =  _get_level_n_edges_by_max_width (
241-                 edge_adj_list , edges , src_id ,  tgt_id , max_depth ,
250+                 edge_adj_list , node_dict ,  edges , nodes ,  edge , max_depth ,
242251                traverse_strategy .bidirectional , traverse_strategy .max_extra_edges ,
243252                edge_sampling 
244253            )
@@ -260,7 +269,6 @@ async def get_cached_node_info(node_id: str) -> dict:
260269
261270        processing_batches .append ((_process_nodes , _process_edges ))
262271
263-     l 
264272    # isolate nodes 
265273    isolated_node_strategy  =  traverse_strategy .isolated_node_strategy 
266274    if  isolated_node_strategy  ==  "add" :
0 commit comments