- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 3.9k
[Feature] VLMs support for GRPO #2752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Open
      
      
            GAD-cell
  wants to merge
  21
  commits into
  unslothai:main
  
    
      
        
          
  
    
      Choose a base branch
      
     
    
      
        
      
      
        
          
          
        
        
          
            
              
              
              
  
           
        
        
          
            
              
              
           
        
       
     
  
        
          
            
          
            
          
        
       
    
      
from
GAD-cell:VLM_GRPO
  
      
      
   
  
    
  
  
  
 
  
      
    base: main
Could not load branches
            
              
  
    Branch not found: {{ refName }}
  
            
                
      Loading
              
            Could not load tags
            
            
              Nothing to show
            
              
  
            
                
      Loading
              
            Are you sure you want to change the base?
            Some commits from the old base branch may be removed from the timeline,
            and old review comments may become outdated.
          
          
      
        
          +122
        
        
          −6
        
        
          
        
      
    
  
  
     Open
                    Changes from 19 commits
      Commits
    
    
            Show all changes
          
          
            21 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      8229a9b
              
                Updated rl and rl_replacements
              
              
                GAD-cell 46fce64
              
                fixed indentation
              
              
                GAD-cell f5d3006
              
                space error
              
              
                GAD-cell 2d4a908
              
                indent fix
              
              
                GAD-cell 15530f1
              
                minor fixes
              
              
                GAD-cell 7ecc622
              
                working generate_and_score_completions
              
              
                GAD-cell d9401f7
              
                working version with hidden states trimming
              
              
                GAD-cell 0a8c3e2
              
                remove print
              
              
                GAD-cell a162fca
              
                Replace _generate_and_score_completions using function replacements i…
              
              
                GAD-cell 32571e9
              
                typo correction
              
              
                GAD-cell 3edd254
              
                indentation fix
              
              
                GAD-cell af3f5e6
              
                fixed _get_per_token_logps slicing
              
              
                GAD-cell 8e3fe8e
              
                slicing condition was off
              
              
                GAD-cell 45d5482
              
                resolve conflicts
              
              
                GAD-cell 0ba0bbc
              
                Merge branch 'main' into VLM_GRPO
              
              
                GAD-cell c073534
              
                resolve conflicts
              
              
                GAD-cell 9dbe0b9
              
                spacing + arg on new line fix
              
              
                GAD-cell 896302f
              
                Merge branch 'main' into VLM_GRPO
              
              
                GAD-cell dfb05c1
              
                efficient vlm grpo compute loss
              
              
                GAD-cell b6fec96
              
                resolve conflicts
              
              
                GAD-cell 12784de
              
                fix
              
              
                GAD-cell File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -168,8 +168,9 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch | |
|  | ||
|  | ||
| # Autocast precision for GRPO | ||
| def grpo_trainer__prepare_inputs(function_name, function): | ||
| if function_name != "_prepare_inputs": return function | ||
| def grpo_generate_and_score_completions(function_name, function): | ||
| if function_name != "_generate_and_score_completions": | ||
| return function | ||
|  | ||
| import re | ||
| # This matches the function signature, decorators and any comments immediately following | ||
|  | @@ -186,7 +187,6 @@ def grpo_trainer__prepare_inputs(function_name, function): | |
| # Find where the code block starts after comments | ||
| code_start_index = match.end(1) | ||
| rest_of_function = function[code_start_index:] | ||
|  | ||
| # Remove any old wake_up call that might be at the start of the function body | ||
| rest_of_function = re.sub( | ||
| r"^\s*if hasattr\(self, 'llm'\):.*?self\.llm\.wake_up\(\).*?\n", | ||
|  | @@ -205,6 +205,91 @@ def grpo_trainer__prepare_inputs(function_name, function): | |
| ) | ||
|  | ||
| function = header_and_comments + insert + rest_of_function | ||
|  | ||
| if """prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]""" not in function: | ||
| return function | ||
|  | ||
|  | ||
| # 1. Output pixel_values and image_grid_thw | ||
| pattern = re.compile( | ||
| r"^(?P<indent>\s*)return {\n" | ||
| r"(?P=indent) {4}\"prompt_ids\": prompt_ids,\n" | ||
| r"(?P=indent) {4}\"prompt_mask\": prompt_mask,\n" | ||
| r"(?P=indent) {4}\"completion_ids\": completion_ids,\n" | ||
| r"(?P=indent) {4}\"completion_mask\": completion_mask,\n" | ||
| r"(?P=indent) {4}\"advantages\": advantages,\n" | ||
| r"(?P=indent) {4}\"old_per_token_logps\": old_per_token_logps,\n" | ||
| r"(?P=indent)}", | ||
| re.MULTILINE | ||
| ) | ||
|  | ||
| replacement = """ return { | ||
| "prompt_ids": prompt_ids, | ||
| "prompt_mask": prompt_mask, | ||
| "pixel_values": pixel_values, | ||
| "image_grid_thw": image_grid_thw, | ||
| "completion_ids": completion_ids, | ||
| "completion_mask": completion_mask, | ||
| "advantages": advantages, | ||
| "old_per_token_logps": old_per_token_logps, | ||
| }""" | ||
| function = re.sub(pattern, replacement, function) | ||
|  | ||
| # 2. Replace the prompt_completion_ids generation | ||
| pattern = re.compile( | ||
| r"^(?P<indent>\s*)prompt_completion_ids = unwrapped_model\.generate\(\n" | ||
| r"(?P=indent) {4}prompt_ids, attention_mask=prompt_mask, generation_config=self\.generation_config\n" | ||
| r"(?P=indent)\)", | ||
| re.MULTILINE | ||
| ) | ||
|  | ||
| replacement = """ if self.use_vision : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask,pixel_values = pixel_values,image_grid_thw=image_grid_thw, generation_config=self.generation_config) | ||
| else : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config)""" | ||
|  | ||
| function = pattern.sub(replacement, function) | ||
|  | ||
| # 3. Replace the old_per_token_logps generation | ||
| pattern = re.compile( | ||
| r"^(?P<indent>\s*)old_per_token_logps = self\._get_per_token_logps\(\n" | ||
| r"(?P=indent) {4}self\.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size\n" | ||
| r"(?P=indent)\)", | ||
| re.MULTILINE | ||
| ) | ||
|  | ||
| replacement = """ old_per_token_logps = self._get_per_token_logps( | ||
| self.model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep, batch_size | ||
| )""" | ||
|  | ||
| function = re.sub(pattern, replacement, function) | ||
|  | ||
| # 4. Replace the prompt processing section | ||
| pattern = re.compile( | ||
| r"^(?P<indent>\s*)prompts = \[x\[\"prompt\"\] for x in inputs\]\n" | ||
| r"(?P=indent)prompts_text = \[maybe_apply_chat_template\(example, self\.processing_class\)\[\"prompt\"\] for example in inputs\]\n" | ||
| r"(?P=indent)prompt_inputs = self\.processing_class\(\n" | ||
| r"(?P=indent) {4}text=prompts_text, return_tensors=\"pt\", padding=True, padding_side=\"left\", add_special_tokens=False\n" | ||
| r"(?P=indent)\)\n" | ||
| r"(?P=indent)prompt_inputs = super\(\)\._prepare_inputs\(prompt_inputs\)" | ||
| , | ||
| re.MULTILINE | ||
| ) | ||
|  | ||
| replacement = """ prompts = [x["prompt"] for x in inputs] | ||
| prompts_text = [maybe_apply_chat_template(example, self.processing_class)['prompt'] for example in inputs] | ||
| if not self.use_vision: | ||
| pixel_values, image_grid_thw = None, None | ||
| prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False) | ||
| prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
| else: | ||
| images = [x['image'] for x in inputs] # Only image inputs support for now | ||
| prompt_inputs = self.processing_class(images=images, text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False) | ||
| prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
| pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw']""" | ||
|  | ||
| function = pattern.sub(replacement, function) | ||
|  | ||
|  | ||
|  | ||
| # Add mixed precision training | ||
| function = function.replace( | ||
| "with torch.inference_mode():", | ||
|  | @@ -230,7 +315,24 @@ def grpo_trainer__prepare_inputs(function_name, function): | |
| function = function.rstrip() + "\n " + sleep_and_cache | ||
| return function | ||
| pass | ||
| RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs) | ||
|  | ||
| RL_FUNCTIONS["grpo_trainer"].append(grpo_generate_and_score_completions) | ||
|  | ||
| def grpo_prepare_inputs(function_name, function): | ||
| if function_name != "_prepare_inputs": return function | ||
|  | ||
| if "generation_batch = self._generate_and_score_completions(generation_batch)" not in function : return function | ||
|  | ||
| function = function.replace( | ||
| "generation_batch = self._generate_and_score_completions(generation_batch)", | ||
|  | ||
| "generation_batch = self._generate_and_score_completions(generation_batch)\n"\ | ||
| " if self.use_vision : generation_batch['pixel_values']=generation_batch['pixel_values'].view(generation_batch['prompt_ids'].size(0), -1, generation_batch['pixel_values'].size(1)) # (batch_size * n_patches, dim embedding)->(batch_size,n_patches,dim embeddding)" | ||
| ) | ||
|  | ||
| return function | ||
| pass | ||
| RL_FUNCTIONS["grpo_trainer"].append(grpo_prepare_inputs) | ||
|  | ||
|  | ||
| # Remove _move_model_to_vllm | ||
|  | @@ -249,7 +351,7 @@ def _move_model_to_vllm(self, *args, **kwargs): return None | |
| def grpo_trainer__get_per_token_logps(function_name, function): | ||
| if function_name != "_get_per_token_logps": return function | ||
|  | ||
| def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): | ||
| def _get_per_token_logps(self, model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep): | ||
| if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': | ||
| return None # Unsloth efficient GRPO | ||
| # Otherwise, calculate normally: | ||
|  | @@ -311,22 +413,23 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch | |
|  | ||
| prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] | ||
| completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] | ||
| pixel_values, image_grid_thw = inputs.get("pixel_values", None), inputs.get("image_grid_thw", None) | ||
|  | ||
| input_ids = torch.cat([prompt_ids, completion_ids], dim=1) | ||
| bsz, qlen = input_ids.shape | ||
| attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) | ||
| # attention_mask = None | ||
| logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens | ||
| _input_ids = input_ids | ||
| _logits_to_keep = logits_to_keep | ||
|  | ||
| per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) | ||
| _logits_to_keep = logits_to_keep | ||
| per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep) | ||
|  | ||
| # Compute the KL divergence between the model and the reference model | ||
| # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. | ||
| # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328 | ||
| if self.beta != 0.0: | ||
| with torch.inference_mode(), model.disable_adapter(): | ||
| ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) | ||
| ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep) | ||
| else: | ||
| ref_per_token_logps = None | ||
| # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 | ||
|  | @@ -376,6 +479,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch | |
| loss, completion_length, mean_kl = grpo_accumulated_loss( | ||
| trainer = self, | ||
| input_ids = _input_ids, | ||
| pixel_values = pixel_values, | ||
| image_grid_thw = image_grid_thw, | ||
| logits_to_keep = logits_to_keep, | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changes for efficient path here | ||
| completion_mask = completion_mask, | ||
| advantages = advantages, | ||
|  | @@ -397,6 +502,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch | |
| loss, completion_length, mean_kl = grpo_accumulated_loss( | ||
| trainer = self, | ||
| input_ids = _input_ids, | ||
| pixel_values = pixel_values, | ||
| image_grid_thw = image_grid_thw, | ||
| logits_to_keep = logits_to_keep, | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and here | ||
| completion_mask = completion_mask, | ||
| advantages = advantages, | ||
|  | ||
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.