diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 2555f0df1..ec8110689 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -25,7 +25,7 @@ import torch import inspect from collections import defaultdict -from unsloth_zoo.rl_replacements import RL_REPLACEMENTS +from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding from unsloth import DEVICE_TYPE RL_EXTRA_ARGS = defaultdict(list) @@ -245,6 +245,18 @@ def grpo_trainer__generate_and_score_completions(function_name, function): "prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False", ) + # Left pad prompt before calculation old and ref hidden states + line_to_replace = "batch_size = self.args.per_device_train_batch_size if mode == \"train\" else self.args.per_device_eval_batch_size" + + # The new multi-line string that will replace the line above + replacement_lines = """ + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + if not has_images: + # Left pad prompt before calculation old and ref hidden states + prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id)""" + + function = function.replace(line_to_replace, replacement_lines) + # Always between max_prompt_length and use_vllm found = re.findall( r"\n(([ ]{8,})if self\.max_prompt_length is not None:.*?"\ @@ -263,26 +275,27 @@ def grpo_trainer__generate_and_score_completions(function_name, function): # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). - prompt_ids = prompt_ids[:, -self.max_prompt_length :] - prompt_mask = prompt_mask[:, -self.max_prompt_length :] - prompts_text = self.processing_class.batch_decode( - prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False + protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id] + protected = [token for token in protected if token is not None] + prompt_ids, prompt_mask = truncate_with_protected_tokens( + prompt_ids, prompt_mask, self.max_prompt_length, protected ) - pad_token = self.processing_class.pad_token - def strip_leading_tokens(text): - while text.startswith(pad_token): - text = text.removeprefix(pad_token) - return text - if pad_token is not None: + prompts_text = [re.sub(rf"^({{re.escape(self.pad_token)}})+", "", text) for text in prompts_text] + + # The chat template inserts a single image token into the prompt text. However, when this text is later + # tokenized, the single image token string is expanded into multiple image token IDs, depending on the + # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We + # collapse them back into a single token string to match the original template. + if self.image_token is not None: prompts_text = [ - strip_leading_tokens(text) for text in prompts_text + re.sub(rf"({{re.escape(self.image_token)}})+", self.image_token, text) for text in prompts_text ] - # Generate completions using either vLLM or regular generation if self.use_vllm:""" function = function.replace(replace_part, new_replacement) - pass + + return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__generate_and_score_completions) @@ -304,7 +317,7 @@ def _move_model_to_vllm(self, *args, **kwargs): return None def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function - def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, compute_efficient = False): if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': return None # Unsloth efficient GRPO # Otherwise, calculate normally: @@ -350,28 +363,59 @@ def grpo_trainer__get_per_token_logps_and_entropies(function_name, function): if function_name != "_get_per_token_logps_and_entropies": return function # Just copy over from _get_per_token_logps replacement function above. For now this returns None anyway - def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, logits_to_keep, batch_size = None, compute_entropy = False, *args, **kwargs): - if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': - return None, None # logps, entropies Unsloth efficient GRPO - # Otherwise, calculate normally: - if not hasattr(self, '_autocast_dtype'): - self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 - if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16 - - os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" - with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): - # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - logits = model( - input_ids = input_ids, - attention_mask = attention_mask, - logits_to_keep = logits_to_keep + 1, - ).logits - - entropies = None - if compute_entropy: - from trl.trainer.utils import entropy_from_logits - entropies = entropy_from_logits(logits) - + def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, logits_to_keep, batch_size = None, + compute_entropy = False, compute_efficient = False, *args, **kwargs): + # if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': + # return None, None # logps, entropies Unsloth efficient GRPO + if compute_efficient: + return None, None + else: + # Otherwise, calculate normally: + if not hasattr(self, '_autocast_dtype'): + self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 + if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16 + + pixel_values, image_grid_thw = kwargs.get("pixel_values", None), kwargs.get("image_grid_thw", None) + pixel_attention_mask, image_sizes = kwargs.get('pixel_attention_mask',None), kwargs.get('image_sizes',None) + + os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" + + unwrapped_model = self.accelerator.unwrap_model(model, keep_fp32_wrapper=False) + + with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): + with torch.inference_mode(): + if pixel_values is None: + attention_mask = input_ids != self.processing_class.pad_token_id + attention_mask = attention_mask.to(attention_mask.dtype) + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + logits = unwrapped_model( + input_ids = input_ids, + attention_mask = attention_mask, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, + pixel_attention_mask = pixel_attention_mask, + image_sizes = image_sizes, + #logits_to_keep = logits_to_keep + 1, + ).logits + else: + logits = unwrapped_model( + input_ids = input_ids, + attention_mask = attention_mask, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, + pixel_attention_mask = pixel_attention_mask, + image_sizes = image_sizes, + logits_to_keep = logits_to_keep + 1, + ).logits + + + entropies = None + if compute_entropy: + from trl.trainer.utils import entropy_from_logits + entropies = entropy_from_logits(logits) + + + os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred return logits, entropies # logps, entropies # input_ids = input_ids[:, -logits_to_keep:] @@ -416,8 +460,12 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch raise ValueError("The GRPOTrainer does not support returning outputs") # Compute the per-token log probabilities for the model + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + pixel_values, image_grid_thw = inputs.get("pixel_values", None), inputs.get("image_grid_thw", None) + pixel_attention_mask, image_sizes = inputs.get('pixel_attention_mask',None), inputs.get('image_sizes',None) + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) bsz, qlen = input_ids.shape attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) @@ -427,21 +475,21 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch _logits_to_keep = logits_to_keep get_logps_func = \ - lambda model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False: \ - self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) \ + lambda model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False, compute_efficient = False: \ + self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep, compute_efficient) \ if hasattr(self, "_get_per_token_logps") else \ - self._get_per_token_logps_and_entropies(model, input_ids, attention_mask, logits_to_keep, batch_size, compute_entropy)[0] # logps - - per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep) - + self._get_per_token_logps_and_entropies(model, input_ids, attention_mask, logits_to_keep, batch_size, compute_entropy, compute_efficient)[0] # logps + #breakpoint() + per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep, compute_efficient = True) # Compute the KL divergence between the model and the reference model # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328 - if self.beta != 0.0: - with torch.inference_mode(), model.disable_adapter(): - ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep) - else: - ref_per_token_logps = None + # if self.beta != 0.0: + # with torch.inference_mode(), model.disable_adapter(): + # ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep) + # else: + # ref_per_token_logps = None + ref_hidden_states = inputs.get("ref_per_token_logps", None) # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 # x - x.detach() allows for preserving gradients from x advantages = inputs["advantages"] @@ -449,6 +497,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() old_hidden_states = inputs.get("old_per_token_logps", None) + input_ids = input_ids[:, -logits_to_keep:] # Get logit softcapping and logit scale @@ -459,22 +508,26 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logit_scale_divide = getattr(model.config, "logits_scaling", 0) # Granite if logit_scale_divide is None: logit_scale_divide = 0 - if per_token_logps is not None: - if ref_per_token_logps is not None: - ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + if ref_hidden_states is not None: + ref_hidden_states = ref_hidden_states[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + if old_hidden_states is not None: + old_hidden_states = old_hidden_states[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred per_token_logps = per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred loss, completion_length, mean_kl = grpo_compute_loss_slow( - ref_per_token_logps, + ref_hidden_states, per_token_logps, old_hidden_states, input_ids, completion_mask, self.beta, advantages, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, loss_type = self.args.loss_type, + importance_sampling_level = self.importance_sampling_level, epsilon_low = self.epsilon_low, epsilon_high = self.epsilon_high, max_completion_length = self.args.max_completion_length, @@ -489,12 +542,16 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch loss, completion_length, mean_kl = grpo_accumulated_loss( trainer = self, input_ids = _input_ids, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, logits_to_keep = logits_to_keep, completion_mask = completion_mask, advantages = advantages, old_hidden_states = old_hidden_states, + ref_hidden_states = ref_hidden_states, n_chunks = self.args.unsloth_num_chunks, loss_type = self.args.loss_type, + importance_sampling_level = self.importance_sampling_level, epsilon_low = self.epsilon_low, epsilon_high = self.epsilon_high, max_completion_length = self.args.max_completion_length, @@ -514,6 +571,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch completion_mask = completion_mask, advantages = advantages, old_hidden_states = old_hidden_states, + ref_hidden_states = ref_hidden_states, n_chunks = self.args.unsloth_num_chunks, temperature = self.args.temperature, logit_softcapping = logit_softcapping,