- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 3.9k
[Feature] VLMs support for GRPO #2752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 16 commits
8229a9b
              46fce64
              f5d3006
              2d4a908
              15530f1
              7ecc622
              d9401f7
              0a8c3e2
              a162fca
              32571e9
              3edd254
              af3f5e6
              8e3fe8e
              45d5482
              0ba0bbc
              c073534
              9dbe0b9
              896302f
              dfb05c1
              b6fec96
              12784de
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -168,8 +168,9 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch | |
|  | ||
|  | ||
| # Autocast precision for GRPO | ||
| def grpo_trainer__prepare_inputs(function_name, function): | ||
| if function_name != "_prepare_inputs": return function | ||
| def grpo_generate_and_score_completions(function_name, function): | ||
|         
                  GAD-cell marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| if function_name != "_generate_and_score_completions": | ||
| return function | ||
|  | ||
| import re | ||
| # This matches the function signature, decorators and any comments immediately following | ||
|  | @@ -186,7 +187,6 @@ def grpo_trainer__prepare_inputs(function_name, function): | |
| # Find where the code block starts after comments | ||
| code_start_index = match.end(1) | ||
| rest_of_function = function[code_start_index:] | ||
|  | ||
| # Remove any old wake_up call that might be at the start of the function body | ||
| rest_of_function = re.sub( | ||
| r"^\s*if hasattr\(self, 'llm'\):.*?self\.llm\.wake_up\(\).*?\n", | ||
|  | @@ -205,6 +205,91 @@ def grpo_trainer__prepare_inputs(function_name, function): | |
| ) | ||
|  | ||
| function = header_and_comments + insert + rest_of_function | ||
|  | ||
| if """prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]""" not in function: | ||
| return function | ||
|  | ||
|  | ||
| # 1. Output pixel_values and image_grid_thw | ||
| pattern = re.compile( | ||
| r"^(?P<indent>\s*)return {\n" | ||
| r"(?P=indent) {4}\"prompt_ids\": prompt_ids,\n" | ||
| r"(?P=indent) {4}\"prompt_mask\": prompt_mask,\n" | ||
| r"(?P=indent) {4}\"completion_ids\": completion_ids,\n" | ||
| r"(?P=indent) {4}\"completion_mask\": completion_mask,\n" | ||
| r"(?P=indent) {4}\"advantages\": advantages,\n" | ||
| r"(?P=indent) {4}\"old_per_token_logps\": old_per_token_logps,\n" | ||
| r"(?P=indent)}", | ||
| re.MULTILINE | ||
| ) | ||
|  | ||
| replacement = """ return { | ||
| "prompt_ids": prompt_ids, | ||
| "prompt_mask": prompt_mask, | ||
| "pixel_values": pixel_values, | ||
| "image_grid_thw": image_grid_thw, | ||
| "completion_ids": completion_ids, | ||
| "completion_mask": completion_mask, | ||
| "advantages": advantages, | ||
| "old_per_token_logps": old_per_token_logps, | ||
| }""" | ||
| function = re.sub(pattern, replacement, function) | ||
|  | ||
| # 2. Replace the prompt_completion_ids generation | ||
| pattern = re.compile( | ||
| r"^(?P<indent>\s*)prompt_completion_ids = unwrapped_model\.generate\(\n" | ||
| r"(?P=indent) {4}prompt_ids, attention_mask=prompt_mask, generation_config=self\.generation_config\n" | ||
| r"(?P=indent)\)", | ||
| re.MULTILINE | ||
| ) | ||
|  | ||
| replacement = """ if self.use_vision : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask,pixel_values = pixel_values,image_grid_thw=image_grid_thw, generation_config=self.generation_config) | ||
| else : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config)""" | ||
|  | ||
| function = pattern.sub(replacement, function) | ||
|  | ||
| # 3. Replace the old_per_token_logps generation | ||
| pattern = re.compile( | ||
| r"^(?P<indent>\s*)old_per_token_logps = self\._get_per_token_logps\(\n" | ||
| r"(?P=indent) {4}self\.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size\n" | ||
| r"(?P=indent)\)", | ||
| re.MULTILINE | ||
| ) | ||
|  | ||
| replacement = """ old_per_token_logps = self._get_per_token_logps( | ||
| self.model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep, batch_size | ||
| )""" | ||
|  | ||
| function = re.sub(pattern, replacement, function) | ||
|  | ||
| # 4. Replace the prompt processing section | ||
| pattern = re.compile( | ||
| r"^(?P<indent>\s*)prompts = \[x\[\"prompt\"\] for x in inputs\]\n" | ||
| r"(?P=indent)prompts_text = \[maybe_apply_chat_template\(example, self\.processing_class\)\[\"prompt\"\] for example in inputs\]\n" | ||
| r"(?P=indent)prompt_inputs = self\.processing_class\(\n" | ||
| r"(?P=indent) {4}text=prompts_text, return_tensors=\"pt\", padding=True, padding_side=\"left\", add_special_tokens=False\n" | ||
| r"(?P=indent)\)\n" | ||
| r"(?P=indent)prompt_inputs = super\(\)\._prepare_inputs\(prompt_inputs\)" | ||
| , | ||
| re.MULTILINE | ||
| ) | ||
|  | ||
| replacement = """ prompts = [x["prompt"] for x in inputs] | ||
| prompts_text = [maybe_apply_chat_template(example, self.processing_class)['prompt'] for example in inputs] | ||
| if not self.use_vision: | ||
| pixel_values, image_grid_thw = None, None | ||
| prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False) | ||
| prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
| else: | ||
| images = [x['image'] for x in inputs] # Only image inputs support for now | ||
| prompt_inputs = self.processing_class(images=images, text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False) | ||
| prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
| pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw']""" | ||
|  | ||
| function = pattern.sub(replacement, function) | ||
|  | ||
|  | ||
|  | ||
| # Add mixed precision training | ||
| function = function.replace( | ||
| "with torch.inference_mode():", | ||
|  | @@ -230,7 +315,24 @@ def grpo_trainer__prepare_inputs(function_name, function): | |
| function = function.rstrip() + "\n " + sleep_and_cache | ||
| return function | ||
| pass | ||
| RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs) | ||
|  | ||
| RL_FUNCTIONS["grpo_trainer"].append(grpo_generate_and_score_completions) | ||
|  | ||
| def grpo_prepare_inputs(function_name, function): | ||
| if function_name != "_prepare_inputs": return function | ||
|  | ||
| if "generation_batch = self._generate_and_score_completions(generation_batch)" not in function : return function | ||
|  | ||
| function = function.replace( | ||
| "generation_batch = self._generate_and_score_completions(generation_batch)", | ||
|  | ||
| "generation_batch = self._generate_and_score_completions(generation_batch)\n"\ | ||
| " if self.use_vision : generation_batch['pixel_values']=generation_batch['pixel_values'].view(generation_batch['prompt_ids'].size(0), -1, generation_batch['pixel_values'].size(1)) # (batch_size * n_patches, dim embedding)->(batch_size,n_patches,dim embeddding)" | ||
| ) | ||
|  | ||
| return function | ||
| pass | ||
| RL_FUNCTIONS["grpo_trainer"].append(grpo_prepare_inputs) | ||
|  | ||
|  | ||
| # Remove _move_model_to_vllm | ||
|  | @@ -249,8 +351,8 @@ def _move_model_to_vllm(self, *args, **kwargs): return None | |
| def grpo_trainer__get_per_token_logps(function_name, function): | ||
| if function_name != "_get_per_token_logps": return function | ||
|  | ||
| def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): | ||
| if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': | ||
| def _get_per_token_logps(self, model, input_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep, calc_logprob_flag = None): | ||
| if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0' and not calc_logprob_flag: | ||
| return None # Unsloth efficient GRPO | ||
| # Otherwise, calculate normally: | ||
| if not hasattr(self, '_autocast_dtype'): | ||
|  | @@ -260,13 +362,16 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) | |
| os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" | ||
| with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): | ||
| # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded | ||
| logits = model( | ||
| input_ids = input_ids, | ||
| attention_mask = attention_mask, | ||
| logits_to_keep = logits_to_keep + 1, | ||
| ).logits | ||
| # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred | ||
| return logits | ||
| if self.use_vision : | ||
|         
                  GAD-cell marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw, logits_to_keep=logits_to_keep + 1).logits | ||
| else: | ||
| hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits | ||
| #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred | ||
|  | ||
| if hidden_states.size(1) != logits_to_keep+1 : # Some models like Qwen VL don't have logits_to_keep parameter so you need to trim the output manually | ||
|          | ||
| hidden_states = hidden_states[:, -(logits_to_keep+1):, :] | ||
|  | ||
| return hidden_states | ||
| # input_ids = input_ids[:, -logits_to_keep:] | ||
| # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. | ||
| # See https://github.com/huggingface/trl/issues/2770 | ||
|  | @@ -311,22 +416,23 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch | |
|  | ||
| prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] | ||
| completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] | ||
| pixel_values,image_grid_thw = inputs.get("pixel_values", None), inputs.get("image_grid_thw", None) | ||
|         
                  GAD-cell marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| input_ids = torch.cat([prompt_ids, completion_ids], dim=1) | ||
| bsz, qlen = input_ids.shape | ||
| attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) | ||
| # attention_mask = None | ||
| logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens | ||
| _input_ids = input_ids | ||
| _logits_to_keep = logits_to_keep | ||
|  | ||
| per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) | ||
|         
                  GAD-cell marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep) | ||
|         
                  GAD-cell marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
|  | ||
| # Compute the KL divergence between the model and the reference model | ||
| # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. | ||
| # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328 | ||
| if self.beta != 0.0: | ||
| with torch.inference_mode(), model.disable_adapter(): | ||
| ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) | ||
| ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep) | ||
|         
                  GAD-cell marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| else: | ||
| ref_per_token_logps = None | ||
| # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 | ||
|  | ||
Uh oh!
There was an error while loading. Please reload this page.