Skip to content
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
fbcb05e
Update rl_replacements.py
pluesclues Aug 8, 2025
a8a39b6
Update rl_replacements.py, unsloth zoo
pluesclues Aug 11, 2025
bcad513
Update rl_replacements.py, addition of gspo
pluesclues Aug 11, 2025
7cc4df7
Merge branch 'unslothai:main' into vlm_grpo_update_official
pluesclues Aug 13, 2025
0555875
Update rl_replacements.py, fixed syntax error
pluesclues Aug 13, 2025
8a71e0b
Update rl_replacements.py, changed log ratios to be torch.ones
pluesclues Aug 15, 2025
d880306
Merge branch 'unslothai:main' into vlm_grpo_update_official
pluesclues Aug 18, 2025
96ba57c
Update rl_replacements.py, fixed log ratio calculations
pluesclues Aug 18, 2025
d9be146
Merge branch 'unslothai:main' into vlm_grpo_update_official
pluesclues Aug 20, 2025
f6824d7
Merge branch 'unslothai:main' into vlm_grpo_update_official
pluesclues Sep 2, 2025
103eb83
Merge branch 'unslothai:main' into vlm_grpo_update_official
pluesclues Sep 8, 2025
524fd94
Merge branch 'unslothai:main' into vlm_grpo_update_official
pluesclues Sep 8, 2025
36a424f
Update rl_replacements.py
pluesclues Sep 9, 2025
728f237
Update rl_replacements.py, vlm grpo with nan changes
pluesclues Sep 9, 2025
7ec2862
Update rl_replacements.py, compute loss fix
pluesclues Sep 9, 2025
21080c9
Update rl_replacements.py, offically merged the two branches
pluesclues Sep 10, 2025
94659fc
Update rl_replacements.py removed comments
pluesclues Sep 10, 2025
9b5bcd2
Update rl_replacements.py, added correct kl metrics
pluesclues Sep 15, 2025
9125ae1
Merge branch 'main' into vlm_grpo_update_official
pluesclues Sep 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 153 additions & 37 deletions unsloth_zoo/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,66 @@ def chunked_selective_log_softmax(logits, index):
return all_per_token_logps
pass

def calculate_pad_tokens_in_prompt(
input_ids: torch.Tensor,
logits_to_keep: int,
pad_token_id: int
) -> torch.Tensor:
"""
Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
"""
if logits_to_keep >= input_ids.shape[1]:
raise ValueError("logits_to_keep must be smaller than the sequence length.")

prompt_section = input_ids[:, :-logits_to_keep]

padding_mask = (prompt_section == pad_token_id)

pad_token_counts = padding_mask.sum(dim=1)

return pad_token_counts

def create_completion_attention_mask(
completion_input_ids: torch.Tensor,
left_pad_tokens_per_prompt: torch.Tensor,
max_left_pad: int,
pad_token_id: int
) -> torch.Tensor:
"""
Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]

Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
and pad are pad tokens, this function would make a completion mask that would 0 out the pad
and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
"""
batch_size, completion_len = completion_input_ids.shape
device = completion_input_ids.device

num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt

indices = torch.arange(completion_len, device=device).unsqueeze(0)
shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)

non_padding_mask = (completion_input_ids != pad_token_id)

final_mask = shift_mask & non_padding_mask

return final_mask

def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
"""
Moves all padding tokens in each sequence of a batch to the right.
"""
mask = (tensor != pad_id)
sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
packed_tensor = torch.gather(tensor, 1, sorted_indices)

return packed_tensor

RL_REPLACEMENTS["selective_log_softmax"] = chunked_selective_log_softmax
RL_REPLACEMENTS["create_completion_attention_mask"] = create_completion_attention_mask
RL_REPLACEMENTS["calculate_pad_tokens_in_prompt"] = calculate_pad_tokens_in_prompt
RL_REPLACEMENTS["left_pack_padding"] = left_pack_padding


# Custom compiled GRPO loss - creates 3 Triton kernels
Expand All @@ -85,6 +144,7 @@ def grpo_compute_loss(
logit_scale_multiply = kwargs.get("logit_scale_multiply", 0.0)
logit_scale_divide = kwargs.get("logit_scale_divide", 0.0)
logit_softcapping = kwargs.get("logit_softcapping", 0.0)
importance_sampling_level = kwargs.get("importance_sampling_level", "token")

input_ids = input_ids.unsqueeze(-1)

Expand All @@ -98,7 +158,6 @@ def grpo_compute_loss(
if temperature != 1.0: new_logits = new_logits / temperature
new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
new = new_x - torch.logsumexp(new_logits, dim = -1)

# x_i - logsumexp(x_i)
with torch.no_grad():
if beta != 0.0:
Expand Down Expand Up @@ -143,9 +202,23 @@ def grpo_compute_loss(
# Below is forward KL (normal KL)
# kl_i = torch.exp(old) * (old - new)
if old_logits is not None:
coef_1 = torch.exp(new - old)
log_ratio = new - old
else:
log_ratio = new - new.detach()

if importance_sampling_level == "token":
log_importance_weights = log_ratio
elif importance_sampling_level == "sequence":
log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
log_importance_weights = log_importance_weights.unsqueeze(-1)
else:
coef_1 = torch.exp(new - new.detach())
raise ValueError(
f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
"and 'sequence'."
)

coef_1 = torch.exp(log_importance_weights)

coef_2 = torch.clamp(coef_1, 1 - epsilon_low, 1 + epsilon_high)

if delta is not None:
Expand Down Expand Up @@ -178,15 +251,19 @@ def grpo_compute_loss(

# loss = (loss_i * mask).sum() / mask.sum()

# Get metrics as well which are folded
with torch.inference_mode():
completion_length = n_mask_per_reward.mean()
n_mask_per_reward = n_mask_per_reward.clamp(min = 1.0) # Counteracts division by 0
mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
mean_kl = mean_kl_per_reward.mean()
pass
completion_length = n_mask_per_reward.mean()

return loss, completion_length, mean_kl
# Get metrics as well which are folded
def masked_batch_mean(x):
with torch.inference_mode():
if x.shape[1] == 1: # when importance_sampling_level == "sequence"
return x.mean()
else:
mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
mean_kl = mean_kl_per_reward.mean()
return mean_kl

return loss, completion_length, masked_batch_mean(kl_i)
pass
RL_REPLACEMENTS["grpo_compute_loss"] = grpo_compute_loss
RL_REPLACEMENTS["grpo_compute_loss_slow"] = \
Expand Down Expand Up @@ -283,7 +360,10 @@ def accumulate_chunk(
old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0)
else:
old_hidden_states = [None] * n_chunks
ref_hidden_states = torch.chunk(_ref_hidden_states, chunks = n_chunks, dim = 0)
if _ref_hidden_states is not None:
ref_hidden_states = torch.chunk(_ref_hidden_states, chunks = n_chunks, dim = 0)
else:
ref_hidden_states = [None] * n_chunks
input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0)
mask = torch.chunk(_mask, chunks = n_chunks, dim = 0)
advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0)
Expand Down Expand Up @@ -347,12 +427,17 @@ def grpo_accumulated_loss(
completion_mask,
advantages,
old_hidden_states,
ref_hidden_states,
n_chunks = -1,
**kwargs,
):
# All Unsloth Zoo code licensed under LGPLv3
bsz, qlen = input_ids.shape

pixel_values = kwargs.get('pixel_values',None)
image_grid_thw = kwargs.get('image_grid_thw',None)
pixel_attention_mask = kwargs.get('pixel_attention_mask',None)
image_sizes = kwargs.get('image_sizes',None)
# Find closest multiple
factors = [i for i in range(1, bsz + 1) if bsz % i == 0]
if n_chunks == -1: n_chunks = bsz
Expand All @@ -364,41 +449,72 @@ def grpo_accumulated_loss(
pass
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"

completion_input_ids = input_ids[:, -logits_to_keep:]
lm_head = trainer.model.get_output_embeddings().weight

with torch.amp.autocast(device_type = trainer.model.device.type, dtype = trainer._autocast_dtype):
with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter():
ref_hidden_states = trainer.model(
if pixel_values is None:
left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(input_ids, logits_to_keep, trainer.processing_class.pad_token_id)

max_left_pad = max(left_pad_tokens_per_prompt).item()

input_ids = left_pack_padding(input_ids, trainer.processing_class.pad_token_id)

completion_input_ids = input_ids[:, -(logits_to_keep +max_left_pad):]

completion_mask = create_completion_attention_mask(completion_input_ids, left_pad_tokens_per_prompt, max_left_pad, trainer.processing_class.pad_token_id).to(attention_mask.dtype)
attention_mask = input_ids != trainer.processing_class.pad_token_id
attention_mask = attention_mask.to(attention_mask.dtype)
else:
completion_input_ids = input_ids[:, -logits_to_keep:]

unwrapped_model = trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False)
with torch.amp.autocast(device_type = trainer.model.device.type, dtype = trainer._autocast_dtype):
if pixel_values is None:
new_hidden_states = unwrapped_model(
input_ids = input_ids,
attention_mask = attention_mask,
pixel_values = pixel_values,
image_grid_thw = image_grid_thw,
pixel_attention_mask = pixel_attention_mask,
image_sizes = image_sizes,
#logits_to_keep = logits_to_keep + 1,
).logits

#keep extra logit as we generated a new token
new_hidden_states = new_hidden_states[:, -(logits_to_keep +max_left_pad+1): , :]
if ref_hidden_states is not None:
ref_hidden_states = ref_hidden_states[:, -(logits_to_keep +max_left_pad+1): , :]
if old_hidden_states is not None:
old_hidden_states = old_hidden_states[:, -(logits_to_keep +max_left_pad+1): , :]
else:
new_hidden_states = unwrapped_model(
input_ids = input_ids,
attention_mask = attention_mask,
pixel_values = pixel_values,
image_grid_thw = image_grid_thw,
pixel_attention_mask = pixel_attention_mask,
image_sizes = image_sizes,
logits_to_keep = logits_to_keep + 1,
).logits
pass
new_hidden_states = trainer.model(
input_ids = input_ids,
attention_mask = attention_mask,
logits_to_keep = logits_to_keep + 1,
).logits

loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
new_hidden_states,
old_hidden_states,
ref_hidden_states,
lm_head,
completion_input_ids,
completion_mask,
advantages,
trainer.beta,
trainer.accelerator.scaler,
n_chunks,
kwargs # pass kwargs as a dict
)

loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
new_hidden_states,
old_hidden_states,
ref_hidden_states,
lm_head,
completion_input_ids,
completion_mask,
advantages,
trainer.beta,
trainer.accelerator.scaler,
n_chunks,
kwargs # pass kwargs as a dict
)
pass

# Must force not returning hidden states but logits otherwise gibberish
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
return loss, completion_length, mean_kl

return loss, completion_length, mean_kl
# Old non efficient code path
new_logits = torch.matmul(new_hidden_states, lm_head.t())
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
Expand Down