Skip to content
Merged
Changes from 29 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
f911c32
Kept, padding logic
pluesclues Jun 23, 2025
2ba7f50
Made sure prediction step in rl.py allows logging for callbacks in RL…
pluesclues Jun 23, 2025
0c1bc4d
Merge branch 'unslothai:main' into main
pluesclues Jun 23, 2025
78336ce
updated llama.py to new online_dpo changes
pluesclues Jun 23, 2025
383aa9c
Update rl.py to make logic simpiler
pluesclues Jun 23, 2025
532af4f
Update rl.py, made sure tokenized_output on eval step was on same device
pluesclues Jun 24, 2025
49f77c1
Update rl.py, corrected tokenized_outputs to inputs
pluesclues Jun 24, 2025
7921aa7
Update rl.py, removed sagemaker stuff
pluesclues Jun 25, 2025
54f03ee
Update llama.py, figures out if there is right padding automatically
pluesclues Jul 2, 2025
a8d4168
Update llama.py, changed conditional statement for right padding slig…
pluesclues Jul 2, 2025
236b924
Update llama.py, updated OS.environ variable to temp variable
pluesclues Jul 8, 2025
76d73c6
Merge branch 'main' into main
pluesclues Jul 8, 2025
fa2e18e
Update rl.py, made it account for right padding in online dpo and rew…
pluesclues Jul 8, 2025
80f9cd2
Update llama.py, automatically figures out if right padding is needed
pluesclues Jul 8, 2025
ed1771a
Merge branch 'main' into main
pluesclues Jul 12, 2025
49d3844
Merge branch 'main' into main
pluesclues Aug 3, 2025
b0a9c65
Merge branch 'unslothai:main' into main
pluesclues Aug 8, 2025
30f3366
Update rl_replacements.py, fixed up passing image data to functions
pluesclues Aug 8, 2025
327053f
Merge branch 'unslothai:main' into vlm_grpo_update
pluesclues Aug 11, 2025
8af680f
Update rl_replacements.py, for VLM GRPO support with TRL
pluesclues Aug 11, 2025
5e0fbdb
Update rl_replacements.py, gspo added
pluesclues Aug 11, 2025
ba4fc39
Update rl.py, forgot about Online_DPO changes in this branch
pluesclues Aug 12, 2025
f9a2c18
Update rl.py, forgot to not include Online DPO PR changes
pluesclues Aug 12, 2025
36d3f97
Update llama.py, forgot to disinclude Online DPO PR changes
pluesclues Aug 12, 2025
9c11967
Merge branch 'unslothai:main' into vlm_grpo_update
pluesclues Aug 13, 2025
5a370a5
Merge branch 'unslothai:main' into vlm_grpo_update
pluesclues Aug 18, 2025
7e97306
Merge branch 'unslothai:main' into vlm_grpo_update
pluesclues Aug 20, 2025
a78b407
Merge branch 'unslothai:main' into vlm_grpo_update
pluesclues Aug 22, 2025
8266f9e
Update rl_replacements.py, updated generate and score completions to …
pluesclues Aug 22, 2025
b6fdf4d
Merge branch 'unslothai:main' into vlm_grpo_update
pluesclues Sep 4, 2025
0ad04a6
Merge branch 'unslothai:main' into vlm_grpo_update
pluesclues Sep 8, 2025
2b379ad
Update rl_replacements.py
pluesclues Sep 9, 2025
0eaf2ec
Update rl_replacements.py, fixed nan issues with vlms
pluesclues Sep 9, 2025
fb115fb
Update rl_replacements.py, added indent
pluesclues Sep 9, 2025
694f88e
Update rl_replacements.py, added attention mask to calculations of ol…
pluesclues Sep 10, 2025
5b4a03d
Merge branch 'unslothai:main' into vlm_grpo_update
pluesclues Sep 10, 2025
3e63ed2
Merge branch 'unslothai:main' into vlm_grpo_update
pluesclues Sep 16, 2025
b4b6c65
Update unsloth/models/rl_replacements.py
danielhanchen Sep 16, 2025
24712ca
Update unsloth/models/rl_replacements.py
danielhanchen Sep 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 95 additions & 68 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,22 +263,22 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
# because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
prompts_text = self.processing_class.batch_decode(
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
protected = [token for token in protected if token is not None]
prompt_ids, prompt_mask = truncate_with_protected_tokens(
prompt_ids, prompt_mask, self.max_prompt_length, protected
)
pad_token = self.processing_class.pad_token
def strip_leading_tokens(text):
while text.startswith(pad_token):
text = text.removeprefix(pad_token)
return text

if pad_token is not None:
prompts_text = [re.sub(rf"^({{re.escape(self.pad_token)}})+", "", text) for text in prompts_text]

# The chat template inserts a single image token into the prompt text. However, when this text is later
# tokenized, the single image token string is expanded into multiple image token IDs, depending on the
# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
# collapse them back into a single token string to match the original template.
if self.image_token is not None:
prompts_text = [
strip_leading_tokens(text) for text in prompts_text
re.sub(rf"({{re.escape(self.image_token)}})+", self.image_token, text) for text in prompts_text
]

# Generate completions using either vLLM or regular generation
if self.use_vllm:"""
function = function.replace(replace_part, new_replacement)
Expand All @@ -304,7 +304,7 @@ def _move_model_to_vllm(self, *args, **kwargs): return None
def grpo_trainer__get_per_token_logps(function_name, function):
if function_name != "_get_per_token_logps": return function

def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, compute_efficient = False):
if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
return None # Unsloth efficient GRPO
# Otherwise, calculate normally:
Expand Down Expand Up @@ -350,46 +350,58 @@ def grpo_trainer__get_per_token_logps_and_entropies(function_name, function):
if function_name != "_get_per_token_logps_and_entropies": return function

# Just copy over from _get_per_token_logps replacement function above. For now this returns None anyway
def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, logits_to_keep, batch_size = None, compute_entropy = False, *args, **kwargs):
if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
return None, None # logps, entropies Unsloth efficient GRPO
# Otherwise, calculate normally:
if not hasattr(self, '_autocast_dtype'):
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16

os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(
input_ids = input_ids,
attention_mask = attention_mask,
logits_to_keep = logits_to_keep + 1,
).logits

entropies = None
if compute_entropy:
from trl.trainer.utils import entropy_from_logits
entropies = entropy_from_logits(logits)

# logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
return logits, entropies # logps, entropies
# input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
# logits = logits[:, -logits_to_keep:]
# return logits
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
# logits = logits / self.temperature
# logps = selective_log_softmax(logits, input_ids)

# row_indices, col_indices = torch.where(logps < -20)

# # Method 1: Check if tensors have elements
# if len(row_indices) > 0 and len(col_indices) > 0:
# breakpoint() # Breakpoint triggered here
# print("Found high values!")
# return logps # compute logprobs for the input tokens
def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, logits_to_keep, batch_size = None,
compute_entropy = False, compute_efficient = False, *args, **kwargs):
# if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
# return None, None # logps, entropies Unsloth efficient GRPO
if compute_efficient:
return None, None
else:
# Otherwise, calculate normally:
if not hasattr(self, '_autocast_dtype'):
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16

pixel_values, image_grid_thw = kwargs.get("pixel_values", None), kwargs.get("image_grid_thw", None)
pixel_attention_mask, image_sizes = kwargs.get('pixel_attention_mask',None), kwargs.get('image_sizes',None)

os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(
input_ids = input_ids,
attention_mask = attention_mask,
pixel_values = pixel_values,
image_grid_thw = image_grid_thw,
pixel_attention_mask = pixel_attention_mask,
image_sizes = image_sizes,
logits_to_keep = logits_to_keep + 1,
).logits
pass

entropies = None
if compute_entropy:
from trl.trainer.utils import entropy_from_logits
entropies = entropy_from_logits(logits)
#breakpoint()
# logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
return logits, entropies # logps, entropies
# input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
# logits = logits[:, -logits_to_keep:]
# return logits
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
# logits = logits / self.temperature
# logps = selective_log_softmax(logits, input_ids)

# row_indices, col_indices = torch.where(logps < -20)

# # Method 1: Check if tensors have elements
# if len(row_indices) > 0 and len(col_indices) > 0:
# breakpoint() # Breakpoint triggered here
# print("Found high values!")
# return logps # compute logprobs for the input tokens
pass
pass

Expand All @@ -416,8 +428,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
raise ValueError("The GRPOTrainer does not support returning outputs")
# Compute the per-token log probabilities for the model

#breakpoint()

prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
pixel_values, image_grid_thw = inputs.get("pixel_values", None), inputs.get("image_grid_thw", None)
pixel_attention_mask, image_sizes = inputs.get('pixel_attention_mask',None), inputs.get('image_sizes',None)

input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
bsz, qlen = input_ids.shape
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
Expand All @@ -427,28 +444,29 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
_logits_to_keep = logits_to_keep

get_logps_func = \
lambda model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False: \
self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) \
lambda model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False, compute_efficient = False: \
self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep, compute_efficient) \
if hasattr(self, "_get_per_token_logps") else \
self._get_per_token_logps_and_entropies(model, input_ids, attention_mask, logits_to_keep, batch_size, compute_entropy)[0] # logps

per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep)

self._get_per_token_logps_and_entropies(model, input_ids, attention_mask, logits_to_keep, batch_size, compute_entropy, compute_efficient)[0] # logps
#breakpoint()
per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep, compute_efficient = True)
# Compute the KL divergence between the model and the reference model
# _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves.
# https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328
if self.beta != 0.0:
with torch.inference_mode(), model.disable_adapter():
ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep)
else:
ref_per_token_logps = None
# if self.beta != 0.0:
# with torch.inference_mode(), model.disable_adapter():
# ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep)
# else:
# ref_per_token_logps = None
ref_hidden_states = inputs.get("ref_per_token_logps", None)
# per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# x - x.detach() allows for preserving gradients from x
advantages = inputs["advantages"]
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
old_hidden_states = inputs.get("old_per_token_logps", None)

input_ids = input_ids[:, -logits_to_keep:]

# Get logit softcapping and logit scale
Expand All @@ -459,22 +477,26 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
logit_scale_divide = getattr(model.config, "logits_scaling", 0) # Granite
if logit_scale_divide is None: logit_scale_divide = 0


if per_token_logps is not None:

if ref_per_token_logps is not None:
ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
if ref_hidden_states is not None:
ref_hidden_states = ref_hidden_states[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
if old_hidden_states is not None:
old_hidden_states = old_hidden_states[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
per_token_logps = per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

loss, completion_length, mean_kl = grpo_compute_loss_slow(
ref_per_token_logps,
ref_hidden_states,
per_token_logps,
old_hidden_states,
input_ids,
completion_mask,
self.beta,
advantages,
pixel_values = pixel_values,
image_grid_thw = image_grid_thw,
loss_type = self.args.loss_type,
importance_sampling_level = self.importance_sampling_level,
epsilon_low = self.epsilon_low,
epsilon_high = self.epsilon_high,
max_completion_length = self.args.max_completion_length,
Expand All @@ -489,12 +511,16 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
loss, completion_length, mean_kl = grpo_accumulated_loss(
trainer = self,
input_ids = _input_ids,
pixel_values = pixel_values,
image_grid_thw = image_grid_thw,
logits_to_keep = logits_to_keep,
completion_mask = completion_mask,
advantages = advantages,
old_hidden_states = old_hidden_states,
ref_hidden_states = ref_hidden_states,
n_chunks = self.args.unsloth_num_chunks,
loss_type = self.args.loss_type,
importance_sampling_level = self.importance_sampling_level,
epsilon_low = self.epsilon_low,
epsilon_high = self.epsilon_high,
max_completion_length = self.args.max_completion_length,
Expand All @@ -514,6 +540,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
completion_mask = completion_mask,
advantages = advantages,
old_hidden_states = old_hidden_states,
ref_hidden_states = ref_hidden_states,
n_chunks = self.args.unsloth_num_chunks,
temperature = self.args.temperature,
logit_softcapping = logit_softcapping,
Expand Down