From f911c32a1a4ea7289f49a372e25a7c076a4eac4d Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Sun, 22 Jun 2025 20:59:55 -0400 Subject: [PATCH 01/25] Kept, padding logic --- unsloth/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 125bc7e61..d617129b3 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -762,8 +762,7 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif self.training: - # elif attention_mask is None: + elif self.training and os.environ.get("UNSLOTH_KEEP_PADDING", "0") != '1': attention_mask = None padding_mask = None else: From 2ba7f502ce8a84cb5158525cd8a9aac6e0319192 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Sun, 22 Jun 2025 21:02:48 -0400 Subject: [PATCH 02/25] Made sure prediction step in rl.py allows logging for callbacks in RL trainers --- unsloth/models/rl.py | 115 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 889bbd480..25f34d7ef 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -83,6 +83,120 @@ def generate_with_clone(*args, **kwargs): pass pass pass + + from transformers import Trainer + from transformers.utils import is_sagemaker_mp_enabled + if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") + + from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat + else: + IS_SAGEMAKER_MP_POST_1_10 = False + from transformers.trainer_pt_utils import nested_detach + @torch.no_grad() + def unsloth_prediction_step(self, model, inputs, prediction_loss_only,ignore_keys,): + """ + Perform an evaluation step on `model` using `inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to evaluate. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + + Return: + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, + logits and labels (each being optional). + """ + has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) + # For CLIP-like models capable of returning loss values. + # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` + # is `True` in `model.forward`. + return_loss = inputs.get("return_loss", None) + if return_loss is None: + return_loss = self.can_return_loss + loss_without_labels = True if len(self.label_names) == 0 and return_loss else False + + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. + if has_labels or loss_without_labels: + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + with torch.no_grad(): + if is_sagemaker_mp_enabled(): + raw_outputs = smp_forward_only(model, inputs) + if has_labels or loss_without_labels: + if isinstance(raw_outputs, dict): + loss_mb = raw_outputs["loss"] + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) + else: + loss_mb = raw_outputs[0] + logits_mb = raw_outputs[1:] + + loss = loss_mb.reduce_mean().detach().cpu() + logits = smp_nested_concat(logits_mb) + else: + loss = None + if isinstance(raw_outputs, dict): + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) + else: + logits_mb = raw_outputs + logits = smp_nested_concat(logits_mb) + else: + if has_labels or loss_without_labels: + with self.compute_loss_context_manager(): + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + loss = loss.mean().detach() + + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + else: + logits = outputs[1:] + else: + loss = None + with self.compute_loss_context_manager(): + #breakpoint() + tokenized_output = self.processing_class(inputs["prompt"], padding=True, truncation=True, return_tensors="pt") + outputs = model(**tokenized_output) + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + else: + logits = outputs + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index - 1] + + if prediction_loss_only: + return (loss, None, None) + + logits = nested_detach(logits) + if len(logits) == 1: + logits = logits[0] + + return (loss, logits, labels) import trl.trainer trainers = dir(trl.trainer) @@ -94,6 +208,7 @@ def generate_with_clone(*args, **kwargs): if hasattr(current_trainer, unwrap): try: exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}") except: continue + exec(f"Trainer.prediction_step=unsloth_prediction_step") pass pass From 78336ce8ec883b390cc3e4df37de176860779a7b Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 23 Jun 2025 19:18:23 -0400 Subject: [PATCH 03/25] updated llama.py to new online_dpo changes --- unsloth/models/llama.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d617129b3..9db8abdd4 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -69,7 +69,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING from transformers import set_seed as transformers_set_seed from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model -from peft import PeftModelForCausalLM +from peft import PeftModelForCausalLM, PeftModelForSequenceClassification from ..save import patch_saving_functions import re, os, inspect, math, sys import types @@ -2078,7 +2078,8 @@ def from_pretrained( model.for_inference = functools.partial(FastLlamaModel.for_inference, model) # Patch generate - if model.generate.__name__ != "unsloth_fast_generate": + is_classification = "Classification" in str(type(model)) + if not is_classification and model.generate.__name__ != "unsloth_fast_generate": model._old_generate = model.generate unsloth_fast_generate.__doc__ = model._old_generate.__doc__ model.generate = types.MethodType(unsloth_fast_generate, model) @@ -2158,7 +2159,7 @@ def get_peft_model( if r <= 0: raise TypeError(f"Unsloth: Rank of {str(r)} must be larger than 0.") - if isinstance(model, PeftModelForCausalLM): + if isinstance(model, PeftModelForCausalLM) or isinstance(model, PeftModelForSequenceClassification): # Check if exactly the same and then pass through! assert(hasattr(model, "peft_config")) @@ -2427,7 +2428,7 @@ def get_peft_model( is_classification = "Classification" in str(type(model)) # Get LoRA - # if not is_classification else TaskType.SEQ_CLS + # arguments = dict( r = r, @@ -2435,7 +2436,7 @@ def get_peft_model( target_modules = final_modules, lora_dropout = lora_dropout, bias = bias, - task_type = TaskType.CAUSAL_LM, + task_type = TaskType.CAUSAL_LM if not is_classification else TaskType.SEQ_CLS, layers_to_transform = layers_to_transform, init_lora_weights = init_lora_weights, loftq_config = loftq_config, @@ -2449,7 +2450,6 @@ def get_peft_model( _saved_temp_tokenizer = model._saved_temp_tokenizer lora_config = LoraConfig(**arguments) - # First offload lm_head and embed_tokens to disk input_embeddings_device = model.get_input_embeddings().weight.device if is_classification: @@ -2571,7 +2571,7 @@ def patch_peft_model( use_gradient_checkpointing = use_gradient_checkpointing, ) pass - if not isinstance(model, PeftModelForCausalLM): + if not isinstance(model, PeftModelForCausalLM) and not isinstance(model, PeftModelForSequenceClassification): raise TypeError( "Unsloth: Your model needs to call `.get_peft_model` first!" ) From 383aa9c18647bb96d30e61d79976176c9243fc77 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 23 Jun 2025 19:18:53 -0400 Subject: [PATCH 04/25] Update rl.py to make logic simpiler From 532af4f9716de2c9b4c123cc951b2f2bec0d3da4 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 23 Jun 2025 20:03:51 -0400 Subject: [PATCH 05/25] Update rl.py, made sure tokenized_output on eval step was on same device --- unsloth/models/rl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 25f34d7ef..41402b2d6 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -178,8 +178,7 @@ def unsloth_prediction_step(self, model, inputs, prediction_loss_only,ignore_key else: loss = None with self.compute_loss_context_manager(): - #breakpoint() - tokenized_output = self.processing_class(inputs["prompt"], padding=True, truncation=True, return_tensors="pt") + tokenized_output = self.processing_class(inputs["prompt"], padding=True, truncation=True, return_tensors="pt").to(model.device) outputs = model(**tokenized_output) if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) From 49f77c102eb0cd13c529804d7ffbf411e55910ea Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Tue, 24 Jun 2025 18:18:31 -0400 Subject: [PATCH 06/25] Update rl.py, corrected tokenized_outputs to inputs --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 41402b2d6..570cc9c9c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -178,8 +178,8 @@ def unsloth_prediction_step(self, model, inputs, prediction_loss_only,ignore_key else: loss = None with self.compute_loss_context_manager(): - tokenized_output = self.processing_class(inputs["prompt"], padding=True, truncation=True, return_tensors="pt").to(model.device) - outputs = model(**tokenized_output) + inputs = self.processing_class(inputs["prompt"], padding=True, truncation=True, return_tensors="pt").to(model.device) + outputs = model(**inputs) if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) else: From 7921aa7c3a49d7765e18526655588a6360d2030f Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Wed, 25 Jun 2025 19:50:10 -0400 Subject: [PATCH 07/25] Update rl.py, removed sagemaker stuff --- unsloth/models/rl.py | 66 ++++++++++++-------------------------------- 1 file changed, 18 insertions(+), 48 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 570cc9c9c..7e4478283 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -85,16 +85,6 @@ def generate_with_clone(*args, **kwargs): pass from transformers import Trainer - from transformers.utils import is_sagemaker_mp_enabled - if is_sagemaker_mp_enabled(): - import smdistributed.modelparallel.torch as smp - from smdistributed.modelparallel import __version__ as SMP_VERSION - - IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") - - from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat - else: - IS_SAGEMAKER_MP_POST_1_10 = False from transformers.trainer_pt_utils import nested_detach @torch.no_grad() def unsloth_prediction_step(self, model, inputs, prediction_loss_only,ignore_keys,): @@ -146,47 +136,27 @@ def unsloth_prediction_step(self, model, inputs, prediction_loss_only,ignore_key labels = None with torch.no_grad(): - if is_sagemaker_mp_enabled(): - raw_outputs = smp_forward_only(model, inputs) - if has_labels or loss_without_labels: - if isinstance(raw_outputs, dict): - loss_mb = raw_outputs["loss"] - logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) - else: - loss_mb = raw_outputs[0] - logits_mb = raw_outputs[1:] - - loss = loss_mb.reduce_mean().detach().cpu() - logits = smp_nested_concat(logits_mb) + if has_labels or loss_without_labels: + with self.compute_loss_context_manager(): + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + loss = loss.mean().detach() + + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) else: - loss = None - if isinstance(raw_outputs, dict): - logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) - else: - logits_mb = raw_outputs - logits = smp_nested_concat(logits_mb) + logits = outputs[1:] else: - if has_labels or loss_without_labels: - with self.compute_loss_context_manager(): - loss, outputs = self.compute_loss(model, inputs, return_outputs=True) - loss = loss.mean().detach() - - if isinstance(outputs, dict): - logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) - else: - logits = outputs[1:] + loss = None + with self.compute_loss_context_manager(): + tokenized_output = self.processing_class(inputs["prompt"], padding=True, truncation=True, return_tensors="pt").to(model.device) + outputs = model(**tokenized_output) + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) else: - loss = None - with self.compute_loss_context_manager(): - inputs = self.processing_class(inputs["prompt"], padding=True, truncation=True, return_tensors="pt").to(model.device) - outputs = model(**inputs) - if isinstance(outputs, dict): - logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) - else: - logits = outputs - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[self.args.past_index - 1] + logits = outputs + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index - 1] if prediction_loss_only: return (loss, None, None) From 54f03eec34b09a38008ce71b548010d417f169a6 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Wed, 2 Jul 2025 18:40:19 -0400 Subject: [PATCH 08/25] Update llama.py, figures out if there is right padding automatically --- unsloth/models/llama.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9db8abdd4..56b3c8316 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -758,7 +758,11 @@ def LlamaModel_fast_forward( inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2) if inputs_requires_grad: inputs_embeds.requires_grad_(True) pass - + #Figure out if there is right padding + if attention_mask is not None: + pads_right = (attention_mask[:, -1] == 0).any() + if pads_right.item(): + os.environ["UNSLOTH_KEEP_PADDING"] = '1' # Ignore attention_mask if attention_mask is None: padding_mask = None From a8d416895418502de1e3c36ae1da431f35c05766 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Wed, 2 Jul 2025 18:46:12 -0400 Subject: [PATCH 09/25] Update llama.py, changed conditional statement for right padding slightlyt --- unsloth/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 56b3c8316..66bbf2e96 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -763,6 +763,8 @@ def LlamaModel_fast_forward( pads_right = (attention_mask[:, -1] == 0).any() if pads_right.item(): os.environ["UNSLOTH_KEEP_PADDING"] = '1' + else: + os.environ["UNSLOTH_KEEP_PADDING"] = '0' # Ignore attention_mask if attention_mask is None: padding_mask = None From 236b9241055dd7a57142d0947a8c158d114c963b Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 7 Jul 2025 21:50:46 -0400 Subject: [PATCH 10/25] Update llama.py, updated OS.environ variable to temp variable --- unsloth/models/llama.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 66bbf2e96..9b5181b7a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -759,24 +759,21 @@ def LlamaModel_fast_forward( if inputs_requires_grad: inputs_embeds.requires_grad_(True) pass #Figure out if there is right padding + keep_padding = False if attention_mask is not None: pads_right = (attention_mask[:, -1] == 0).any() if pads_right.item(): - os.environ["UNSLOTH_KEEP_PADDING"] = '1' + keep_padding = True else: - os.environ["UNSLOTH_KEEP_PADDING"] = '0' + keep_padding = False # Ignore attention_mask if attention_mask is None: padding_mask = None - elif self.training and os.environ.get("UNSLOTH_KEEP_PADDING", "0") != '1': + elif self.training and not keep_padding: attention_mask = None padding_mask = None else: - # if 0 in attention_mask: - # padding_mask = attention_mask - # else: padding_mask = None - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), From fa2e18e7f8d9e6e25e713aa854bd6296f32c2665 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 7 Jul 2025 23:30:05 -0400 Subject: [PATCH 11/25] Update rl.py, made it account for right padding in online dpo and reward modeling --- unsloth/models/rl.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 7e4478283..0bd0ce62a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -785,7 +785,34 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import init = init.replace("if peft_config is None:", "if False:") init = init.replace("if peft_config is not None:", "if False:") init = init.replace("get_peft_model(model, peft_config)", "model") - + if RLTrainer_name == "OnlineDPOTrainer" or RLTrainer_name == "RewardTrainer": + pattern = re.compile( + r'(^(\s*)# Add tags for models that have been loaded with the correct transformers version)', + re.MULTILINE + ) + if RLTrainer_name == "OnlineDPOTrainer": + replacement = ( + r'\2# Check for nested model attribute before setting right padding\n' + r'\2if hasattr(self.model.model, "model"):\n' + r'\2 self.model.model.model._needs_right_padding = True\n' + r'\2else:\n' + r'\2 self.model.model._needs_right_padding = True\n' + r'\2# Check for nested model attribute before setting right padding\n' + r'\2if hasattr(self.reward_model, "model"):\n' + r'\2 self.reward_model.model._needs_right_padding = True\n' + r'\2print("Unsloth reward model type: ", type(self.reward_model.model))' + r'\1' + ) + else: + replacement = ( + r'\2# Check for nested model attribute before setting right padding\n' + r'\2if hasattr(self.model.model, "model"):\n' + r'\2 self.model.model.model._needs_right_padding = True\n' + r'\2else:\n' + r'\2 self.model.model._needs_right_padding = True\n' + r'\1' + ) + init = pattern.sub(replacement, init, count=1) # Set use_vllm if not set if "args.use_vllm" in init and "model" in init and "args" in init: # .*? matches first match. .+? matches final match. From 80f9cd274244d7bc244cf6950722b5ae430aaabb Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 7 Jul 2025 23:31:07 -0400 Subject: [PATCH 12/25] Update llama.py, automatically figures out if right padding is needed --- unsloth/models/llama.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index e69482523..fb723041f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -771,18 +771,10 @@ def LlamaModel_fast_forward( inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2) if inputs_requires_grad: inputs_embeds.requires_grad_(True) pass - #Figure out if there is right padding - keep_padding = False - if attention_mask is not None: - pads_right = (attention_mask[:, -1] == 0).any() - if pads_right.item(): - keep_padding = True - else: - keep_padding = False # Ignore attention_mask if attention_mask is None: padding_mask = None - elif self.training and not keep_padding: + elif self.training and not hasattr(self, "_needs_right_padding"): attention_mask = None padding_mask = None else: From 30f3366a1eac8ceb8e3e0167d657aa8223a24780 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Fri, 8 Aug 2025 16:47:51 -0400 Subject: [PATCH 13/25] Update rl_replacements.py, fixed up passing image data to functions --- unsloth/models/rl_replacements.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 2555f0df1..7efd413ed 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -416,8 +416,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch raise ValueError("The GRPOTrainer does not support returning outputs") # Compute the per-token log probabilities for the model + #breakpoint() + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + pixel_values, image_grid_thw = inputs.get("pixel_values", None), inputs.get("image_grid_thw", None) + pixel_attention_mask, image_sizes = inputs.get('pixel_attention_mask',None), inputs.get('image_sizes',None) + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) bsz, qlen = input_ids.shape attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) @@ -433,7 +438,6 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch self._get_per_token_logps_and_entropies(model, input_ids, attention_mask, logits_to_keep, batch_size, compute_entropy)[0] # logps per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep) - # Compute the KL divergence between the model and the reference model # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328 @@ -444,11 +448,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch ref_per_token_logps = None # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 # x - x.detach() allows for preserving gradients from x + #breakpoint() advantages = inputs["advantages"] # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() old_hidden_states = inputs.get("old_per_token_logps", None) + input_ids = input_ids[:, -logits_to_keep:] # Get logit softcapping and logit scale @@ -474,6 +480,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch completion_mask, self.beta, advantages, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, loss_type = self.args.loss_type, epsilon_low = self.epsilon_low, epsilon_high = self.epsilon_high, @@ -489,6 +497,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch loss, completion_length, mean_kl = grpo_accumulated_loss( trainer = self, input_ids = _input_ids, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, logits_to_keep = logits_to_keep, completion_mask = completion_mask, advantages = advantages, From 8af680ffd4fa330cfc8571e522f9413bdf600fc2 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 11 Aug 2025 14:25:06 -0400 Subject: [PATCH 14/25] Update rl_replacements.py, for VLM GRPO support with TRL --- unsloth/models/rl_replacements.py | 128 +++++++++++++++++------------- 1 file changed, 72 insertions(+), 56 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 7efd413ed..efa662f09 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -304,7 +304,7 @@ def _move_model_to_vllm(self, *args, **kwargs): return None def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function - def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, compute_efficient = False): if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': return None # Unsloth efficient GRPO # Otherwise, calculate normally: @@ -350,46 +350,58 @@ def grpo_trainer__get_per_token_logps_and_entropies(function_name, function): if function_name != "_get_per_token_logps_and_entropies": return function # Just copy over from _get_per_token_logps replacement function above. For now this returns None anyway - def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, logits_to_keep, batch_size = None, compute_entropy = False, *args, **kwargs): - if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': - return None, None # logps, entropies Unsloth efficient GRPO - # Otherwise, calculate normally: - if not hasattr(self, '_autocast_dtype'): - self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 - if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16 - - os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" - with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): - # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - logits = model( - input_ids = input_ids, - attention_mask = attention_mask, - logits_to_keep = logits_to_keep + 1, - ).logits - - entropies = None - if compute_entropy: - from trl.trainer.utils import entropy_from_logits - entropies = entropy_from_logits(logits) - - # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - return logits, entropies # logps, entropies - # input_ids = input_ids[:, -logits_to_keep:] - # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. - # See https://github.com/huggingface/trl/issues/2770 - # logits = logits[:, -logits_to_keep:] - # return logits - # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details - # logits = logits / self.temperature - # logps = selective_log_softmax(logits, input_ids) - - # row_indices, col_indices = torch.where(logps < -20) - - # # Method 1: Check if tensors have elements - # if len(row_indices) > 0 and len(col_indices) > 0: - # breakpoint() # Breakpoint triggered here - # print("Found high values!") - # return logps # compute logprobs for the input tokens + def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, logits_to_keep, batch_size = None, + compute_entropy = False, compute_efficient = False, *args, **kwargs): + # if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': + # return None, None # logps, entropies Unsloth efficient GRPO + if compute_efficient: + return None, None + else: + # Otherwise, calculate normally: + if not hasattr(self, '_autocast_dtype'): + self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 + if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16 + + pixel_values, image_grid_thw = kwargs.get("pixel_values", None), kwargs.get("image_grid_thw", None) + pixel_attention_mask, image_sizes = kwargs.get('pixel_attention_mask',None), kwargs.get('image_sizes',None) + + os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" + with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + logits = model( + input_ids = input_ids, + attention_mask = attention_mask, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, + pixel_attention_mask = pixel_attention_mask, + image_sizes = image_sizes, + logits_to_keep = logits_to_keep + 1, + ).logits + pass + + entropies = None + if compute_entropy: + from trl.trainer.utils import entropy_from_logits + entropies = entropy_from_logits(logits) + #breakpoint() + # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + return logits, entropies # logps, entropies + # input_ids = input_ids[:, -logits_to_keep:] + # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. + # See https://github.com/huggingface/trl/issues/2770 + # logits = logits[:, -logits_to_keep:] + # return logits + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + # logits = logits / self.temperature + # logps = selective_log_softmax(logits, input_ids) + + # row_indices, col_indices = torch.where(logps < -20) + + # # Method 1: Check if tensors have elements + # if len(row_indices) > 0 and len(col_indices) > 0: + # breakpoint() # Breakpoint triggered here + # print("Found high values!") + # return logps # compute logprobs for the input tokens pass pass @@ -432,23 +444,23 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch _logits_to_keep = logits_to_keep get_logps_func = \ - lambda model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False: \ - self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) \ + lambda model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False, compute_efficient = False: \ + self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep, compute_efficient) \ if hasattr(self, "_get_per_token_logps") else \ - self._get_per_token_logps_and_entropies(model, input_ids, attention_mask, logits_to_keep, batch_size, compute_entropy)[0] # logps - - per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep) + self._get_per_token_logps_and_entropies(model, input_ids, attention_mask, logits_to_keep, batch_size, compute_entropy, compute_efficient)[0] # logps + #breakpoint() + per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep, compute_efficient = True) # Compute the KL divergence between the model and the reference model # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328 - if self.beta != 0.0: - with torch.inference_mode(), model.disable_adapter(): - ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep) - else: - ref_per_token_logps = None + # if self.beta != 0.0: + # with torch.inference_mode(), model.disable_adapter(): + # ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep) + # else: + # ref_per_token_logps = None + ref_hidden_states = inputs.get("ref_per_token_logps", None) # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 # x - x.detach() allows for preserving gradients from x - #breakpoint() advantages = inputs["advantages"] # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) # per_token_loss = -(per_token_loss - self.beta * per_token_kl) @@ -465,15 +477,16 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logit_scale_divide = getattr(model.config, "logits_scaling", 0) # Granite if logit_scale_divide is None: logit_scale_divide = 0 - if per_token_logps is not None: - if ref_per_token_logps is not None: - ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + if ref_hidden_states is not None: + ref_hidden_states = ref_hidden_states[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + if old_hidden_states is not None: + old_hidden_states = old_hidden_states[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred per_token_logps = per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred loss, completion_length, mean_kl = grpo_compute_loss_slow( - ref_per_token_logps, + ref_hidden_states, per_token_logps, old_hidden_states, input_ids, @@ -493,6 +506,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logit_scale_divide = logit_scale_divide, ) else: + #breakpoint() if hasattr(self.args, "loss_type"): loss, completion_length, mean_kl = grpo_accumulated_loss( trainer = self, @@ -503,6 +517,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch completion_mask = completion_mask, advantages = advantages, old_hidden_states = old_hidden_states, + ref_hidden_states = ref_hidden_states, n_chunks = self.args.unsloth_num_chunks, loss_type = self.args.loss_type, epsilon_low = self.epsilon_low, @@ -524,6 +539,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch completion_mask = completion_mask, advantages = advantages, old_hidden_states = old_hidden_states, + ref_hidden_states = ref_hidden_states, n_chunks = self.args.unsloth_num_chunks, temperature = self.args.temperature, logit_softcapping = logit_softcapping, From 5e0fbdbbd63debc6e2113c373ff997e5d953ffd1 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 11 Aug 2025 17:33:08 -0400 Subject: [PATCH 15/25] Update rl_replacements.py, gspo added --- unsloth/models/rl_replacements.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index efa662f09..4d8cdd1d4 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -496,6 +496,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch pixel_values = pixel_values, image_grid_thw = image_grid_thw, loss_type = self.args.loss_type, + importance_sampling_level = self.importance_sampling_level, epsilon_low = self.epsilon_low, epsilon_high = self.epsilon_high, max_completion_length = self.args.max_completion_length, @@ -506,7 +507,6 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logit_scale_divide = logit_scale_divide, ) else: - #breakpoint() if hasattr(self.args, "loss_type"): loss, completion_length, mean_kl = grpo_accumulated_loss( trainer = self, @@ -520,6 +520,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch ref_hidden_states = ref_hidden_states, n_chunks = self.args.unsloth_num_chunks, loss_type = self.args.loss_type, + importance_sampling_level = self.importance_sampling_level, epsilon_low = self.epsilon_low, epsilon_high = self.epsilon_high, max_completion_length = self.args.max_completion_length, From ba4fc3981d0e35730c8397773f0c62b64f320f7c Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Tue, 12 Aug 2025 08:02:49 -0400 Subject: [PATCH 16/25] Update rl.py, forgot about Online_DPO changes in this branch --- unsloth/models/rl.py | 112 ------------------------------------------- 1 file changed, 112 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index f0fe6211c..deb779588 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -83,89 +83,6 @@ def generate_with_clone(*args, **kwargs): pass pass pass - - from transformers import Trainer - from transformers.trainer_pt_utils import nested_detach - @torch.no_grad() - def unsloth_prediction_step(self, model, inputs, prediction_loss_only,ignore_keys,): - """ - Perform an evaluation step on `model` using `inputs`. - - Subclass and override to inject custom behavior. - - Args: - model (`nn.Module`): - The model to evaluate. - inputs (`Dict[str, Union[torch.Tensor, Any]]`): - The inputs and targets of the model. - - The dictionary will be unpacked before being fed to the model. Most models expect the targets under the - argument `labels`. Check your model's documentation for all accepted arguments. - prediction_loss_only (`bool`): - Whether or not to return the loss only. - ignore_keys (`List[str]`, *optional*): - A list of keys in the output of your model (if it is a dictionary) that should be ignored when - gathering predictions. - - Return: - Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, - logits and labels (each being optional). - """ - has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) - # For CLIP-like models capable of returning loss values. - # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` - # is `True` in `model.forward`. - return_loss = inputs.get("return_loss", None) - if return_loss is None: - return_loss = self.can_return_loss - loss_without_labels = True if len(self.label_names) == 0 and return_loss else False - - inputs = self._prepare_inputs(inputs) - if ignore_keys is None: - if hasattr(self.model, "config"): - ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) - else: - ignore_keys = [] - - # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. - if has_labels or loss_without_labels: - labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) - if len(labels) == 1: - labels = labels[0] - else: - labels = None - - with torch.no_grad(): - if has_labels or loss_without_labels: - with self.compute_loss_context_manager(): - loss, outputs = self.compute_loss(model, inputs, return_outputs=True) - loss = loss.mean().detach() - - if isinstance(outputs, dict): - logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) - else: - logits = outputs[1:] - else: - loss = None - with self.compute_loss_context_manager(): - tokenized_output = self.processing_class(inputs["prompt"], padding=True, truncation=True, return_tensors="pt").to(model.device) - outputs = model(**tokenized_output) - if isinstance(outputs, dict): - logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) - else: - logits = outputs - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[self.args.past_index - 1] - - if prediction_loss_only: - return (loss, None, None) - - logits = nested_detach(logits) - if len(logits) == 1: - logits = logits[0] - - return (loss, logits, labels) import trl.trainer trainers = dir(trl.trainer) @@ -177,7 +94,6 @@ def unsloth_prediction_step(self, model, inputs, prediction_loss_only,ignore_key if hasattr(current_trainer, unwrap): try: exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}") except: continue - exec(f"Trainer.prediction_step=unsloth_prediction_step") pass pass @@ -792,34 +708,6 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import init = init.replace("if peft_config is None:", "if False:") init = init.replace("if peft_config is not None:", "if False:") init = init.replace("get_peft_model(model, peft_config)", "model") - if RLTrainer_name == "OnlineDPOTrainer" or RLTrainer_name == "RewardTrainer": - pattern = re.compile( - r'(^(\s*)# Add tags for models that have been loaded with the correct transformers version)', - re.MULTILINE - ) - if RLTrainer_name == "OnlineDPOTrainer": - replacement = ( - r'\2# Check for nested model attribute before setting right padding\n' - r'\2if hasattr(self.model.model, "model"):\n' - r'\2 self.model.model.model._needs_right_padding = True\n' - r'\2else:\n' - r'\2 self.model.model._needs_right_padding = True\n' - r'\2# Check for nested model attribute before setting right padding\n' - r'\2if hasattr(self.reward_model, "model"):\n' - r'\2 self.reward_model.model._needs_right_padding = True\n' - r'\2print("Unsloth reward model type: ", type(self.reward_model.model))' - r'\1' - ) - else: - replacement = ( - r'\2# Check for nested model attribute before setting right padding\n' - r'\2if hasattr(self.model.model, "model"):\n' - r'\2 self.model.model.model._needs_right_padding = True\n' - r'\2else:\n' - r'\2 self.model.model._needs_right_padding = True\n' - r'\1' - ) - init = pattern.sub(replacement, init, count=1) # New TRL 0.20.0 init = init.replace("if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):", "if False:") # New TRL 0.20.0 From f9a2c185444497354a2570006a14d770493b03e5 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Tue, 12 Aug 2025 08:04:15 -0400 Subject: [PATCH 17/25] Update rl.py, forgot to not include Online DPO PR changes From 36d3f970ec3aff970711f8420f33896e177ed8ad Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Tue, 12 Aug 2025 08:05:19 -0400 Subject: [PATCH 18/25] Update llama.py, forgot to disinclude Online DPO PR changes --- unsloth/models/llama.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 40bc3cd5a..3c0d5012a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -794,14 +794,19 @@ def LlamaModel_fast_forward( inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2) if inputs_requires_grad: inputs_embeds.requires_grad_(True) pass + # Ignore attention_mask if attention_mask is None: padding_mask = None - elif self.training and not hasattr(self, "_needs_right_padding"): + elif self.training: attention_mask = None padding_mask = None else: + # if 0 in attention_mask: + # padding_mask = attention_mask + # else: padding_mask = None + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), From 8266f9e55cd9b8dad2e643ca33d9be37214d51ce Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Fri, 22 Aug 2025 14:51:35 -0400 Subject: [PATCH 19/25] Update rl_replacements.py, updated generate and score completions to be up to date for trl --- unsloth/models/rl_replacements.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 4d8cdd1d4..1c06383d1 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -263,22 +263,22 @@ def grpo_trainer__generate_and_score_completions(function_name, function): # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). - prompt_ids = prompt_ids[:, -self.max_prompt_length :] - prompt_mask = prompt_mask[:, -self.max_prompt_length :] - prompts_text = self.processing_class.batch_decode( - prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False + protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id] + protected = [token for token in protected if token is not None] + prompt_ids, prompt_mask = truncate_with_protected_tokens( + prompt_ids, prompt_mask, self.max_prompt_length, protected ) - pad_token = self.processing_class.pad_token - def strip_leading_tokens(text): - while text.startswith(pad_token): - text = text.removeprefix(pad_token) - return text - if pad_token is not None: + prompts_text = [re.sub(rf"^({{re.escape(self.pad_token)}})+", "", text) for text in prompts_text] + + # The chat template inserts a single image token into the prompt text. However, when this text is later + # tokenized, the single image token string is expanded into multiple image token IDs, depending on the + # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We + # collapse them back into a single token string to match the original template. + if self.image_token is not None: prompts_text = [ - strip_leading_tokens(text) for text in prompts_text + re.sub(rf"({{re.escape(self.image_token)}})+", self.image_token, text) for text in prompts_text ] - # Generate completions using either vLLM or regular generation if self.use_vllm:""" function = function.replace(replace_part, new_replacement) From 2b379ade38baa2277c696b40a39f235d4f3c4fb0 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Tue, 9 Sep 2025 11:59:55 -0400 Subject: [PATCH 20/25] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 1c06383d1..7766f4059 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -375,7 +375,7 @@ def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, l image_grid_thw = image_grid_thw, pixel_attention_mask = pixel_attention_mask, image_sizes = image_sizes, - logits_to_keep = logits_to_keep + 1, + #logits_to_keep = logits_to_keep + 1, ).logits pass From 0eaf2ec87f8470af96dc738080acc02e289c2f79 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Tue, 9 Sep 2025 13:42:46 -0400 Subject: [PATCH 21/25] Update rl_replacements.py, fixed nan issues with vlms --- unsloth/models/rl_replacements.py | 90 +++++++++++++++++++------------ 1 file changed, 57 insertions(+), 33 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 7766f4059..eae2dc1d6 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -245,6 +245,17 @@ def grpo_trainer__generate_and_score_completions(function_name, function): "prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False", ) + # Left pad prompt before calculation old and ref hidden states + line_to_replace = "batch_size = self.args.per_device_train_batch_size if mode == \"train\" else self.args.per_device_eval_batch_size" + + # The new multi-line string that will replace the line above + replacement_lines = """batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + if not has_images: + # Left pad prompt before calculation old and ref hidden states + prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id)""" + + function = function.replace(line_to_replace, replacement_lines) + # Always between max_prompt_length and use_vllm found = re.findall( r"\n(([ ]{8,})if self\.max_prompt_length is not None:.*?"\ @@ -282,7 +293,8 @@ def grpo_trainer__generate_and_score_completions(function_name, function): # Generate completions using either vLLM or regular generation if self.use_vllm:""" function = function.replace(replace_part, new_replacement) - pass + + return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__generate_and_score_completions) @@ -366,42 +378,54 @@ def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, l pixel_attention_mask, image_sizes = kwargs.get('pixel_attention_mask',None), kwargs.get('image_sizes',None) os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" - with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): - # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - logits = model( - input_ids = input_ids, - attention_mask = attention_mask, - pixel_values = pixel_values, - image_grid_thw = image_grid_thw, - pixel_attention_mask = pixel_attention_mask, - image_sizes = image_sizes, - #logits_to_keep = logits_to_keep + 1, - ).logits - pass + if pixel_values is None: + with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + logits = model( + input_ids = input_ids, + attention_mask = attention_mask, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, + pixel_attention_mask = pixel_attention_mask, + image_sizes = image_sizes, + #logits_to_keep = logits_to_keep + 1, + ).logits + else: + with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): + logits = model( + input_ids = input_ids, + attention_mask = attention_mask, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, + pixel_attention_mask = pixel_attention_mask, + image_sizes = image_sizes, + logits_to_keep = logits_to_keep + 1, + ).logits + entropies = None - entropies = None + with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): if compute_entropy: from trl.trainer.utils import entropy_from_logits entropies = entropy_from_logits(logits) - #breakpoint() - # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - return logits, entropies # logps, entropies - # input_ids = input_ids[:, -logits_to_keep:] - # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. - # See https://github.com/huggingface/trl/issues/2770 - # logits = logits[:, -logits_to_keep:] - # return logits - # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details - # logits = logits / self.temperature - # logps = selective_log_softmax(logits, input_ids) - - # row_indices, col_indices = torch.where(logps < -20) - - # # Method 1: Check if tensors have elements - # if len(row_indices) > 0 and len(col_indices) > 0: - # breakpoint() # Breakpoint triggered here - # print("Found high values!") - # return logps # compute logprobs for the input tokens + #breakpoint() + # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + return logits, entropies # logps, entropies + # input_ids = input_ids[:, -logits_to_keep:] + # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. + # See https://github.com/huggingface/trl/issues/2770 + # logits = logits[:, -logits_to_keep:] + # return logits + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + # logits = logits / self.temperature + # logps = selective_log_softmax(logits, input_ids) + + # row_indices, col_indices = torch.where(logps < -20) + + # # Method 1: Check if tensors have elements + # if len(row_indices) > 0 and len(col_indices) > 0: + # breakpoint() # Breakpoint triggered here + # print("Found high values!") + # return logps # compute logprobs for the input tokens pass pass From fb115fb16cb2592caf99a9414b7d1f95f1f819ca Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Tue, 9 Sep 2025 13:52:09 -0400 Subject: [PATCH 22/25] Update rl_replacements.py, added indent --- unsloth/models/rl_replacements.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index eae2dc1d6..3fae8cef6 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -249,7 +249,8 @@ def grpo_trainer__generate_and_score_completions(function_name, function): line_to_replace = "batch_size = self.args.per_device_train_batch_size if mode == \"train\" else self.args.per_device_eval_batch_size" # The new multi-line string that will replace the line above - replacement_lines = """batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + replacement_lines = """ + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size if not has_images: # Left pad prompt before calculation old and ref hidden states prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id)""" From 694f88e72b36e77a95c4e0fdbe836f364aa73078 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Tue, 9 Sep 2025 20:50:52 -0400 Subject: [PATCH 23/25] Update rl_replacements.py, added attention mask to calculations of old and ref hidden states --- unsloth/models/rl_replacements.py | 44 ++++++++++++++++++------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 3fae8cef6..13b4e4f6a 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -25,7 +25,7 @@ import torch import inspect from collections import defaultdict -from unsloth_zoo.rl_replacements import RL_REPLACEMENTS +from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding from unsloth import DEVICE_TYPE RL_EXTRA_ARGS = defaultdict(list) @@ -379,21 +379,26 @@ def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, l pixel_attention_mask, image_sizes = kwargs.get('pixel_attention_mask',None), kwargs.get('image_sizes',None) os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" - if pixel_values is None: - with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): - # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - logits = model( - input_ids = input_ids, - attention_mask = attention_mask, - pixel_values = pixel_values, - image_grid_thw = image_grid_thw, - pixel_attention_mask = pixel_attention_mask, - image_sizes = image_sizes, - #logits_to_keep = logits_to_keep + 1, - ).logits - else: - with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): - logits = model( + + unwrapped_model = self.accelerator.unwrap_model(model, keep_fp32_wrapper=False) + + with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): + with torch.inference_mode(): + if pixel_values is None: + attention_mask = input_ids != self.processing_class.pad_token_id + attention_mask = attention_mask.to(attention_mask.dtype) + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + logits = unwrapped_model( + input_ids = input_ids, + attention_mask = attention_mask, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, + pixel_attention_mask = pixel_attention_mask, + image_sizes = image_sizes, + #logits_to_keep = logits_to_keep + 1, + ).logits + else: + logits = unwrapped_model( input_ids = input_ids, attention_mask = attention_mask, pixel_values = pixel_values, @@ -402,13 +407,16 @@ def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, l image_sizes = image_sizes, logits_to_keep = logits_to_keep + 1, ).logits - entropies = None + - with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): + entropies = None if compute_entropy: from trl.trainer.utils import entropy_from_logits entropies = entropy_from_logits(logits) + + #breakpoint() + os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred return logits, entropies # logps, entropies # input_ids = input_ids[:, -logits_to_keep:] From b4b6c65110130d1e30c01c6d05be3eab333acc81 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 16 Sep 2025 05:42:53 -0700 Subject: [PATCH 24/25] Update unsloth/models/rl_replacements.py --- unsloth/models/rl_replacements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 13b4e4f6a..3eb99a887 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -461,7 +461,6 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch raise ValueError("The GRPOTrainer does not support returning outputs") # Compute the per-token log probabilities for the model - #breakpoint() prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] From 24712caa5086272c9db3c23890467b8e81ecbddf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 16 Sep 2025 05:43:00 -0700 Subject: [PATCH 25/25] Update unsloth/models/rl_replacements.py --- unsloth/models/rl_replacements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 3eb99a887..ec8110689 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -415,7 +415,6 @@ def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, l entropies = entropy_from_logits(logits) - #breakpoint() os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred return logits, entropies # logps, entropies