Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class Unsloth{RLConfig_name}({RLConfig_name}):
metadata = {{'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}},
)
def __init__({RLConfig_arguments},
use_vision = False,
vllm_sampling_params = None,
unsloth_num_chunks = -1,
**kwargs,
Expand All @@ -142,6 +143,7 @@ def __init__({RLConfig_arguments},
super().__init__({RLConfig_call_args}{RLConfig_kwargs})
self.vllm_sampling_params = vllm_sampling_params
self.unsloth_num_chunks = unsloth_num_chunks
self.use_vision = use_vision
pass

{RLTrainer_extras}
Expand Down Expand Up @@ -233,6 +235,12 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):

# Edit bf16, fp16 by checking model's torch_dtype directly
extra_args = ""

# Add boolean for vision support
if "args" in call_args :
use_vision = "self.use_vision = args.use_vision\n"
extra_args += use_vision

if "args" in call_args and "model" in call_args:
mixed_precision = \
"use_bf16 = getattr(args, 'bf16', False)\n"\
Expand Down
138 changes: 122 additions & 16 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,9 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch


# Autocast precision for GRPO
def grpo_trainer__prepare_inputs(function_name, function):
if function_name != "_prepare_inputs": return function
def grpo_generate_and_score_completions(function_name, function):
if function_name != "_generate_and_score_completions":
return function

import re
# This matches the function signature, decorators and any comments immediately following
Expand All @@ -186,7 +187,6 @@ def grpo_trainer__prepare_inputs(function_name, function):
# Find where the code block starts after comments
code_start_index = match.end(1)
rest_of_function = function[code_start_index:]

# Remove any old wake_up call that might be at the start of the function body
rest_of_function = re.sub(
r"^\s*if hasattr\(self, 'llm'\):.*?self\.llm\.wake_up\(\).*?\n",
Expand All @@ -205,6 +205,91 @@ def grpo_trainer__prepare_inputs(function_name, function):
)

function = header_and_comments + insert + rest_of_function

if """prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]""" not in function:
return function


# 1. Output pixel_values and image_grid_thw
pattern = re.compile(
r"^(?P<indent>\s*)return {\n"
r"(?P=indent) {4}\"prompt_ids\": prompt_ids,\n"
r"(?P=indent) {4}\"prompt_mask\": prompt_mask,\n"
r"(?P=indent) {4}\"completion_ids\": completion_ids,\n"
r"(?P=indent) {4}\"completion_mask\": completion_mask,\n"
r"(?P=indent) {4}\"advantages\": advantages,\n"
r"(?P=indent) {4}\"old_per_token_logps\": old_per_token_logps,\n"
r"(?P=indent)}",
re.MULTILINE
)

replacement = """ return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thw,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"advantages": advantages,
"old_per_token_logps": old_per_token_logps,
}"""
function = re.sub(pattern, replacement, function)

# 2. Replace the prompt_completion_ids generation
pattern = re.compile(
r"^(?P<indent>\s*)prompt_completion_ids = unwrapped_model\.generate\(\n"
r"(?P=indent) {4}prompt_ids, attention_mask=prompt_mask, generation_config=self\.generation_config\n"
r"(?P=indent)\)",
re.MULTILINE
)

replacement = """ if self.use_vision : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask,pixel_values = pixel_values,image_grid_thw=image_grid_thw, generation_config=self.generation_config)
else : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config)"""

function = pattern.sub(replacement, function)

# 3. Replace the old_per_token_logps generation
pattern = re.compile(
r"^(?P<indent>\s*)old_per_token_logps = self\._get_per_token_logps\(\n"
r"(?P=indent) {4}self\.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size\n"
r"(?P=indent)\)",
re.MULTILINE
)

replacement = """ old_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep, batch_size
)"""

function = re.sub(pattern, replacement, function)

# 4. Replace the prompt processing section
pattern = re.compile(
r"^(?P<indent>\s*)prompts = \[x\[\"prompt\"\] for x in inputs\]\n"
r"(?P=indent)prompts_text = \[maybe_apply_chat_template\(example, self\.processing_class\)\[\"prompt\"\] for example in inputs\]\n"
r"(?P=indent)prompt_inputs = self\.processing_class\(\n"
r"(?P=indent) {4}text=prompts_text, return_tensors=\"pt\", padding=True, padding_side=\"left\", add_special_tokens=False\n"
r"(?P=indent)\)\n"
r"(?P=indent)prompt_inputs = super\(\)\._prepare_inputs\(prompt_inputs\)"
,
re.MULTILINE
)

replacement = """ prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)['prompt'] for example in inputs]
if not self.use_vision:
pixel_values, image_grid_thw = None, None
prompt_inputs = self.processing_class(text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
else:
images = [x['image'] for x in inputs] # Only image inputs support for now
prompt_inputs = self.processing_class(images=images, text=prompts_text, return_tensors='pt', padding=True, padding_side="left", add_special_tokens=False)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
pixel_values, image_grid_thw = prompt_inputs['pixel_values'], prompt_inputs['image_grid_thw']"""

function = pattern.sub(replacement, function)



# Add mixed precision training
function = function.replace(
"with torch.inference_mode():",
Expand All @@ -230,7 +315,24 @@ def grpo_trainer__prepare_inputs(function_name, function):
function = function.rstrip() + "\n " + sleep_and_cache
return function
pass
RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs)

RL_FUNCTIONS["grpo_trainer"].append(grpo_generate_and_score_completions)

def grpo_prepare_inputs(function_name, function):
if function_name != "_prepare_inputs": return function

if "generation_batch = self._generate_and_score_completions(generation_batch)" not in function : return function

function = function.replace(
"generation_batch = self._generate_and_score_completions(generation_batch)",

"generation_batch = self._generate_and_score_completions(generation_batch)\n"\
" if self.use_vision : generation_batch['pixel_values']=generation_batch['pixel_values'].view(generation_batch['prompt_ids'].size(0), -1, generation_batch['pixel_values'].size(1)) # (batch_size * n_patches, dim embedding)->(batch_size,n_patches,dim embeddding)"
)

return function
pass
RL_FUNCTIONS["grpo_trainer"].append(grpo_prepare_inputs)


# Remove _move_model_to_vllm
Expand All @@ -249,8 +351,8 @@ def _move_model_to_vllm(self, *args, **kwargs): return None
def grpo_trainer__get_per_token_logps(function_name, function):
if function_name != "_get_per_token_logps": return function

def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
def _get_per_token_logps(self, model, input_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep, calc_logprob_flag = None):
if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0' and not calc_logprob_flag:
return None # Unsloth efficient GRPO
# Otherwise, calculate normally:
if not hasattr(self, '_autocast_dtype'):
Expand All @@ -260,13 +362,16 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(
input_ids = input_ids,
attention_mask = attention_mask,
logits_to_keep = logits_to_keep + 1,
).logits
# logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
return logits
if self.use_vision :
hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw, logits_to_keep=logits_to_keep + 1).logits
else:
hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
#logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

if hidden_states.size(1) != logits_to_keep+1 : # Some models like Qwen VL don't have logits_to_keep parameter so you need to trim the output manually
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait I think we do this automatically in kernels

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you do it in grpo_accumulated_loss, but for now vlm grpo uses grpo_compute_loss_slow. And in that case It needs to be trimmed. I've commented about this here. We can implement it for fast path but I didn't want to touch that part yet since the flag is not clear for me.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh and sorry for the spacing I forgot to double-check it :)

hidden_states = hidden_states[:, -(logits_to_keep+1):, :]

return hidden_states
# input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
Expand Down Expand Up @@ -311,22 +416,23 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch

prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
pixel_values,image_grid_thw = inputs.get("pixel_values", None), inputs.get("image_grid_thw", None)
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
bsz, qlen = input_ids.shape
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
# attention_mask = None
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
_input_ids = input_ids
_logits_to_keep = logits_to_keep

per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep)

# Compute the KL divergence between the model and the reference model
# _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves.
# https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328
if self.beta != 0.0:
with torch.inference_mode(), model.disable_adapter():
ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask,pixel_values,image_grid_thw, logits_to_keep)
else:
ref_per_token_logps = None
# per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
Expand Down