2525import  torch 
2626import  inspect 
2727from  collections  import  defaultdict 
28- from  unsloth_zoo .rl_replacements  import  RL_REPLACEMENTS 
28+ from  unsloth_zoo .rl_replacements  import  RL_REPLACEMENTS ,  left_pack_padding 
2929from  unsloth  import  DEVICE_TYPE 
3030
3131RL_EXTRA_ARGS       =  defaultdict (list )
@@ -245,6 +245,18 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
245245        "prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False" ,
246246    )
247247
248+     # Left pad prompt before calculation old and ref hidden states 
249+     line_to_replace  =  "batch_size = self.args.per_device_train_batch_size if mode == \" train\"  else self.args.per_device_eval_batch_size" 
250+ 
251+     # The new multi-line string that will replace the line above 
252+     replacement_lines  =  """ 
253+         batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size 
254+         if not has_images: 
255+             # Left pad prompt before calculation old and ref hidden states 
256+             prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id)""" 
257+ 
258+     function  =  function .replace (line_to_replace , replacement_lines )
259+ 
248260    # Always between max_prompt_length and use_vllm 
249261    found  =  re .findall (
250262        r"\n(([ ]{8,})if self\.max_prompt_length is not None:.*?" \
@@ -263,26 +275,27 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
263275            # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. 
264276            # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, 
265277            # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). 
266-             prompt_ids  = prompt_ids[:, - self.max_prompt_length : ] 
267-             prompt_mask  = prompt_mask[:, -self.max_prompt_length : ] 
268-             prompts_text = self.processing_class.batch_decode ( 
269-                 prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False  
278+             protected  = [self.image_token_id,  self.vision_start_token_id, self.vision_end_token_id ] 
279+             protected  = [token for token in protected if token is not None ] 
280+             prompt_ids, prompt_mask = truncate_with_protected_tokens ( 
281+                 prompt_ids, prompt_mask, self.max_prompt_length, protected  
270282            ) 
271-             pad_token = self.processing_class.pad_token 
272-             def strip_leading_tokens(text): 
273-                 while text.startswith(pad_token): 
274-                     text = text.removeprefix(pad_token) 
275-                 return text 
276283
277-             if pad_token is not None: 
284+             prompts_text = [re.sub(rf"^({{re.escape(self.pad_token)}})+", "", text) for text in prompts_text] 
285+ 
286+             # The chat template inserts a single image token into the prompt text. However, when this text is later 
287+             # tokenized, the single image token string is expanded into multiple image token IDs, depending on the 
288+             # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We 
289+             # collapse them back into a single token string to match the original template. 
290+             if self.image_token is not None: 
278291                prompts_text = [ 
279-                     strip_leading_tokens( text) for text in prompts_text 
292+                     re.sub(rf"({{re.escape(self.image_token)}})+", self.image_token,  text) for text in prompts_text 
280293                ] 
281- 
282294        # Generate completions using either vLLM or regular generation 
283295        if self.use_vllm:""" 
284296            function  =  function .replace (replace_part , new_replacement )
285-     pass 
297+ 
298+ 
286299    return  function 
287300pass 
288301RL_FUNCTIONS ["grpo_trainer" ].append (grpo_trainer__generate_and_score_completions )
@@ -304,7 +317,7 @@ def _move_model_to_vllm(self, *args, **kwargs): return None
304317def  grpo_trainer__get_per_token_logps (function_name , function ):
305318    if  function_name  !=  "_get_per_token_logps" : return  function 
306319
307-     def  _get_per_token_logps (self , model , input_ids , attention_mask , logits_to_keep ):
320+     def  _get_per_token_logps (self , model , input_ids , attention_mask , logits_to_keep ,  compute_efficient   =   False ):
308321        if  True : # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': 
309322            return  None  # Unsloth efficient GRPO 
310323        # Otherwise, calculate normally: 
@@ -350,28 +363,59 @@ def grpo_trainer__get_per_token_logps_and_entropies(function_name, function):
350363    if  function_name  !=  "_get_per_token_logps_and_entropies" : return  function 
351364
352365    # Just copy over from _get_per_token_logps replacement function above. For now this returns None anyway 
353-     def  _get_per_token_logps_and_entropies (self , model , input_ids , attention_mask , logits_to_keep , batch_size  =  None , compute_entropy  =  False , * args , ** kwargs ):
354-         if  True : # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': 
355-             return  None , None   # logps, entropies Unsloth efficient GRPO 
356-         # Otherwise, calculate normally: 
357-         if  not  hasattr (self , '_autocast_dtype' ):
358-             self ._autocast_dtype  =  torch .float16  if  os .environ .get ('ACCELERATE_MIXED_PRECISION' , 'fp16' ) ==  'fp16'  else  torch .bfloat16 
359-             if  os .environ .get ('UNSLOTH_FORCE_FLOAT32' , '0' ) ==  '1' : self ._autocast_dtype  =  torch .float16 
360- 
361-         os .environ ["UNSLOTH_RETURN_HIDDEN_STATES" ] =  "1" 
362-         with  torch .amp .autocast (device_type  =  'cuda' , dtype  =  self ._autocast_dtype ):
363-             # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded 
364-             logits  =  model (
365-                 input_ids  =  input_ids ,
366-                 attention_mask  =  attention_mask ,
367-                 logits_to_keep  =  logits_to_keep  +  1 ,
368-             ).logits 
369- 
370-             entropies  =  None 
371-             if  compute_entropy :
372-                 from  trl .trainer .utils  import  entropy_from_logits 
373-                 entropies  =  entropy_from_logits (logits )
374- 
366+     def  _get_per_token_logps_and_entropies (self , model , input_ids , attention_mask , logits_to_keep , batch_size  =  None , 
367+                                            compute_entropy  =  False , compute_efficient  =  False , * args , ** kwargs ):
368+         # if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': 
369+         #     return None, None  # logps, entropies Unsloth efficient GRPO 
370+         if  compute_efficient :
371+             return  None , None 
372+         else : 
373+             # Otherwise, calculate normally: 
374+             if  not  hasattr (self , '_autocast_dtype' ):
375+                 self ._autocast_dtype  =  torch .float16  if  os .environ .get ('ACCELERATE_MIXED_PRECISION' , 'fp16' ) ==  'fp16'  else  torch .bfloat16 
376+                 if  os .environ .get ('UNSLOTH_FORCE_FLOAT32' , '0' ) ==  '1' : self ._autocast_dtype  =  torch .float16 
377+ 
378+             pixel_values , image_grid_thw  =  kwargs .get ("pixel_values" , None ), kwargs .get ("image_grid_thw" , None )
379+             pixel_attention_mask , image_sizes  =  kwargs .get ('pixel_attention_mask' ,None ), kwargs .get ('image_sizes' ,None )
380+ 
381+             os .environ ["UNSLOTH_RETURN_HIDDEN_STATES" ] =  "1" 
382+            
383+             unwrapped_model  =  self .accelerator .unwrap_model (model , keep_fp32_wrapper = False )
384+ 
385+             with  torch .amp .autocast (device_type  =  'cuda' , dtype  =  self ._autocast_dtype ):
386+                 with  torch .inference_mode ():
387+                     if  pixel_values  is  None : 
388+                         attention_mask  =   input_ids  !=  self .processing_class .pad_token_id 
389+                         attention_mask  =  attention_mask .to (attention_mask .dtype )
390+                         # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded 
391+                         logits  =  unwrapped_model (
392+                             input_ids  =  input_ids ,
393+                             attention_mask  =  attention_mask ,
394+                             pixel_values  =  pixel_values ,
395+                             image_grid_thw  =  image_grid_thw ,
396+                             pixel_attention_mask  =  pixel_attention_mask ,
397+                             image_sizes  =  image_sizes ,
398+                             #logits_to_keep = logits_to_keep + 1, 
399+                         ).logits 
400+                     else :
401+                         logits  =  unwrapped_model (
402+                             input_ids  =  input_ids ,
403+                             attention_mask  =  attention_mask ,
404+                             pixel_values  =  pixel_values ,
405+                             image_grid_thw  =  image_grid_thw ,
406+                             pixel_attention_mask  =  pixel_attention_mask ,
407+                             image_sizes  =  image_sizes ,
408+                             logits_to_keep  =  logits_to_keep  +  1 ,
409+                         ).logits 
410+                     
411+ 
412+                 entropies  =  None 
413+                 if  compute_entropy :
414+                     from  trl .trainer .utils  import  entropy_from_logits 
415+                     entropies  =  entropy_from_logits (logits )
416+ 
417+ 
418+             os .environ ["UNSLOTH_RETURN_HIDDEN_STATES" ] =  "0" 
375419            # logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred 
376420            return  logits , entropies   # logps, entropies 
377421            # input_ids = input_ids[:, -logits_to_keep:] 
@@ -416,8 +460,12 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
416460            raise  ValueError ("The GRPOTrainer does not support returning outputs" )
417461        # Compute the per-token log probabilities for the model 
418462
463+ 
419464        prompt_ids , prompt_mask  =  inputs ["prompt_ids" ], inputs ["prompt_mask" ]
420465        completion_ids , completion_mask  =  inputs ["completion_ids" ], inputs ["completion_mask" ]
466+         pixel_values , image_grid_thw  =  inputs .get ("pixel_values" , None ), inputs .get ("image_grid_thw" , None )
467+         pixel_attention_mask , image_sizes  =  inputs .get ('pixel_attention_mask' ,None ), inputs .get ('image_sizes' ,None )
468+ 
421469        input_ids  =  torch .cat ([prompt_ids , completion_ids ], dim = 1 )
422470        bsz , qlen  =  input_ids .shape 
423471        attention_mask  =  torch .cat ([prompt_mask , completion_mask ], dim = 1 )
@@ -427,28 +475,29 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
427475        _logits_to_keep  =  logits_to_keep 
428476
429477        get_logps_func  =  \
430-             lambda  model , input_ids , attention_mask , logits_to_keep , batch_size = None , compute_entropy = False : \
431-             self ._get_per_token_logps (model , input_ids , attention_mask , logits_to_keep ) \
478+             lambda  model , input_ids , attention_mask , logits_to_keep , batch_size = None , compute_entropy = False ,  compute_efficient   =   False : \
479+             self ._get_per_token_logps (model , input_ids , attention_mask , logits_to_keep ,  compute_efficient ) \
432480            if  hasattr (self , "_get_per_token_logps" ) else  \
433-             self ._get_per_token_logps_and_entropies (model , input_ids , attention_mask , logits_to_keep , batch_size , compute_entropy )[0 ]  # logps 
434- 
435-         per_token_logps  =  get_logps_func (model , input_ids , attention_mask , logits_to_keep )
436- 
481+             self ._get_per_token_logps_and_entropies (model , input_ids , attention_mask , logits_to_keep , batch_size , compute_entropy , compute_efficient )[0 ]  # logps 
482+         #breakpoint() 
483+         per_token_logps  =  get_logps_func (model , input_ids , attention_mask , logits_to_keep , compute_efficient  =  True )
437484        # Compute the KL divergence between the model and the reference model 
438485        # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. 
439486        # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328 
440-         if  self .beta  !=  0.0 :
441-             with  torch .inference_mode (), model .disable_adapter ():
442-                 ref_per_token_logps  =  per_token_logps  =  get_logps_func (model , input_ids , attention_mask , logits_to_keep )
443-         else :
444-             ref_per_token_logps  =  None 
487+         # if self.beta != 0.0: 
488+         #     with torch.inference_mode(), model.disable_adapter(): 
489+         #         ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep) 
490+         # else: 
491+         #     ref_per_token_logps = None 
492+         ref_hidden_states  =  inputs .get ("ref_per_token_logps" , None )
445493        # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 
446494        # x - x.detach() allows for preserving gradients from x 
447495        advantages  =  inputs ["advantages" ]
448496        # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) 
449497        # per_token_loss = -(per_token_loss - self.beta * per_token_kl) 
450498        # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() 
451499        old_hidden_states  =  inputs .get ("old_per_token_logps" , None )
500+         
452501        input_ids  =  input_ids [:, - logits_to_keep :]
453502
454503        # Get logit softcapping and logit scale 
@@ -459,22 +508,26 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
459508        logit_scale_divide  =  getattr (model .config , "logits_scaling" , 0 ) # Granite 
460509        if  logit_scale_divide  is  None : logit_scale_divide  =  0 
461510
462- 
463511        if  per_token_logps  is  not None :
464512
465-             if  ref_per_token_logps  is  not None :
466-                 ref_per_token_logps  =  ref_per_token_logps [:, :- 1 , :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred 
513+             if  ref_hidden_states  is  not None :
514+                 ref_hidden_states  =  ref_hidden_states [:, :- 1 , :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred 
515+             if  old_hidden_states  is  not None :
516+                 old_hidden_states  =  old_hidden_states [:, :- 1 , :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred 
467517            per_token_logps  =  per_token_logps [:, :- 1 , :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred 
468518
469519            loss , completion_length , mean_kl  =  grpo_compute_loss_slow (
470-                 ref_per_token_logps ,
520+                 ref_hidden_states ,
471521                per_token_logps ,
472522                old_hidden_states ,
473523                input_ids ,
474524                completion_mask ,
475525                self .beta ,
476526                advantages ,
527+                 pixel_values  =  pixel_values ,
528+                 image_grid_thw  =  image_grid_thw ,
477529                loss_type  =  self .args .loss_type ,
530+                 importance_sampling_level  =  self .importance_sampling_level ,
478531                epsilon_low  =  self .epsilon_low ,
479532                epsilon_high  =  self .epsilon_high ,
480533                max_completion_length  =  self .args .max_completion_length ,
@@ -489,12 +542,16 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
489542                loss , completion_length , mean_kl  =  grpo_accumulated_loss (
490543                    trainer  =  self ,
491544                    input_ids  =  _input_ids ,
545+                     pixel_values  =  pixel_values ,
546+                     image_grid_thw  =  image_grid_thw ,
492547                    logits_to_keep  =  logits_to_keep ,
493548                    completion_mask  =  completion_mask ,
494549                    advantages  =  advantages ,
495550                    old_hidden_states  =  old_hidden_states ,
551+                     ref_hidden_states  =  ref_hidden_states ,
496552                    n_chunks  =  self .args .unsloth_num_chunks ,
497553                    loss_type  =  self .args .loss_type ,
554+                     importance_sampling_level  =  self .importance_sampling_level ,
498555                    epsilon_low  =  self .epsilon_low ,
499556                    epsilon_high  =  self .epsilon_high ,
500557                    max_completion_length  =  self .args .max_completion_length ,
@@ -514,6 +571,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
514571                    completion_mask  =  completion_mask ,
515572                    advantages  =  advantages ,
516573                    old_hidden_states  =  old_hidden_states ,
574+                     ref_hidden_states  =  ref_hidden_states ,
517575                    n_chunks  =  self .args .unsloth_num_chunks ,
518576                    temperature  =  self .args .temperature ,
519577                    logit_softcapping  =  logit_softcapping ,
0 commit comments