Skip to content

Commit 471cb48

Browse files
committed
amend
1 parent 2505ede commit 471cb48

File tree

2 files changed

+164
-23
lines changed

2 files changed

+164
-23
lines changed

torchrl/modules/distributions/discrete.py

Lines changed: 157 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from torch.distributions.utils import lazy_property, logits_to_probs, probs_to_logits
1717

18-
__all__ = ["OneHotCategorical", "MaskedCategorical", "Ordinal", "OneHotOrdinal"]
18+
__all__ = ["OneHotCategorical", "MaskedCategorical", "Ordinal", "OneHotOrdinal", "LLMMaskedCategorical"]
1919

2020

2121
def _treat_categorical_params(
@@ -51,8 +51,8 @@ def wrapped(_self, *args, **kwargs):
5151

5252

5353
class ReparamGradientStrategy(Enum):
54-
PassThrough: Any = 1
55-
RelaxedOneHot: Any = 2
54+
PassThrough = 1
55+
RelaxedOneHot = 2
5656

5757

5858
class OneHotCategorical(D.Categorical):
@@ -105,7 +105,11 @@ def __init__(
105105
probs = _treat_categorical_params(probs)
106106
self.grad_method = grad_method
107107
super().__init__(probs=probs, logits=logits, **kwargs)
108-
self.num_samples = self._param.shape[-1]
108+
# Get num_samples from logits or probs shape
109+
if logits is not None:
110+
self.num_samples = logits.shape[-1]
111+
else:
112+
self.num_samples = probs.shape[-1]
109113

110114
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
111115
return super().log_prob(value.argmax(dim=-1))
@@ -381,7 +385,9 @@ def _mask_logits(
381385
return logits
382386

383387
if not sparse_mask:
384-
return logits.masked_fill(~expand_as_right(mask, logits), neg_inf)
388+
# Use a large negative value instead of -inf to avoid numerical issues
389+
large_neg = torch.finfo(logits.dtype).min
390+
return logits.masked_fill(~expand_as_right(mask, logits), large_neg)
385391

386392
if padding_value is not None:
387393
padding_mask = mask == padding_value
@@ -390,7 +396,8 @@ def _mask_logits(
390396
mask = mask.masked_fill(padding_mask, 0)
391397
logits = logits.gather(dim=-1, index=mask)
392398
if padding_value is not None:
393-
logits.masked_fill_(padding_mask, neg_inf)
399+
large_neg = torch.finfo(logits.dtype).min
400+
logits.masked_fill_(padding_mask, large_neg)
394401
return logits
395402

396403
@property
@@ -658,3 +665,147 @@ def _generate_ordinal_logits(scores: torch.Tensor) -> torch.Tensor:
658665
)
659666

660667
return larger_than_log_probs + smaller_than_log_probs
668+
669+
670+
class LLMMaskedCategorical(D.Categorical):
671+
"""LLM-optimized masked categorical distribution.
672+
673+
This class provides a more memory-efficient approach for LLM training by:
674+
1. Using ignore_index=-100 for log_prob computation (no masking overhead)
675+
2. Using traditional masking for sampling operations
676+
677+
This is particularly beneficial for large vocabulary sizes where masking
678+
all logits can be memory-intensive.
679+
680+
Args:
681+
logits (torch.Tensor): event log probabilities (unnormalized)
682+
probs (torch.Tensor): event probabilities
683+
mask (torch.Tensor): boolean mask indicating valid positions
684+
ignore_index (int, optional): index to ignore in log_prob computation. Defaults to -100.
685+
686+
Examples:
687+
>>> logits = torch.randn(2, 10, 50000) # batch=2, seq_len=10, vocab=50000
688+
>>> mask = torch.ones(2, 10, dtype=torch.bool)
689+
>>> mask[0, :5] = False # mask first 5 tokens of first sequence
690+
>>> dist = LLMMaskedCategorical(logits=logits, mask=mask)
691+
>>>
692+
>>> # Efficient log_prob computation (no masking overhead)
693+
>>> tokens = torch.randint(0, 50000, (2, 10))
694+
>>> tokens[0, :5] = -100 # set masked positions to ignore_index
695+
>>> log_probs = dist.log_prob(tokens)
696+
>>>
697+
>>> # Sampling still uses masking for correctness
698+
>>> samples = dist.sample()
699+
"""
700+
701+
def __init__(
702+
self,
703+
logits: torch.Tensor,
704+
mask: torch.Tensor,
705+
ignore_index: int = -100,
706+
**kwargs,
707+
) -> None:
708+
self._original_logits = logits
709+
self._mask = mask
710+
self.ignore_index = ignore_index
711+
712+
# Create masked logits for sampling (only when needed)
713+
self._masked_logits = None
714+
self._masked_dist = None
715+
716+
# Initialize parent with original logits
717+
super().__init__(logits=logits, **kwargs)
718+
719+
@property
720+
def _sampling_logits(self):
721+
"""Get masked logits for sampling operations."""
722+
if self._masked_logits is None:
723+
# Only create masked logits when needed for sampling
724+
large_neg = torch.finfo(self._original_logits.dtype).min
725+
self._masked_logits = self._original_logits.masked_fill(
726+
~expand_as_right(self._mask, self._original_logits),
727+
large_neg
728+
)
729+
return self._masked_logits
730+
731+
@property
732+
def _sampling_dist(self):
733+
"""Get masked distribution for sampling operations."""
734+
if self._masked_dist is None:
735+
self._masked_dist = D.Categorical(logits=self._sampling_logits)
736+
return self._masked_dist
737+
738+
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
739+
"""Compute log probabilities using ignore_index approach.
740+
741+
This is memory-efficient as it doesn't require masking the logits.
742+
The value tensor should use ignore_index for masked positions.
743+
"""
744+
# Use cross_entropy with ignore_index for efficiency
745+
if value.ndim > 1:
746+
# Reshape for cross_entropy: (batch, seq_len, vocab) -> (batch*seq_len, vocab)
747+
logits_flat = self._original_logits.view(-1, self._original_logits.size(-1))
748+
value_flat = value.view(-1)
749+
750+
# Compute cross_entropy with ignore_index
751+
log_probs_flat = -F.cross_entropy(
752+
logits_flat, value_flat,
753+
reduce=False,
754+
ignore_index=self.ignore_index
755+
)
756+
757+
# Reshape back
758+
log_probs = log_probs_flat.view_as(value)
759+
else:
760+
log_probs = -F.cross_entropy(
761+
self._original_logits, value,
762+
reduce=False,
763+
ignore_index=self.ignore_index
764+
)
765+
766+
return log_probs
767+
768+
def sample(self, sample_shape: torch.Size | Sequence[int] | None = None) -> torch.Tensor:
769+
"""Sample from the distribution using masked logits."""
770+
return self._sampling_dist.sample(sample_shape)
771+
772+
def rsample(self, sample_shape: torch.Size | Sequence = None) -> torch.Tensor:
773+
"""Reparameterized sampling using masked logits."""
774+
# This would need to be implemented based on the specific reparameterization strategy
775+
# For now, fall back to regular sampling
776+
return self.sample(sample_shape)
777+
778+
@property
779+
def mode(self) -> torch.Tensor:
780+
"""Get the mode using masked logits."""
781+
masked_logits = self._sampling_logits
782+
return (masked_logits == masked_logits.max(-1, True)[0]).to(torch.long)
783+
784+
def entropy(self) -> torch.Tensor:
785+
"""Compute entropy using masked logits."""
786+
return self._sampling_dist.entropy()
787+
788+
def clear_cache(self):
789+
"""Clear cached masked tensors to free memory."""
790+
self._masked_logits = None
791+
self._masked_dist = None
792+
793+
@property
794+
def mask(self) -> torch.Tensor:
795+
"""Get the mask."""
796+
return self._mask
797+
798+
@property
799+
def logits(self) -> torch.Tensor:
800+
"""Get the original logits."""
801+
return self._original_logits
802+
803+
@property
804+
def masked_logits(self) -> torch.Tensor:
805+
"""Get the masked logits for sampling operations."""
806+
return self._sampling_logits
807+
808+
@property
809+
def masked_dist(self) -> D.Categorical:
810+
"""Get the masked distribution for sampling operations."""
811+
return self._sampling_dist

torchrl/modules/llm/policies/common.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch.nn.utils.rnn import pad_sequence
1818
from torchrl.data.llm import History
1919
from torchrl.data.tensor_specs import Unbounded
20-
from torchrl.modules import MaskedCategorical
20+
from torchrl.modules.distributions.discrete import LLMMaskedCategorical
2121

2222
# TODOs:
2323
# - [ ] Remove the useless view(-1) calls when num_samples is not > 1
@@ -446,7 +446,7 @@ def get_dist(
446446
**kwargs: Additional arguments
447447
448448
Returns:
449-
Distribution (Categorical or MaskedCategorical)
449+
Distribution (Categorical or LLMMaskedCategorical)
450450
"""
451451
if self.generate:
452452
raise NotImplementedError(
@@ -489,11 +489,9 @@ def get_dist(
489489
mask = logits != padding_value
490490

491491
if mask is not None:
492-
return MaskedCategorical(
492+
return LLMMaskedCategorical(
493493
logits=logits,
494494
mask=mask,
495-
use_cross_entropy=True,
496-
padding_side=padding_side,
497495
)
498496
return Categorical(logits)
499497

@@ -603,11 +601,9 @@ def _get_dist_with_prompt_mask(
603601
padding_side=padding_side,
604602
)
605603

606-
return MaskedCategorical(
604+
return LLMMaskedCategorical(
607605
logits=logits,
608606
mask=response_mask.bool(),
609-
use_cross_entropy=True,
610-
padding_side=padding_side,
611607
)
612608

613609
def _get_dist_with_assistant_mask(
@@ -668,11 +664,9 @@ def _get_dist_with_assistant_mask(
668664
f"Assistant mask not found in tensordict at key {assistant_mask_key}. {post_msg}"
669665
)
670666

671-
return MaskedCategorical(
667+
return LLMMaskedCategorical(
672668
logits=logits,
673669
mask=assistant_mask,
674-
use_cross_entropy=True,
675-
padding_side=padding_side,
676670
)
677671

678672
def _get_dist_with_attention_mask(
@@ -722,11 +716,9 @@ def _get_dist_with_attention_mask(
722716
f"Attention mask not found in tensordict at key {attention_mask_key}"
723717
)
724718

725-
return MaskedCategorical(
719+
return LLMMaskedCategorical(
726720
logits=logits,
727721
mask=attention_mask,
728-
use_cross_entropy=True,
729-
padding_side=padding_side,
730722
)
731723

732724
def _get_dist_with_custom_mask(
@@ -764,11 +756,9 @@ def _get_dist_with_custom_mask(
764756
if logits is None:
765757
raise ValueError(f"Logits not found in tensordict at key {logits_key}")
766758

767-
return MaskedCategorical(
759+
return LLMMaskedCategorical(
768760
logits=logits,
769761
mask=mask,
770-
use_cross_entropy=True,
771-
padding_side=padding_side,
772762
)
773763

774764
# Convenience methods for common LLM training scenarios

0 commit comments

Comments
 (0)