1515import  copy 
1616import  heapq 
1717from  itertools  import  chain 
18- from  typing  import  List , Tuple 
18+ from  typing  import  Dict ,  List ,  Optional , Tuple 
1919
2020import  torch 
2121from  tensordict  import  TensorDict 
@@ -150,7 +150,7 @@ def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool
150150    return  partitions 
151151
152152
153- def  get_seqlen_balanced_partitions (seqlen_list : List [int ], k_partitions : int , equal_size : bool ):
153+ def  get_seqlen_balanced_partitions (seqlen_list : List [int ], k_partitions : int , equal_size : bool )  ->   List [ List [ int ]] :
154154    """Get order of seq lengths to make partitions balanced, this is 
155155    used in balacing sum of seqlength across dp ranks and microbatches. 
156156
@@ -161,8 +161,7 @@ def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, eq
161161            resulting number of partitions 
162162        equal_size (bool): 
163163            if True, number of items in each partitions must be equal. 
164-             if False, only consider balancing the sum, each partition can have 
165-             variable number of items 
164+             if False, only consider balancing the sum, each partition can have variable number of items 
166165
167166    Returns: 
168167        partitions (List[List[int]]): 
@@ -186,14 +185,28 @@ def _check_and_sort_partitions(partitions):
186185    return  _check_and_sort_partitions (partitions )
187186
188187
189- def  log_seqlen_unbalance (seqlen_list : List [int ], partitions : List [List [int ]], prefix ):
190-     # add some metrics of seqlen sum on dp ranks 
188+ def  log_seqlen_unbalance (seqlen_list : List [int ], partitions : List [List [int ]], prefix : str ) ->  Dict [str , float ]:
189+     """ 
190+     Calculate and log metrics related to sequence length imbalance before and after partitioning. 
191+ 
192+     Args: 
193+         seqlen_list (List[int]): A list of sequence lengths for each item. 
194+         partitions (List[List[int]]): A list of partitions, where each inner list contains indices 
195+                                       from seqlen_list assigned to that partition. 
196+         prefix (str): A prefix to be added to each metric key in the returned dictionary. 
197+ 
198+     Returns: 
199+         dict: A dictionary containing metrics related to sequence length imbalance. 
200+     """ 
201+     # Get the number of partitions 
191202    k_partition  =  len (partitions )
192203    # assert len(seqlen_list) % k_partition == 0 
193204    batch_size  =  len (seqlen_list ) //  k_partition 
194205    min_sum_seqlen  =  None 
195206    max_sum_seqlen  =  None 
196207    total_sum_seqlen  =  0 
208+ 
209+     # Iterate over each batch of sequence lengths 
197210    for  offset  in  range (0 , len (seqlen_list ), batch_size ):
198211        cur_sum_seqlen  =  sum (seqlen_list [offset  : offset  +  batch_size ])
199212        if  min_sum_seqlen  is  None  or  cur_sum_seqlen  <  min_sum_seqlen :
@@ -206,7 +219,7 @@ def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], pr
206219    for  partition  in  partitions :
207220        cur_sum_seqlen_balanced  =  sum ([seqlen_list [i ] for  i  in  partition ])
208221        balanced_sum_seqlen_list .append (cur_sum_seqlen_balanced )
209-      # print("balanced_sum_seqlen_list: ", balanced_sum_seqlen_list) 
222+ 
210223    min_sum_seqlen_balanced  =  min (balanced_sum_seqlen_list )
211224    max_sum_seqlen_balanced  =  max (balanced_sum_seqlen_list )
212225
@@ -220,11 +233,13 @@ def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], pr
220233    }
221234
222235
223- def  ceildiv (a , b ) :
236+ def  ceildiv (a :  float , b :  float )  ->   float :
224237    return  - (a  //  - b )
225238
226239
227- def  rearrange_micro_batches (batch : TensorDict , max_token_len , dp_group = None ):
240+ def  rearrange_micro_batches (
241+     batch : TensorDict , max_token_len : int , dp_group : Optional [dist .ProcessGroup ] =  None 
242+ ) ->  Tuple [List [TensorDict ], List [List [int ]]]:
228243    """Split the batch into a list of micro_batches, where the max_token_len is smaller than max_token_len 
229244    and the number of valid tokens in each micro batch is well balanced. 
230245    """ 
@@ -253,7 +268,16 @@ def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None):
253268    return  micro_batches , micro_bsz_idx 
254269
255270
256- def  get_reverse_idx (idx_map ):
271+ def  get_reverse_idx (idx_map : List [int ]) ->  List [int ]:
272+     """ 
273+     Build the inverse of an index mapping. 
274+ 
275+     Args: 
276+         idx_map (Sequence[int]): Sequence where idx_map[i] = j. 
277+ 
278+     Returns: 
279+         List[int]: Inverse mapping list such that output[j] = i for each i. 
280+     """ 
257281    reverse_idx_map  =  copy .deepcopy (idx_map )
258282
259283    for  i , idx  in  enumerate (idx_map ):
@@ -263,20 +287,38 @@ def get_reverse_idx(idx_map):
263287
264288
265289def  prepare_dynamic_batch (data : DataProto , max_token_len : int ) ->  tuple [list [DataProto ], list [list [int ]]]:
290+     """ 
291+     Prepare a batch for dynamic batching. 
292+ 
293+     Args: 
294+         data (DataProto): The input data. 
295+         max_token_len (int): The maximum token length for dynamic batching. 
296+ 
297+     Returns: 
298+         Tuple[List[DataProto], List[List[int]]]: A tuple containing a list of DataProto objects 
299+         and a list of index lists. 
300+     """ 
266301    batch , batch_idx_list  =  rearrange_micro_batches (data .batch , max_token_len = max_token_len )
267302    micro_batches  =  []
268303    for  i , batch_idx  in  enumerate (batch_idx_list ):
269304        tensors  =  dict (batch [i ])
270-         non_tensors  =  {}
271-         for  key  in  data .non_tensor_batch .keys ():
272-             non_tensors [key ] =  [data .non_tensor_batch [key ][idx ] for  idx  in  batch_idx ]
273- 
305+         non_tensors  =  {key : value [batch_idx ] for  key , value  in  data .non_tensor_batch .items ()}
274306        micro_batches .append (DataProto .from_dict (tensors , non_tensors ))
275307
276308    return  micro_batches , batch_idx_list 
277309
278310
279311def  restore_dynamic_batch (data : torch .Tensor , batch_idx_list : List [List [int ]]) ->  torch .Tensor :
312+     """ 
313+     Restore a batch from dynamic batching. 
314+ 
315+     Args: 
316+         data (torch.Tensor): The input data. 
317+         batch_idx_list (List[List[int]]): The list of index lists. 
318+ 
319+     Returns: 
320+         torch.Tensor: The restored data. 
321+     """ 
280322    indices  =  list (chain .from_iterable (batch_idx_list ))
281323    revert_indices  =  torch .tensor (get_reverse_idx (indices ), dtype = torch .long )
282324    return  data [revert_indices ]
0 commit comments