Skip to content

Wensun/apo #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 125 additions & 95 deletions compose_rl/algorithms/online/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
ComposerMPTPolicyLM,
)
from compose_rl.algorithms.online.model_methods import (
ALGORITHM_TYPE,
OnPolicyEnum,
)
from compose_rl.algorithms.online.reward_manager import (
Expand Down Expand Up @@ -589,6 +590,7 @@ def iteration_start(self, state: State, logger: Logger):
del logger # unused

batch = self._get_next_iter_prompts()

batch = state.device.batch_to_device(batch)

if self.vllm_engines is not None:
Expand Down Expand Up @@ -648,7 +650,7 @@ def _get_next_iter_prompts(self):
# Explode the batch into multiple batches for each generation
for _ in range(self.generations_per_prompt):
# For keys that do not require additional processing
if key in ['prompt_len', 'verified_answer', 'prompt_id']:
if key in ['prompt_len', 'verified_answer', 'prompt_id', 'vstar']:
curr_values.append(batch[key])
continue

Expand Down Expand Up @@ -678,6 +680,8 @@ def _get_next_iter_prompts(self):
else:
if key == 'verified_answer':
ret_batch[key] = list(flatten(curr_values))
elif key == 'vstar':
ret_batch[key] = list(flatten(curr_values))
else:
# this is an edge case that we will not hit currently, but just handling it as needed
ret_batch[key] = curr_values
Expand Down Expand Up @@ -870,109 +874,135 @@ def _resolve_outputs(
env_outs['right_padded_attn_mask'] = torch.logical_not(
torch.eq(env_outs['obs'], self.pad_token_idx), # type: ignore
)
if self.actor_critic.loss_type not in ALGORITHM_TYPE.REGRESSION:
# Now that rewards are resolved, we can compute advantages
if self.actor_critic.loss_type == OnPolicyEnum.PPO:
env_outs['advantages'] = compute_advantages(
rewards=env_outs['rewards'],
values=env_outs['values'],
gamma=self.gamma,
lambda_gae=self.lambda_gae,
)
elif self.actor_critic.loss_type == OnPolicyEnum.GRPO:
# compute GRPO advantages
prompt_id = env_outs['prompt_id']
rewards = env_outs['rewards']

# Flatten the rewards by summing on sequence length/action_mask
flat_rewards = masked_sum(
rewards,
env_outs['action_mask'],
dim=-1,
)

# Now that rewards are resolved, we can compute advantages
if self.actor_critic.loss_type == OnPolicyEnum.PPO:
env_outs['advantages'] = compute_advantages(
rewards=env_outs['rewards'],
values=env_outs['values'],
gamma=self.gamma,
lambda_gae=self.lambda_gae,
)
elif self.actor_critic.loss_type == OnPolicyEnum.GRPO:
# compute GRPO advantages
prompt_id = env_outs['prompt_id']
rewards = env_outs['rewards']

# Flatten the rewards by summing on sequence length/action_mask
flat_rewards = masked_sum(
rewards,
env_outs['action_mask'],
dim=-1,
)
# Get unique prompt IDs and their indices
unique_prompt_ids, inverse_indices = torch.unique(
prompt_id,
return_inverse=True,
)

# Get unique prompt IDs and their indices
unique_prompt_ids, inverse_indices = torch.unique(
prompt_id,
return_inverse=True,
)
# Use scatter to compute means and standard deviations
# First, we'll create a tensor to track counts, sums, and sum of squares
n_unique = len(unique_prompt_ids)
counts = torch.zeros(n_unique, device=prompt_id.device)
sums = torch.zeros(n_unique, device=prompt_id.device)
sum_squares = torch.zeros(n_unique, device=prompt_id.device)

# Use scatter_add to accumulate values
counts.scatter_add_(
0,
inverse_indices,
torch.ones_like(flat_rewards),
)
sums.scatter_add_(0, inverse_indices, flat_rewards)
sum_squares.scatter_add_(0, inverse_indices, flat_rewards**2)

# Compute means and standard deviations
means = sums / counts
variances = (sum_squares / counts) - (means**2)
stds = torch.sqrt(variances)

# Map back to original tensor shape
mean_rewards = means[inverse_indices]
std_rewards = stds[inverse_indices]

# Calculate GRPO advantage
grpo_advantage = (flat_rewards - mean_rewards)
# Only normalize the advantage if flag is set
if self.actor_critic.normalize_advantage:
grpo_advantage /= (std_rewards + 1e-4)

# Create advantages of the same shape as original rewards
advantages = torch.zeros_like(rewards)
# Copy the flat grpo_advantage according to action_mask
expanded_advantages = grpo_advantage.unsqueeze(1).expand_as(
env_outs['action_mask'],
)
advantages = torch.where(
env_outs['action_mask'].bool(),
expanded_advantages,
advantages,
)
env_outs['advantages'] = advantages
else:
raise ValueError(
f'Invalid loss type: {self.actor_critic.loss_type}. ' +
'Valid options are: ppo, grpo.',
)

# Use scatter to compute means and standard deviations
# First, we'll create a tensor to track counts, sums, and sum of squares
n_unique = len(unique_prompt_ids)
counts = torch.zeros(n_unique, device=prompt_id.device)
sums = torch.zeros(n_unique, device=prompt_id.device)
sum_squares = torch.zeros(n_unique, device=prompt_id.device)

# Use scatter_add to accumulate values
counts.scatter_add_(
0,
inverse_indices,
torch.ones_like(flat_rewards),
)
sums.scatter_add_(0, inverse_indices, flat_rewards)
sum_squares.scatter_add_(0, inverse_indices, flat_rewards**2)

# Compute means and standard deviations
means = sums / counts
variances = (sum_squares / counts) - (means**2)
stds = torch.sqrt(variances)

# Map back to original tensor shape
mean_rewards = means[inverse_indices]
std_rewards = stds[inverse_indices]

# Calculate GRPO advantage
grpo_advantage = (flat_rewards - mean_rewards)
# Only normalize the advantage if flag is set
if self.actor_critic.normalize_advantage:
grpo_advantage /= (std_rewards + 1e-4)

# Create advantages of the same shape as original rewards
advantages = torch.zeros_like(rewards)
# Copy the flat grpo_advantage according to action_mask
expanded_advantages = grpo_advantage.unsqueeze(1).expand_as(
batch_adv_mean, batch_adv_var = dist_compute_masked_mean_and_var(
env_outs['advantages'],
env_outs['action_mask'],
)
advantages = torch.where(
env_outs['action_mask'].bool(),
expanded_advantages,
advantages,

mean_ift = masked_mean(
env_outs['ift_kl'],
env_outs['action_mask'],
)
env_outs['advantages'] = advantages
self.kl_ift.append(mean_ift.cpu())

iter_batch.update(env_outs)

iter_batch.update({
'max_gen_len':
torch.ones(self.iter_batch_size).to(torch.int32) *
self.max_gen_len,
'adv_masked_mean':
torch.ones(self.iter_batch_size) * batch_adv_mean.cpu(),
'adv_masked_var':
torch.ones(self.iter_batch_size) * batch_adv_var.cpu(),
'ift_kl_scalar':
torch.ones(self.iter_batch_size) * self.kl_ctl.value,
'reward_std':
torch.ones(self.iter_batch_size) *
env_outs['rewards'].std().to('cpu'),
})
else:
raise ValueError(
f'Invalid loss type: {self.actor_critic.loss_type}. ' +
'Valid options are: ppo, grpo.',
)
# APO and REBEL

batch_adv_mean, batch_adv_var = dist_compute_masked_mean_and_var(
env_outs['advantages'],
env_outs['action_mask'],
)
mean_ift = masked_mean(
env_outs['ift_kl'],
env_outs['action_mask'],
)
self.kl_ift.append(mean_ift.cpu())

iter_batch.update(env_outs)

iter_batch.update({
'max_gen_len':
torch.ones(self.iter_batch_size).to(torch.int32) *
self.max_gen_len,
'adv_masked_mean':
torch.ones(self.iter_batch_size),
'adv_masked_var':
torch.ones(self.iter_batch_size),
'ift_kl_scalar':
torch.ones(self.iter_batch_size) * self.kl_ctl.value,
'reward_std':
torch.ones(self.iter_batch_size) *
env_outs['rewards'].std().to('cpu'),
})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this block of code for both algorithms very similar to each other, except for the adv_masked_mean bit? If so can we condense it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


mean_ift = masked_mean(
env_outs['ift_kl'],
env_outs['action_mask'],
)
self.kl_ift.append(mean_ift.cpu())

iter_batch.update(env_outs)

iter_batch.update({
'max_gen_len':
torch.ones(self.iter_batch_size).to(torch.int32) *
self.max_gen_len,
'adv_masked_mean':
torch.ones(self.iter_batch_size) * batch_adv_mean.cpu(),
'adv_masked_var':
torch.ones(self.iter_batch_size) * batch_adv_var.cpu(),
'ift_kl_scalar':
torch.ones(self.iter_batch_size) * self.kl_ctl.value,
'reward_std':
torch.ones(self.iter_batch_size) *
env_outs['rewards'].std().to('cpu'),
})

# Moving minibatches to CPU to not take additional GPU memory
for k, v in iter_batch.items():
Expand Down
6 changes: 3 additions & 3 deletions compose_rl/algorithms/online/generation_utils/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from compose_rl.algorithms.online.generation_utils.vllm_actor import LLMRayActor
from compose_rl.algorithms.online.model_methods import OnPolicyEnum
from compose_rl.algorithms.online.model_methods import ALGORITHM_TYPE, OnPolicyEnum

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -362,7 +362,7 @@ def should_update_torch_module(
if parsed_module_name not in valid_non_leaf_module_names:
return False

if loss_type == OnPolicyEnum.GRPO:
if loss_type in ALGORITHM_TYPE.CRITIC_FREE:
return True

if loss_type == OnPolicyEnum.PPO and 'lm_backbone' in full_param_name:
Expand Down Expand Up @@ -394,7 +394,7 @@ def broadcast_to_vllm(
count, num_params = 0, len(
list(model.model.lm_backbone.named_parameters()), # type: ignore
)
elif loss_type == OnPolicyEnum.GRPO:
elif loss_type in ALGORITHM_TYPE.CRITIC_FREE:
# Directly use the model params
count, num_params = 0, len(
list(model.model.named_parameters()), # type: ignore
Expand Down
5 changes: 5 additions & 0 deletions compose_rl/algorithms/online/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping):
value_clip_range=self.config.value_clip_range,
value_loss_weight=self.config.value_loss_weight,
policy_clip_ratio=self.config.policy_clip_ratio,
beta = self.config.beta, #added beta
add_direct_kl_loss=self.config.compute_kl_loss,
kl_estimator=self.config.kl_estimator,
kl_clip_range=self.config.kl_clip_range,
Expand Down Expand Up @@ -217,6 +218,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping):
value_clip_range=self.config.value_clip_range,
value_loss_weight=self.config.value_loss_weight,
policy_clip_ratio=self.config.policy_clip_ratio,
beta = self.config.beta, #added beta parameter
add_direct_kl_loss=self.config.compute_kl_loss,
kl_estimator=self.config.kl_estimator,
kl_clip_range=self.config.kl_clip_range,
Expand Down Expand Up @@ -255,6 +257,7 @@ def __init__(
length_normalize_policy_loss: bool = True,
policy_clip_ratio: float = 0.15,
policy_clip_high_ratio: float | None = None,
beta: float = 1e-3, #added beta
compute_kl_loss: bool = True,
target_kl: float = 0.1,
kl_estimator: str = 'k3',
Expand Down Expand Up @@ -283,6 +286,7 @@ def __init__(
self.policy_clip_high_ratio = policy_clip_high_ratio
self.compute_kl_loss = compute_kl_loss
self.target_kl = target_kl
self.beta = beta
self.kl_estimator = kl_estimator
self.kl_clip_range = kl_clip_range

Expand All @@ -306,6 +310,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping):
loss_type=self.loss_type,
policy_clip_ratio=self.policy_clip_ratio,
policy_clip_high_ratio=self.policy_clip_high_ratio,
beta = self.beta, #added beta
length_normalize_policy_loss=self.length_normalize_policy_loss,
add_direct_kl_loss=self.compute_kl_loss,
kl_estimator=self.kl_estimator,
Expand Down
Loading
Loading