Skip to content

Commit ab3fade

Browse files
authored
Vlm GRPO TRL updated version (#233)
* Update rl_replacements.py Temporary but not correct GRPO implementation * Update rl_replacements.py, unsloth zoo * Update rl_replacements.py, addition of gspo * Update rl_replacements.py, fixed syntax error * Update rl_replacements.py, changed log ratios to be torch.ones * Update rl_replacements.py, fixed log ratio calculations * Update rl_replacements.py * Update rl_replacements.py, vlm grpo with nan changes * Update rl_replacements.py, compute loss fix * Update rl_replacements.py, offically merged the two branches * Update rl_replacements.py removed comments * Update rl_replacements.py, added correct kl metrics
1 parent 5287def commit ab3fade

File tree

1 file changed

+153
-37
lines changed

1 file changed

+153
-37
lines changed

unsloth_zoo/rl_replacements.py

Lines changed: 153 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,66 @@ def chunked_selective_log_softmax(logits, index):
6060
return all_per_token_logps
6161
pass
6262

63+
def calculate_pad_tokens_in_prompt(
64+
input_ids: torch.Tensor,
65+
logits_to_keep: int,
66+
pad_token_id: int
67+
) -> torch.Tensor:
68+
"""
69+
Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
70+
"""
71+
if logits_to_keep >= input_ids.shape[1]:
72+
raise ValueError("logits_to_keep must be smaller than the sequence length.")
73+
74+
prompt_section = input_ids[:, :-logits_to_keep]
75+
76+
padding_mask = (prompt_section == pad_token_id)
77+
78+
pad_token_counts = padding_mask.sum(dim=1)
79+
80+
return pad_token_counts
81+
82+
def create_completion_attention_mask(
83+
completion_input_ids: torch.Tensor,
84+
left_pad_tokens_per_prompt: torch.Tensor,
85+
max_left_pad: int,
86+
pad_token_id: int
87+
) -> torch.Tensor:
88+
"""
89+
Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
90+
91+
Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
92+
and pad are pad tokens, this function would make a completion mask that would 0 out the pad
93+
and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
94+
"""
95+
batch_size, completion_len = completion_input_ids.shape
96+
device = completion_input_ids.device
97+
98+
num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
99+
100+
indices = torch.arange(completion_len, device=device).unsqueeze(0)
101+
shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
102+
103+
non_padding_mask = (completion_input_ids != pad_token_id)
104+
105+
final_mask = shift_mask & non_padding_mask
106+
107+
return final_mask
108+
109+
def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
110+
"""
111+
Moves all padding tokens in each sequence of a batch to the right.
112+
"""
113+
mask = (tensor != pad_id)
114+
sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
115+
packed_tensor = torch.gather(tensor, 1, sorted_indices)
116+
117+
return packed_tensor
118+
63119
RL_REPLACEMENTS["selective_log_softmax"] = chunked_selective_log_softmax
120+
RL_REPLACEMENTS["create_completion_attention_mask"] = create_completion_attention_mask
121+
RL_REPLACEMENTS["calculate_pad_tokens_in_prompt"] = calculate_pad_tokens_in_prompt
122+
RL_REPLACEMENTS["left_pack_padding"] = left_pack_padding
64123

65124

66125
# Custom compiled GRPO loss - creates 3 Triton kernels
@@ -85,6 +144,7 @@ def grpo_compute_loss(
85144
logit_scale_multiply = kwargs.get("logit_scale_multiply", 0.0)
86145
logit_scale_divide = kwargs.get("logit_scale_divide", 0.0)
87146
logit_softcapping = kwargs.get("logit_softcapping", 0.0)
147+
importance_sampling_level = kwargs.get("importance_sampling_level", "token")
88148

89149
input_ids = input_ids.unsqueeze(-1)
90150

@@ -98,7 +158,6 @@ def grpo_compute_loss(
98158
if temperature != 1.0: new_logits = new_logits / temperature
99159
new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
100160
new = new_x - torch.logsumexp(new_logits, dim = -1)
101-
102161
# x_i - logsumexp(x_i)
103162
with torch.no_grad():
104163
if beta != 0.0:
@@ -143,9 +202,23 @@ def grpo_compute_loss(
143202
# Below is forward KL (normal KL)
144203
# kl_i = torch.exp(old) * (old - new)
145204
if old_logits is not None:
146-
coef_1 = torch.exp(new - old)
205+
log_ratio = new - old
206+
else:
207+
log_ratio = new - new.detach()
208+
209+
if importance_sampling_level == "token":
210+
log_importance_weights = log_ratio
211+
elif importance_sampling_level == "sequence":
212+
log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
213+
log_importance_weights = log_importance_weights.unsqueeze(-1)
147214
else:
148-
coef_1 = torch.exp(new - new.detach())
215+
raise ValueError(
216+
f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
217+
"and 'sequence'."
218+
)
219+
220+
coef_1 = torch.exp(log_importance_weights)
221+
149222
coef_2 = torch.clamp(coef_1, 1 - epsilon_low, 1 + epsilon_high)
150223

151224
if delta is not None:
@@ -178,15 +251,19 @@ def grpo_compute_loss(
178251

179252
# loss = (loss_i * mask).sum() / mask.sum()
180253

181-
# Get metrics as well which are folded
182-
with torch.inference_mode():
183-
completion_length = n_mask_per_reward.mean()
184-
n_mask_per_reward = n_mask_per_reward.clamp(min = 1.0) # Counteracts division by 0
185-
mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
186-
mean_kl = mean_kl_per_reward.mean()
187-
pass
254+
completion_length = n_mask_per_reward.mean()
188255

189-
return loss, completion_length, mean_kl
256+
# Get metrics as well which are folded
257+
def masked_batch_mean(x):
258+
with torch.inference_mode():
259+
if x.shape[1] == 1: # when importance_sampling_level == "sequence"
260+
return x.mean()
261+
else:
262+
mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
263+
mean_kl = mean_kl_per_reward.mean()
264+
return mean_kl
265+
266+
return loss, completion_length, masked_batch_mean(kl_i)
190267
pass
191268
RL_REPLACEMENTS["grpo_compute_loss"] = grpo_compute_loss
192269
RL_REPLACEMENTS["grpo_compute_loss_slow"] = \
@@ -283,7 +360,10 @@ def accumulate_chunk(
283360
old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0)
284361
else:
285362
old_hidden_states = [None] * n_chunks
286-
ref_hidden_states = torch.chunk(_ref_hidden_states, chunks = n_chunks, dim = 0)
363+
if _ref_hidden_states is not None:
364+
ref_hidden_states = torch.chunk(_ref_hidden_states, chunks = n_chunks, dim = 0)
365+
else:
366+
ref_hidden_states = [None] * n_chunks
287367
input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0)
288368
mask = torch.chunk(_mask, chunks = n_chunks, dim = 0)
289369
advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0)
@@ -347,12 +427,17 @@ def grpo_accumulated_loss(
347427
completion_mask,
348428
advantages,
349429
old_hidden_states,
430+
ref_hidden_states,
350431
n_chunks = -1,
351432
**kwargs,
352433
):
353434
# All Unsloth Zoo code licensed under LGPLv3
354435
bsz, qlen = input_ids.shape
355436

437+
pixel_values = kwargs.get('pixel_values',None)
438+
image_grid_thw = kwargs.get('image_grid_thw',None)
439+
pixel_attention_mask = kwargs.get('pixel_attention_mask',None)
440+
image_sizes = kwargs.get('image_sizes',None)
356441
# Find closest multiple
357442
factors = [i for i in range(1, bsz + 1) if bsz % i == 0]
358443
if n_chunks == -1: n_chunks = bsz
@@ -364,41 +449,72 @@ def grpo_accumulated_loss(
364449
pass
365450
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
366451

367-
completion_input_ids = input_ids[:, -logits_to_keep:]
368452
lm_head = trainer.model.get_output_embeddings().weight
369453

370-
with torch.amp.autocast(device_type = trainer.model.device.type, dtype = trainer._autocast_dtype):
371-
with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter():
372-
ref_hidden_states = trainer.model(
454+
if pixel_values is None:
455+
left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(input_ids, logits_to_keep, trainer.processing_class.pad_token_id)
456+
457+
max_left_pad = max(left_pad_tokens_per_prompt).item()
458+
459+
input_ids = left_pack_padding(input_ids, trainer.processing_class.pad_token_id)
460+
461+
completion_input_ids = input_ids[:, -(logits_to_keep +max_left_pad):]
462+
463+
completion_mask = create_completion_attention_mask(completion_input_ids, left_pad_tokens_per_prompt, max_left_pad, trainer.processing_class.pad_token_id).to(attention_mask.dtype)
464+
attention_mask = input_ids != trainer.processing_class.pad_token_id
465+
attention_mask = attention_mask.to(attention_mask.dtype)
466+
else:
467+
completion_input_ids = input_ids[:, -logits_to_keep:]
468+
469+
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False)
470+
with torch.amp.autocast(device_type = trainer.model.device.type, dtype = trainer._autocast_dtype):
471+
if pixel_values is None:
472+
new_hidden_states = unwrapped_model(
473+
input_ids = input_ids,
474+
attention_mask = attention_mask,
475+
pixel_values = pixel_values,
476+
image_grid_thw = image_grid_thw,
477+
pixel_attention_mask = pixel_attention_mask,
478+
image_sizes = image_sizes,
479+
#logits_to_keep = logits_to_keep + 1,
480+
).logits
481+
482+
#keep extra logit as we generated a new token
483+
new_hidden_states = new_hidden_states[:, -(logits_to_keep +max_left_pad+1): , :]
484+
if ref_hidden_states is not None:
485+
ref_hidden_states = ref_hidden_states[:, -(logits_to_keep +max_left_pad+1): , :]
486+
if old_hidden_states is not None:
487+
old_hidden_states = old_hidden_states[:, -(logits_to_keep +max_left_pad+1): , :]
488+
else:
489+
new_hidden_states = unwrapped_model(
373490
input_ids = input_ids,
374491
attention_mask = attention_mask,
492+
pixel_values = pixel_values,
493+
image_grid_thw = image_grid_thw,
494+
pixel_attention_mask = pixel_attention_mask,
495+
image_sizes = image_sizes,
375496
logits_to_keep = logits_to_keep + 1,
376497
).logits
377-
pass
378-
new_hidden_states = trainer.model(
379-
input_ids = input_ids,
380-
attention_mask = attention_mask,
381-
logits_to_keep = logits_to_keep + 1,
382-
).logits
383-
384-
loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
385-
new_hidden_states,
386-
old_hidden_states,
387-
ref_hidden_states,
388-
lm_head,
389-
completion_input_ids,
390-
completion_mask,
391-
advantages,
392-
trainer.beta,
393-
trainer.accelerator.scaler,
394-
n_chunks,
395-
kwargs # pass kwargs as a dict
396-
)
498+
499+
loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
500+
new_hidden_states,
501+
old_hidden_states,
502+
ref_hidden_states,
503+
lm_head,
504+
completion_input_ids,
505+
completion_mask,
506+
advantages,
507+
trainer.beta,
508+
trainer.accelerator.scaler,
509+
n_chunks,
510+
kwargs # pass kwargs as a dict
511+
)
397512
pass
513+
398514
# Must force not returning hidden states but logits otherwise gibberish
399515
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
400-
return loss, completion_length, mean_kl
401516

517+
return loss, completion_length, mean_kl
402518
# Old non efficient code path
403519
new_logits = torch.matmul(new_hidden_states, lm_head.t())
404520
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred

0 commit comments

Comments
 (0)