Skip to content

Wensun/apo #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 90 additions & 81 deletions compose_rl/algorithms/online/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
ComposerMPTPolicyLM,
)
from compose_rl.algorithms.online.model_methods import (
ALGORITHM_TYPE,
OnPolicyEnum,
)
from compose_rl.algorithms.online.reward_manager import (
Expand Down Expand Up @@ -590,6 +591,7 @@ def iteration_start(self, state: State, logger: Logger):
del logger # unused

batch = self._get_next_iter_prompts()

batch = state.device.batch_to_device(batch)

if self.vllm_engines is not None:
Expand Down Expand Up @@ -649,7 +651,12 @@ def _get_next_iter_prompts(self):
# Explode the batch into multiple batches for each generation
for _ in range(self.generations_per_prompt):
# For keys that do not require additional processing
if key in ['prompt_len', 'verified_answer', 'prompt_id']:
if key in [
'prompt_len',
'verified_answer',
'prompt_id',
'vstar',
]:
curr_values.append(batch[key])
continue

Expand Down Expand Up @@ -677,7 +684,7 @@ def _get_next_iter_prompts(self):
if isinstance(curr_values[0], torch.Tensor):
ret_batch[key] = torch.cat(curr_values)
else:
if key == 'verified_answer':
if key in ['verified_answer', 'vstar']:
ret_batch[key] = list(flatten(curr_values))
else:
# this is an edge case that we will not hit currently, but just handling it as needed
Expand Down Expand Up @@ -871,86 +878,93 @@ def _resolve_outputs(
env_outs['right_padded_attn_mask'] = torch.logical_not(
torch.eq(env_outs['obs'], self.pad_token_idx), # type: ignore
)
if self.actor_critic.loss_type not in ALGORITHM_TYPE.REGRESSION:
# Now that rewards are resolved, we can compute advantages
if self.actor_critic.loss_type == OnPolicyEnum.PPO:
env_outs['advantages'] = compute_advantages(
rewards=env_outs['rewards'],
values=env_outs['values'],
gamma=self.gamma,
lambda_gae=self.lambda_gae,
)
elif self.actor_critic.loss_type == OnPolicyEnum.GRPO:
# compute GRPO advantages
prompt_id = env_outs['prompt_id']
rewards = env_outs['rewards']

# Flatten the rewards by summing on sequence length/action_mask
flat_rewards = masked_sum(
rewards,
env_outs['action_mask'],
dim=-1,
)

# Now that rewards are resolved, we can compute advantages
if self.actor_critic.loss_type == OnPolicyEnum.PPO:
env_outs['advantages'] = compute_advantages(
rewards=env_outs['rewards'],
values=env_outs['values'],
gamma=self.gamma,
lambda_gae=self.lambda_gae,
)
elif self.actor_critic.loss_type == OnPolicyEnum.GRPO:
# compute GRPO advantages
prompt_id = env_outs['prompt_id']
rewards = env_outs['rewards']

# Flatten the rewards by summing on sequence length/action_mask
flat_rewards = masked_sum(
rewards,
env_outs['action_mask'],
dim=-1,
)
# Get unique prompt IDs and their indices
unique_prompt_ids, inverse_indices = torch.unique(
prompt_id,
return_inverse=True,
)

# Get unique prompt IDs and their indices
unique_prompt_ids, inverse_indices = torch.unique(
prompt_id,
return_inverse=True,
)
# Use scatter to compute means and standard deviations
# First, we'll create a tensor to track counts, sums, and sum of squares
n_unique = len(unique_prompt_ids)
counts = torch.zeros(n_unique, device=prompt_id.device)
sums = torch.zeros(n_unique, device=prompt_id.device)
sum_squares = torch.zeros(n_unique, device=prompt_id.device)

# Use scatter_add to accumulate values
counts.scatter_add_(
0,
inverse_indices,
torch.ones_like(flat_rewards),
)
sums.scatter_add_(0, inverse_indices, flat_rewards)
sum_squares.scatter_add_(0, inverse_indices, flat_rewards**2)

# Compute means and standard deviations
means = sums / counts
variances = (sum_squares / counts) - (means**2)
stds = torch.sqrt(variances)

# Map back to original tensor shape
mean_rewards = means[inverse_indices]
std_rewards = stds[inverse_indices]

# Calculate GRPO advantage
grpo_advantage = (flat_rewards - mean_rewards)
# Only normalize the advantage if flag is set
if self.actor_critic.normalize_advantage:
grpo_advantage /= (std_rewards + 1e-4)

# Create advantages of the same shape as original rewards
advantages = torch.zeros_like(rewards)
# Copy the flat grpo_advantage according to action_mask
expanded_advantages = grpo_advantage.unsqueeze(1).expand_as(
env_outs['action_mask'],
)
advantages = torch.where(
env_outs['action_mask'].bool(),
expanded_advantages,
advantages,
)
env_outs['advantages'] = advantages
else:
raise ValueError(
f'Invalid loss type: {self.actor_critic.loss_type}. ' +
'Valid options are: ppo, grpo.',
)

# Use scatter to compute means and standard deviations
# First, we'll create a tensor to track counts, sums, and sum of squares
n_unique = len(unique_prompt_ids)
counts = torch.zeros(n_unique, device=prompt_id.device)
sums = torch.zeros(n_unique, device=prompt_id.device)
sum_squares = torch.zeros(n_unique, device=prompt_id.device)

# Use scatter_add to accumulate values
counts.scatter_add_(
0,
inverse_indices,
torch.ones_like(flat_rewards),
)
sums.scatter_add_(0, inverse_indices, flat_rewards)
sum_squares.scatter_add_(0, inverse_indices, flat_rewards**2)

# Compute means and standard deviations
means = sums / counts
variances = (sum_squares / counts) - (means**2)
stds = torch.sqrt(variances)

# Map back to original tensor shape
mean_rewards = means[inverse_indices]
std_rewards = stds[inverse_indices]

# Calculate GRPO advantage
grpo_advantage = (flat_rewards - mean_rewards)
# Only normalize the advantage if flag is set
if self.actor_critic.normalize_advantage:
grpo_advantage /= (std_rewards + 1e-4)

# Create advantages of the same shape as original rewards
advantages = torch.zeros_like(rewards)
# Copy the flat grpo_advantage according to action_mask
expanded_advantages = grpo_advantage.unsqueeze(1).expand_as(
batch_adv_mean, batch_adv_var = dist_compute_masked_mean_and_var(
env_outs['advantages'],
env_outs['action_mask'],
)
advantages = torch.where(
env_outs['action_mask'].bool(),
expanded_advantages,
advantages,
)
env_outs['advantages'] = advantages
else:
raise ValueError(
f'Invalid loss type: {self.actor_critic.loss_type}. ' +
'Valid options are: ppo, grpo.',
)

batch_adv_mean, batch_adv_var = dist_compute_masked_mean_and_var(
env_outs['advantages'],
env_outs['action_mask'],
)
iter_batch.update({
'adv_masked_mean':
torch.ones(self.iter_batch_size) * batch_adv_mean.cpu(),
'adv_masked_var':
torch.ones(self.iter_batch_size) * batch_adv_var.cpu(),
})

mean_ift = masked_mean(
env_outs['ift_kl'],
Expand All @@ -964,17 +978,12 @@ def _resolve_outputs(
'max_gen_len':
torch.ones(self.iter_batch_size).to(torch.int32) *
self.max_gen_len,
'adv_masked_mean':
torch.ones(self.iter_batch_size) * batch_adv_mean.cpu(),
'adv_masked_var':
torch.ones(self.iter_batch_size) * batch_adv_var.cpu(),
'ift_kl_scalar':
torch.ones(self.iter_batch_size) * self.kl_ctl.value,
'reward_std':
torch.ones(self.iter_batch_size) *
env_outs['rewards'].std().to('cpu'),
})

# Moving minibatches to CPU to not take additional GPU memory
for k, v in iter_batch.items():
if hasattr(v, 'cpu'):
Expand Down
9 changes: 6 additions & 3 deletions compose_rl/algorithms/online/generation_utils/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from compose_rl.algorithms.online.generation_utils.vllm_actor import LLMRayActor
from compose_rl.algorithms.online.model_methods import OnPolicyEnum
from compose_rl.algorithms.online.model_methods import (
ALGORITHM_TYPE,
OnPolicyEnum,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -362,7 +365,7 @@ def should_update_torch_module(
if parsed_module_name not in valid_non_leaf_module_names:
return False

if loss_type == OnPolicyEnum.GRPO:
if loss_type in ALGORITHM_TYPE.CRITIC_FREE:
return True

if loss_type == OnPolicyEnum.PPO and 'lm_backbone' in full_param_name:
Expand Down Expand Up @@ -394,7 +397,7 @@ def broadcast_to_vllm(
count, num_params = 0, len(
list(model.model.lm_backbone.named_parameters()), # type: ignore
)
elif loss_type == OnPolicyEnum.GRPO:
elif loss_type in ALGORITHM_TYPE.CRITIC_FREE:
# Directly use the model params
count, num_params = 0, len(
list(model.model.named_parameters()), # type: ignore
Expand Down
6 changes: 6 additions & 0 deletions compose_rl/algorithms/online/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping):
value_clip_range=self.config.value_clip_range,
value_loss_weight=self.config.value_loss_weight,
policy_clip_ratio=self.config.policy_clip_ratio,
beta=self.config.beta,
add_direct_kl_loss=self.config.compute_kl_loss,
kl_estimator=self.config.kl_estimator,
kl_clip_range=self.config.kl_clip_range,
Expand Down Expand Up @@ -217,6 +218,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping):
value_clip_range=self.config.value_clip_range,
value_loss_weight=self.config.value_loss_weight,
policy_clip_ratio=self.config.policy_clip_ratio,
beta = self.config.beta,
add_direct_kl_loss=self.config.compute_kl_loss,
kl_estimator=self.config.kl_estimator,
kl_clip_range=self.config.kl_clip_range,
Expand Down Expand Up @@ -255,6 +257,7 @@ def __init__(
length_normalize_policy_loss: bool = True,
policy_clip_ratio: float = 0.15,
policy_clip_high_ratio: float | None = None,
beta: float = 1e-3,
compute_kl_loss: bool = True,
entropy_loss_weight: float | None = None,
target_kl: float = 0.1,
Expand All @@ -275,6 +278,7 @@ def __init__(
target_kl (float): The target KL value. Default: ``0.1``.
kl_estimator (str): The KL estimator to use. Default: ``'k3'``.
kl_clip_range (float): The KL clip range. Default: ``40.0``.
beta (float): pi_ref KL hyperparameter for APO. Default: ``1e-3``
"""
super().__init__(**kwargs)
self.policy_kl = []
Expand All @@ -285,6 +289,7 @@ def __init__(
self.policy_clip_high_ratio = policy_clip_high_ratio
self.compute_kl_loss = compute_kl_loss
self.target_kl = target_kl
self.beta = beta
self.kl_estimator = kl_estimator
self.kl_clip_range = kl_clip_range
self.entropy_loss_weight = entropy_loss_weight
Expand All @@ -309,6 +314,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping):
loss_type=self.loss_type,
policy_clip_ratio=self.policy_clip_ratio,
policy_clip_high_ratio=self.policy_clip_high_ratio,
beta=self.beta,
length_normalize_policy_loss=self.length_normalize_policy_loss,
add_direct_kl_loss=self.compute_kl_loss,
kl_estimator=self.kl_estimator,
Expand Down
Loading
Loading