@@ -100,6 +100,33 @@ def get_loss_tercile(losses: list) -> (float, float):
100100
101101    return  losses [q1_index ], losses [q2_index ]
102102
103+ def  assign_difficulty (subgraphs : list , difficulty_order : list ) ->  list :
104+     """ 
105+     Assign difficulty to subgraphs based on the loss 
106+ 
107+     :param subgraphs 
108+     :param difficulty_order 
109+     :return 
110+     """ 
111+     losses  =  []
112+     for  subgraph  in  subgraphs :
113+         loss  =  get_average_loss (subgraph )
114+         losses .append (loss )
115+     q1 , q2  =  get_loss_tercile (losses )
116+ 
117+     for  i , subgraph  in  enumerate (subgraphs ):
118+         loss  =  get_average_loss (subgraph )
119+         if  loss  <  q1 :
120+             # easy 
121+             subgraphs [i ] =  (subgraph [0 ], subgraph [1 ], difficulty_order [0 ])
122+         elif  loss  <  q2 :
123+             # medium 
124+             subgraphs [i ] =  (subgraph [0 ], subgraph [1 ], difficulty_order [1 ])
125+         else :
126+             # hard 
127+             subgraphs [i ] =  (subgraph [0 ], subgraph [1 ], difficulty_order [2 ])
128+     return  subgraphs 
129+ 
103130def  get_average_loss (batch : tuple ) ->  float :
104131    if  loss_strategy  ==  "only_edge" :
105132        return  sum (edge [2 ]['loss' ] for  edge  in  batch [1 ]) /  len (batch [1 ])
@@ -258,24 +285,7 @@ async def _process_single_batch(
258285        traverse_strategy 
259286    )
260287
261-     losses  =  []
262-     for  batch  in  processing_batches :
263-         loss  =  get_average_loss (batch )
264-         losses .append (loss )
265-     q1 , q2  =  get_loss_tercile (losses )
266- 
267-     difficulty_order  =  traverse_strategy .difficulty_order 
268-     for  i , batch  in  enumerate (processing_batches ):
269-         loss  =  get_average_loss (batch )
270-         if  loss  <  q1 :
271-             # easy 
272-             processing_batches [i ] =  (batch [0 ], batch [1 ], difficulty_order [0 ])
273-         elif  loss  <  q2 :
274-             # medium 
275-             processing_batches [i ] =  (batch [0 ], batch [1 ], difficulty_order [1 ])
276-         else :
277-             # hard 
278-             processing_batches [i ] =  (batch [0 ], batch [1 ], difficulty_order [2 ])
288+     processing_batches  =  assign_difficulty (processing_batches , traverse_strategy .difficulty_order )
279289
280290    for  result  in  tqdm_async (asyncio .as_completed (
281291        [_process_single_batch (batch ) for  batch  in  processing_batches ]
0 commit comments