@@ -60,7 +60,66 @@ def chunked_selective_log_softmax(logits, index):
6060 return all_per_token_logps
6161pass
6262
63+ def calculate_pad_tokens_in_prompt (
64+ input_ids : torch .Tensor ,
65+ logits_to_keep : int ,
66+ pad_token_id : int
67+ ) -> torch .Tensor :
68+ """
69+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
70+ """
71+ if logits_to_keep >= input_ids .shape [1 ]:
72+ raise ValueError ("logits_to_keep must be smaller than the sequence length." )
73+
74+ prompt_section = input_ids [:, :- logits_to_keep ]
75+
76+ padding_mask = (prompt_section == pad_token_id )
77+
78+ pad_token_counts = padding_mask .sum (dim = 1 )
79+
80+ return pad_token_counts
81+
82+ def create_completion_attention_mask (
83+ completion_input_ids : torch .Tensor ,
84+ left_pad_tokens_per_prompt : torch .Tensor ,
85+ max_left_pad : int ,
86+ pad_token_id : int
87+ ) -> torch .Tensor :
88+ """
89+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
90+
91+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
92+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
93+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
94+ """
95+ batch_size , completion_len = completion_input_ids .shape
96+ device = completion_input_ids .device
97+
98+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
99+
100+ indices = torch .arange (completion_len , device = device ).unsqueeze (0 )
101+ shift_mask = indices >= num_tokens_to_mask .unsqueeze (1 )
102+
103+ non_padding_mask = (completion_input_ids != pad_token_id )
104+
105+ final_mask = shift_mask & non_padding_mask
106+
107+ return final_mask
108+
109+ def left_pack_padding (tensor : torch .Tensor , pad_id : int ) -> torch .Tensor :
110+ """
111+ Moves all padding tokens in each sequence of a batch to the right.
112+ """
113+ mask = (tensor != pad_id )
114+ sorted_indices = torch .argsort (mask , dim = 1 , descending = True , stable = True )
115+ packed_tensor = torch .gather (tensor , 1 , sorted_indices )
116+
117+ return packed_tensor
118+
63119RL_REPLACEMENTS ["selective_log_softmax" ] = chunked_selective_log_softmax
120+ RL_REPLACEMENTS ["create_completion_attention_mask" ] = create_completion_attention_mask
121+ RL_REPLACEMENTS ["calculate_pad_tokens_in_prompt" ] = calculate_pad_tokens_in_prompt
122+ RL_REPLACEMENTS ["left_pack_padding" ] = left_pack_padding
64123
65124
66125# Custom compiled GRPO loss - creates 3 Triton kernels
@@ -85,6 +144,7 @@ def grpo_compute_loss(
85144 logit_scale_multiply = kwargs .get ("logit_scale_multiply" , 0.0 )
86145 logit_scale_divide = kwargs .get ("logit_scale_divide" , 0.0 )
87146 logit_softcapping = kwargs .get ("logit_softcapping" , 0.0 )
147+ importance_sampling_level = kwargs .get ("importance_sampling_level" , "token" )
88148
89149 input_ids = input_ids .unsqueeze (- 1 )
90150
@@ -98,7 +158,6 @@ def grpo_compute_loss(
98158 if temperature != 1.0 : new_logits = new_logits / temperature
99159 new_x = torch .gather (new_logits , dim = - 1 , index = input_ids ).squeeze (- 1 )
100160 new = new_x - torch .logsumexp (new_logits , dim = - 1 )
101-
102161 # x_i - logsumexp(x_i)
103162 with torch .no_grad ():
104163 if beta != 0.0 :
@@ -143,9 +202,23 @@ def grpo_compute_loss(
143202 # Below is forward KL (normal KL)
144203 # kl_i = torch.exp(old) * (old - new)
145204 if old_logits is not None :
146- coef_1 = torch .exp (new - old )
205+ log_ratio = new - old
206+ else :
207+ log_ratio = new - new .detach ()
208+
209+ if importance_sampling_level == "token" :
210+ log_importance_weights = log_ratio
211+ elif importance_sampling_level == "sequence" :
212+ log_importance_weights = (log_ratio * mask ).sum (- 1 ) / mask .sum (- 1 ).clamp (min = 1.0 )
213+ log_importance_weights = log_importance_weights .unsqueeze (- 1 )
147214 else :
148- coef_1 = torch .exp (new - new .detach ())
215+ raise ValueError (
216+ f"Unknown importance sampling level: { importance_sampling_level } . Possible values are 'token' "
217+ "and 'sequence'."
218+ )
219+
220+ coef_1 = torch .exp (log_importance_weights )
221+
149222 coef_2 = torch .clamp (coef_1 , 1 - epsilon_low , 1 + epsilon_high )
150223
151224 if delta is not None :
@@ -178,15 +251,19 @@ def grpo_compute_loss(
178251
179252 # loss = (loss_i * mask).sum() / mask.sum()
180253
181- # Get metrics as well which are folded
182- with torch .inference_mode ():
183- completion_length = n_mask_per_reward .mean ()
184- n_mask_per_reward = n_mask_per_reward .clamp (min = 1.0 ) # Counteracts division by 0
185- mean_kl_per_reward = (kl_i * mask ).sum (1 ) / n_mask_per_reward
186- mean_kl = mean_kl_per_reward .mean ()
187- pass
254+ completion_length = n_mask_per_reward .mean ()
188255
189- return loss , completion_length , mean_kl
256+ # Get metrics as well which are folded
257+ def masked_batch_mean (x ):
258+ with torch .inference_mode ():
259+ if x .shape [1 ] == 1 : # when importance_sampling_level == "sequence"
260+ return x .mean ()
261+ else :
262+ mean_kl_per_reward = (kl_i * mask ).sum (1 ) / n_mask_per_reward
263+ mean_kl = mean_kl_per_reward .mean ()
264+ return mean_kl
265+
266+ return loss , completion_length , masked_batch_mean (kl_i )
190267pass
191268RL_REPLACEMENTS ["grpo_compute_loss" ] = grpo_compute_loss
192269RL_REPLACEMENTS ["grpo_compute_loss_slow" ] = \
@@ -283,7 +360,10 @@ def accumulate_chunk(
283360 old_hidden_states = torch .chunk (_old_hidden_states , chunks = n_chunks , dim = 0 )
284361 else :
285362 old_hidden_states = [None ] * n_chunks
286- ref_hidden_states = torch .chunk (_ref_hidden_states , chunks = n_chunks , dim = 0 )
363+ if _ref_hidden_states is not None :
364+ ref_hidden_states = torch .chunk (_ref_hidden_states , chunks = n_chunks , dim = 0 )
365+ else :
366+ ref_hidden_states = [None ] * n_chunks
287367 input_ids = torch .chunk (_input_ids , chunks = n_chunks , dim = 0 )
288368 mask = torch .chunk (_mask , chunks = n_chunks , dim = 0 )
289369 advantages = torch .chunk (_advantages , chunks = n_chunks , dim = 0 )
@@ -347,12 +427,17 @@ def grpo_accumulated_loss(
347427 completion_mask ,
348428 advantages ,
349429 old_hidden_states ,
430+ ref_hidden_states ,
350431 n_chunks = - 1 ,
351432 ** kwargs ,
352433):
353434 # All Unsloth Zoo code licensed under LGPLv3
354435 bsz , qlen = input_ids .shape
355436
437+ pixel_values = kwargs .get ('pixel_values' ,None )
438+ image_grid_thw = kwargs .get ('image_grid_thw' ,None )
439+ pixel_attention_mask = kwargs .get ('pixel_attention_mask' ,None )
440+ image_sizes = kwargs .get ('image_sizes' ,None )
356441 # Find closest multiple
357442 factors = [i for i in range (1 , bsz + 1 ) if bsz % i == 0 ]
358443 if n_chunks == - 1 : n_chunks = bsz
@@ -364,41 +449,72 @@ def grpo_accumulated_loss(
364449 pass
365450 os .environ ["UNSLOTH_RETURN_HIDDEN_STATES" ] = "1"
366451
367- completion_input_ids = input_ids [:, - logits_to_keep :]
368452 lm_head = trainer .model .get_output_embeddings ().weight
369453
370- with torch .amp .autocast (device_type = trainer .model .device .type , dtype = trainer ._autocast_dtype ):
371- with torch .inference_mode (), trainer .accelerator .unwrap_model (trainer .model , keep_fp32_wrapper = False ).disable_adapter ():
372- ref_hidden_states = trainer .model (
454+ if pixel_values is None :
455+ left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt (input_ids , logits_to_keep , trainer .processing_class .pad_token_id )
456+
457+ max_left_pad = max (left_pad_tokens_per_prompt ).item ()
458+
459+ input_ids = left_pack_padding (input_ids , trainer .processing_class .pad_token_id )
460+
461+ completion_input_ids = input_ids [:, - (logits_to_keep + max_left_pad ):]
462+
463+ completion_mask = create_completion_attention_mask (completion_input_ids , left_pad_tokens_per_prompt , max_left_pad , trainer .processing_class .pad_token_id ).to (attention_mask .dtype )
464+ attention_mask = input_ids != trainer .processing_class .pad_token_id
465+ attention_mask = attention_mask .to (attention_mask .dtype )
466+ else :
467+ completion_input_ids = input_ids [:, - logits_to_keep :]
468+
469+ unwrapped_model = trainer .accelerator .unwrap_model (trainer .model , keep_fp32_wrapper = False )
470+ with torch .amp .autocast (device_type = trainer .model .device .type , dtype = trainer ._autocast_dtype ):
471+ if pixel_values is None :
472+ new_hidden_states = unwrapped_model (
473+ input_ids = input_ids ,
474+ attention_mask = attention_mask ,
475+ pixel_values = pixel_values ,
476+ image_grid_thw = image_grid_thw ,
477+ pixel_attention_mask = pixel_attention_mask ,
478+ image_sizes = image_sizes ,
479+ #logits_to_keep = logits_to_keep + 1,
480+ ).logits
481+
482+ #keep extra logit as we generated a new token
483+ new_hidden_states = new_hidden_states [:, - (logits_to_keep + max_left_pad + 1 ): , :]
484+ if ref_hidden_states is not None :
485+ ref_hidden_states = ref_hidden_states [:, - (logits_to_keep + max_left_pad + 1 ): , :]
486+ if old_hidden_states is not None :
487+ old_hidden_states = old_hidden_states [:, - (logits_to_keep + max_left_pad + 1 ): , :]
488+ else :
489+ new_hidden_states = unwrapped_model (
373490 input_ids = input_ids ,
374491 attention_mask = attention_mask ,
492+ pixel_values = pixel_values ,
493+ image_grid_thw = image_grid_thw ,
494+ pixel_attention_mask = pixel_attention_mask ,
495+ image_sizes = image_sizes ,
375496 logits_to_keep = logits_to_keep + 1 ,
376497 ).logits
377- pass
378- new_hidden_states = trainer .model (
379- input_ids = input_ids ,
380- attention_mask = attention_mask ,
381- logits_to_keep = logits_to_keep + 1 ,
382- ).logits
383-
384- loss , completion_length , mean_kl = UnslothEfficientGRPO .apply (
385- new_hidden_states ,
386- old_hidden_states ,
387- ref_hidden_states ,
388- lm_head ,
389- completion_input_ids ,
390- completion_mask ,
391- advantages ,
392- trainer .beta ,
393- trainer .accelerator .scaler ,
394- n_chunks ,
395- kwargs # pass kwargs as a dict
396- )
498+
499+ loss , completion_length , mean_kl = UnslothEfficientGRPO .apply (
500+ new_hidden_states ,
501+ old_hidden_states ,
502+ ref_hidden_states ,
503+ lm_head ,
504+ completion_input_ids ,
505+ completion_mask ,
506+ advantages ,
507+ trainer .beta ,
508+ trainer .accelerator .scaler ,
509+ n_chunks ,
510+ kwargs # pass kwargs as a dict
511+ )
397512 pass
513+
398514 # Must force not returning hidden states but logits otherwise gibberish
399515 os .environ ["UNSLOTH_RETURN_HIDDEN_STATES" ] = "0"
400- return loss , completion_length , mean_kl
401516
517+ return loss , completion_length , mean_kl
402518 # Old non efficient code path
403519 new_logits = torch .matmul (new_hidden_states , lm_head .t ())
404520 new_logits = new_logits [:, :- 1 , :] # exclude the last logit: it corresponds to the next token pred
0 commit comments