Skip to content

Commit 96c2ecb

Browse files
TRL Updated version of VLM GRPO update along with GSPO (#3132)
* Kept, padding logic * Made sure prediction step in rl.py allows logging for callbacks in RL trainers * updated llama.py to new online_dpo changes * Update rl.py to make logic simpiler * Update rl.py, made sure tokenized_output on eval step was on same device * Update rl.py, corrected tokenized_outputs to inputs * Update rl.py, removed sagemaker stuff * Update llama.py, figures out if there is right padding automatically * Update llama.py, changed conditional statement for right padding slightlyt * Update llama.py, updated OS.environ variable to temp variable * Update rl.py, made it account for right padding in online dpo and reward modeling * Update llama.py, automatically figures out if right padding is needed * Update rl_replacements.py, fixed up passing image data to functions * Update rl_replacements.py, for VLM GRPO support with TRL * Update rl_replacements.py, gspo added * Update rl.py, forgot about Online_DPO changes in this branch * Update rl.py, forgot to not include Online DPO PR changes * Update llama.py, forgot to disinclude Online DPO PR changes * Update rl_replacements.py, updated generate and score completions to be up to date for trl * Update rl_replacements.py * Update rl_replacements.py, fixed nan issues with vlms * Update rl_replacements.py, added indent * Update rl_replacements.py, added attention mask to calculations of old and ref hidden states * Update unsloth/models/rl_replacements.py * Update unsloth/models/rl_replacements.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
1 parent a6176ad commit 96c2ecb

File tree

1 file changed

+110
-52
lines changed

1 file changed

+110
-52
lines changed

unsloth/models/rl_replacements.py

Lines changed: 110 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import torch
2626
import inspect
2727
from collections import defaultdict
28-
from unsloth_zoo.rl_replacements import RL_REPLACEMENTS
28+
from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding
2929
from unsloth import DEVICE_TYPE
3030

3131
RL_EXTRA_ARGS = defaultdict(list)
@@ -245,6 +245,18 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
245245
"prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False",
246246
)
247247

248+
# Left pad prompt before calculation old and ref hidden states
249+
line_to_replace = "batch_size = self.args.per_device_train_batch_size if mode == \"train\" else self.args.per_device_eval_batch_size"
250+
251+
# The new multi-line string that will replace the line above
252+
replacement_lines = """
253+
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
254+
if not has_images:
255+
# Left pad prompt before calculation old and ref hidden states
256+
prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id)"""
257+
258+
function = function.replace(line_to_replace, replacement_lines)
259+
248260
# Always between max_prompt_length and use_vllm
249261
found = re.findall(
250262
r"\n(([ ]{8,})if self\.max_prompt_length is not None:.*?"\
@@ -263,26 +275,27 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
263275
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
264276
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
265277
# because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).
266-
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
267-
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
268-
prompts_text = self.processing_class.batch_decode(
269-
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
278+
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
279+
protected = [token for token in protected if token is not None]
280+
prompt_ids, prompt_mask = truncate_with_protected_tokens(
281+
prompt_ids, prompt_mask, self.max_prompt_length, protected
270282
)
271-
pad_token = self.processing_class.pad_token
272-
def strip_leading_tokens(text):
273-
while text.startswith(pad_token):
274-
text = text.removeprefix(pad_token)
275-
return text
276283
277-
if pad_token is not None:
284+
prompts_text = [re.sub(rf"^({{re.escape(self.pad_token)}})+", "", text) for text in prompts_text]
285+
286+
# The chat template inserts a single image token into the prompt text. However, when this text is later
287+
# tokenized, the single image token string is expanded into multiple image token IDs, depending on the
288+
# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
289+
# collapse them back into a single token string to match the original template.
290+
if self.image_token is not None:
278291
prompts_text = [
279-
strip_leading_tokens(text) for text in prompts_text
292+
re.sub(rf"({{re.escape(self.image_token)}})+", self.image_token, text) for text in prompts_text
280293
]
281-
282294
# Generate completions using either vLLM or regular generation
283295
if self.use_vllm:"""
284296
function = function.replace(replace_part, new_replacement)
285-
pass
297+
298+
286299
return function
287300
pass
288301
RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__generate_and_score_completions)
@@ -304,7 +317,7 @@ def _move_model_to_vllm(self, *args, **kwargs): return None
304317
def grpo_trainer__get_per_token_logps(function_name, function):
305318
if function_name != "_get_per_token_logps": return function
306319

307-
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
320+
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, compute_efficient = False):
308321
if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
309322
return None # Unsloth efficient GRPO
310323
# Otherwise, calculate normally:
@@ -350,28 +363,59 @@ def grpo_trainer__get_per_token_logps_and_entropies(function_name, function):
350363
if function_name != "_get_per_token_logps_and_entropies": return function
351364

352365
# Just copy over from _get_per_token_logps replacement function above. For now this returns None anyway
353-
def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, logits_to_keep, batch_size = None, compute_entropy = False, *args, **kwargs):
354-
if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
355-
return None, None # logps, entropies Unsloth efficient GRPO
356-
# Otherwise, calculate normally:
357-
if not hasattr(self, '_autocast_dtype'):
358-
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
359-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
360-
361-
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
362-
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
363-
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
364-
logits = model(
365-
input_ids = input_ids,
366-
attention_mask = attention_mask,
367-
logits_to_keep = logits_to_keep + 1,
368-
).logits
369-
370-
entropies = None
371-
if compute_entropy:
372-
from trl.trainer.utils import entropy_from_logits
373-
entropies = entropy_from_logits(logits)
374-
366+
def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, logits_to_keep, batch_size = None,
367+
compute_entropy = False, compute_efficient = False, *args, **kwargs):
368+
# if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
369+
# return None, None # logps, entropies Unsloth efficient GRPO
370+
if compute_efficient:
371+
return None, None
372+
else:
373+
# Otherwise, calculate normally:
374+
if not hasattr(self, '_autocast_dtype'):
375+
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
376+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
377+
378+
pixel_values, image_grid_thw = kwargs.get("pixel_values", None), kwargs.get("image_grid_thw", None)
379+
pixel_attention_mask, image_sizes = kwargs.get('pixel_attention_mask',None), kwargs.get('image_sizes',None)
380+
381+
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
382+
383+
unwrapped_model = self.accelerator.unwrap_model(model, keep_fp32_wrapper=False)
384+
385+
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
386+
with torch.inference_mode():
387+
if pixel_values is None:
388+
attention_mask = input_ids != self.processing_class.pad_token_id
389+
attention_mask = attention_mask.to(attention_mask.dtype)
390+
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
391+
logits = unwrapped_model(
392+
input_ids = input_ids,
393+
attention_mask = attention_mask,
394+
pixel_values = pixel_values,
395+
image_grid_thw = image_grid_thw,
396+
pixel_attention_mask = pixel_attention_mask,
397+
image_sizes = image_sizes,
398+
#logits_to_keep = logits_to_keep + 1,
399+
).logits
400+
else:
401+
logits = unwrapped_model(
402+
input_ids = input_ids,
403+
attention_mask = attention_mask,
404+
pixel_values = pixel_values,
405+
image_grid_thw = image_grid_thw,
406+
pixel_attention_mask = pixel_attention_mask,
407+
image_sizes = image_sizes,
408+
logits_to_keep = logits_to_keep + 1,
409+
).logits
410+
411+
412+
entropies = None
413+
if compute_entropy:
414+
from trl.trainer.utils import entropy_from_logits
415+
entropies = entropy_from_logits(logits)
416+
417+
418+
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
375419
# logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
376420
return logits, entropies # logps, entropies
377421
# input_ids = input_ids[:, -logits_to_keep:]
@@ -416,8 +460,12 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
416460
raise ValueError("The GRPOTrainer does not support returning outputs")
417461
# Compute the per-token log probabilities for the model
418462

463+
419464
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
420465
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
466+
pixel_values, image_grid_thw = inputs.get("pixel_values", None), inputs.get("image_grid_thw", None)
467+
pixel_attention_mask, image_sizes = inputs.get('pixel_attention_mask',None), inputs.get('image_sizes',None)
468+
421469
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
422470
bsz, qlen = input_ids.shape
423471
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
@@ -427,28 +475,29 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
427475
_logits_to_keep = logits_to_keep
428476

429477
get_logps_func = \
430-
lambda model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False: \
431-
self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) \
478+
lambda model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False, compute_efficient = False: \
479+
self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep, compute_efficient) \
432480
if hasattr(self, "_get_per_token_logps") else \
433-
self._get_per_token_logps_and_entropies(model, input_ids, attention_mask, logits_to_keep, batch_size, compute_entropy)[0] # logps
434-
435-
per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep)
436-
481+
self._get_per_token_logps_and_entropies(model, input_ids, attention_mask, logits_to_keep, batch_size, compute_entropy, compute_efficient)[0] # logps
482+
#breakpoint()
483+
per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep, compute_efficient = True)
437484
# Compute the KL divergence between the model and the reference model
438485
# _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves.
439486
# https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328
440-
if self.beta != 0.0:
441-
with torch.inference_mode(), model.disable_adapter():
442-
ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep)
443-
else:
444-
ref_per_token_logps = None
487+
# if self.beta != 0.0:
488+
# with torch.inference_mode(), model.disable_adapter():
489+
# ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep)
490+
# else:
491+
# ref_per_token_logps = None
492+
ref_hidden_states = inputs.get("ref_per_token_logps", None)
445493
# per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
446494
# x - x.detach() allows for preserving gradients from x
447495
advantages = inputs["advantages"]
448496
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
449497
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
450498
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
451499
old_hidden_states = inputs.get("old_per_token_logps", None)
500+
452501
input_ids = input_ids[:, -logits_to_keep:]
453502

454503
# Get logit softcapping and logit scale
@@ -459,22 +508,26 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
459508
logit_scale_divide = getattr(model.config, "logits_scaling", 0) # Granite
460509
if logit_scale_divide is None: logit_scale_divide = 0
461510

462-
463511
if per_token_logps is not None:
464512

465-
if ref_per_token_logps is not None:
466-
ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
513+
if ref_hidden_states is not None:
514+
ref_hidden_states = ref_hidden_states[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
515+
if old_hidden_states is not None:
516+
old_hidden_states = old_hidden_states[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
467517
per_token_logps = per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
468518

469519
loss, completion_length, mean_kl = grpo_compute_loss_slow(
470-
ref_per_token_logps,
520+
ref_hidden_states,
471521
per_token_logps,
472522
old_hidden_states,
473523
input_ids,
474524
completion_mask,
475525
self.beta,
476526
advantages,
527+
pixel_values = pixel_values,
528+
image_grid_thw = image_grid_thw,
477529
loss_type = self.args.loss_type,
530+
importance_sampling_level = self.importance_sampling_level,
478531
epsilon_low = self.epsilon_low,
479532
epsilon_high = self.epsilon_high,
480533
max_completion_length = self.args.max_completion_length,
@@ -489,12 +542,16 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
489542
loss, completion_length, mean_kl = grpo_accumulated_loss(
490543
trainer = self,
491544
input_ids = _input_ids,
545+
pixel_values = pixel_values,
546+
image_grid_thw = image_grid_thw,
492547
logits_to_keep = logits_to_keep,
493548
completion_mask = completion_mask,
494549
advantages = advantages,
495550
old_hidden_states = old_hidden_states,
551+
ref_hidden_states = ref_hidden_states,
496552
n_chunks = self.args.unsloth_num_chunks,
497553
loss_type = self.args.loss_type,
554+
importance_sampling_level = self.importance_sampling_level,
498555
epsilon_low = self.epsilon_low,
499556
epsilon_high = self.epsilon_high,
500557
max_completion_length = self.args.max_completion_length,
@@ -514,6 +571,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
514571
completion_mask = completion_mask,
515572
advantages = advantages,
516573
old_hidden_states = old_hidden_states,
574+
ref_hidden_states = ref_hidden_states,
517575
n_chunks = self.args.unsloth_num_chunks,
518576
temperature = self.args.temperature,
519577
logit_softcapping = logit_softcapping,

0 commit comments

Comments
 (0)