15
15
16
16
from torch .distributions .utils import lazy_property , logits_to_probs , probs_to_logits
17
17
18
- __all__ = ["OneHotCategorical" , "MaskedCategorical" , "Ordinal" , "OneHotOrdinal" ]
18
+ __all__ = ["OneHotCategorical" , "MaskedCategorical" , "Ordinal" , "OneHotOrdinal" , "LLMMaskedCategorical" ]
19
19
20
20
21
21
def _treat_categorical_params (
@@ -51,8 +51,8 @@ def wrapped(_self, *args, **kwargs):
51
51
52
52
53
53
class ReparamGradientStrategy (Enum ):
54
- PassThrough : Any = 1
55
- RelaxedOneHot : Any = 2
54
+ PassThrough = 1
55
+ RelaxedOneHot = 2
56
56
57
57
58
58
class OneHotCategorical (D .Categorical ):
@@ -105,7 +105,11 @@ def __init__(
105
105
probs = _treat_categorical_params (probs )
106
106
self .grad_method = grad_method
107
107
super ().__init__ (probs = probs , logits = logits , ** kwargs )
108
- self .num_samples = self ._param .shape [- 1 ]
108
+ # Get num_samples from logits or probs shape
109
+ if logits is not None :
110
+ self .num_samples = logits .shape [- 1 ]
111
+ else :
112
+ self .num_samples = probs .shape [- 1 ]
109
113
110
114
def log_prob (self , value : torch .Tensor ) -> torch .Tensor :
111
115
return super ().log_prob (value .argmax (dim = - 1 ))
@@ -381,7 +385,9 @@ def _mask_logits(
381
385
return logits
382
386
383
387
if not sparse_mask :
384
- return logits .masked_fill (~ expand_as_right (mask , logits ), neg_inf )
388
+ # Use a large negative value instead of -inf to avoid numerical issues
389
+ large_neg = torch .finfo (logits .dtype ).min
390
+ return logits .masked_fill (~ expand_as_right (mask , logits ), large_neg )
385
391
386
392
if padding_value is not None :
387
393
padding_mask = mask == padding_value
@@ -390,7 +396,8 @@ def _mask_logits(
390
396
mask = mask .masked_fill (padding_mask , 0 )
391
397
logits = logits .gather (dim = - 1 , index = mask )
392
398
if padding_value is not None :
393
- logits .masked_fill_ (padding_mask , neg_inf )
399
+ large_neg = torch .finfo (logits .dtype ).min
400
+ logits .masked_fill_ (padding_mask , large_neg )
394
401
return logits
395
402
396
403
@property
@@ -658,3 +665,147 @@ def _generate_ordinal_logits(scores: torch.Tensor) -> torch.Tensor:
658
665
)
659
666
660
667
return larger_than_log_probs + smaller_than_log_probs
668
+
669
+
670
+ class LLMMaskedCategorical (D .Categorical ):
671
+ """LLM-optimized masked categorical distribution.
672
+
673
+ This class provides a more memory-efficient approach for LLM training by:
674
+ 1. Using ignore_index=-100 for log_prob computation (no masking overhead)
675
+ 2. Using traditional masking for sampling operations
676
+
677
+ This is particularly beneficial for large vocabulary sizes where masking
678
+ all logits can be memory-intensive.
679
+
680
+ Args:
681
+ logits (torch.Tensor): event log probabilities (unnormalized)
682
+ probs (torch.Tensor): event probabilities
683
+ mask (torch.Tensor): boolean mask indicating valid positions
684
+ ignore_index (int, optional): index to ignore in log_prob computation. Defaults to -100.
685
+
686
+ Examples:
687
+ >>> logits = torch.randn(2, 10, 50000) # batch=2, seq_len=10, vocab=50000
688
+ >>> mask = torch.ones(2, 10, dtype=torch.bool)
689
+ >>> mask[0, :5] = False # mask first 5 tokens of first sequence
690
+ >>> dist = LLMMaskedCategorical(logits=logits, mask=mask)
691
+ >>>
692
+ >>> # Efficient log_prob computation (no masking overhead)
693
+ >>> tokens = torch.randint(0, 50000, (2, 10))
694
+ >>> tokens[0, :5] = -100 # set masked positions to ignore_index
695
+ >>> log_probs = dist.log_prob(tokens)
696
+ >>>
697
+ >>> # Sampling still uses masking for correctness
698
+ >>> samples = dist.sample()
699
+ """
700
+
701
+ def __init__ (
702
+ self ,
703
+ logits : torch .Tensor ,
704
+ mask : torch .Tensor ,
705
+ ignore_index : int = - 100 ,
706
+ ** kwargs ,
707
+ ) -> None :
708
+ self ._original_logits = logits
709
+ self ._mask = mask
710
+ self .ignore_index = ignore_index
711
+
712
+ # Create masked logits for sampling (only when needed)
713
+ self ._masked_logits = None
714
+ self ._masked_dist = None
715
+
716
+ # Initialize parent with original logits
717
+ super ().__init__ (logits = logits , ** kwargs )
718
+
719
+ @property
720
+ def _sampling_logits (self ):
721
+ """Get masked logits for sampling operations."""
722
+ if self ._masked_logits is None :
723
+ # Only create masked logits when needed for sampling
724
+ large_neg = torch .finfo (self ._original_logits .dtype ).min
725
+ self ._masked_logits = self ._original_logits .masked_fill (
726
+ ~ expand_as_right (self ._mask , self ._original_logits ),
727
+ large_neg
728
+ )
729
+ return self ._masked_logits
730
+
731
+ @property
732
+ def _sampling_dist (self ):
733
+ """Get masked distribution for sampling operations."""
734
+ if self ._masked_dist is None :
735
+ self ._masked_dist = D .Categorical (logits = self ._sampling_logits )
736
+ return self ._masked_dist
737
+
738
+ def log_prob (self , value : torch .Tensor ) -> torch .Tensor :
739
+ """Compute log probabilities using ignore_index approach.
740
+
741
+ This is memory-efficient as it doesn't require masking the logits.
742
+ The value tensor should use ignore_index for masked positions.
743
+ """
744
+ # Use cross_entropy with ignore_index for efficiency
745
+ if value .ndim > 1 :
746
+ # Reshape for cross_entropy: (batch, seq_len, vocab) -> (batch*seq_len, vocab)
747
+ logits_flat = self ._original_logits .view (- 1 , self ._original_logits .size (- 1 ))
748
+ value_flat = value .view (- 1 )
749
+
750
+ # Compute cross_entropy with ignore_index
751
+ log_probs_flat = - F .cross_entropy (
752
+ logits_flat , value_flat ,
753
+ reduce = False ,
754
+ ignore_index = self .ignore_index
755
+ )
756
+
757
+ # Reshape back
758
+ log_probs = log_probs_flat .view_as (value )
759
+ else :
760
+ log_probs = - F .cross_entropy (
761
+ self ._original_logits , value ,
762
+ reduce = False ,
763
+ ignore_index = self .ignore_index
764
+ )
765
+
766
+ return log_probs
767
+
768
+ def sample (self , sample_shape : torch .Size | Sequence [int ] | None = None ) -> torch .Tensor :
769
+ """Sample from the distribution using masked logits."""
770
+ return self ._sampling_dist .sample (sample_shape )
771
+
772
+ def rsample (self , sample_shape : torch .Size | Sequence = None ) -> torch .Tensor :
773
+ """Reparameterized sampling using masked logits."""
774
+ # This would need to be implemented based on the specific reparameterization strategy
775
+ # For now, fall back to regular sampling
776
+ return self .sample (sample_shape )
777
+
778
+ @property
779
+ def mode (self ) -> torch .Tensor :
780
+ """Get the mode using masked logits."""
781
+ masked_logits = self ._sampling_logits
782
+ return (masked_logits == masked_logits .max (- 1 , True )[0 ]).to (torch .long )
783
+
784
+ def entropy (self ) -> torch .Tensor :
785
+ """Compute entropy using masked logits."""
786
+ return self ._sampling_dist .entropy ()
787
+
788
+ def clear_cache (self ):
789
+ """Clear cached masked tensors to free memory."""
790
+ self ._masked_logits = None
791
+ self ._masked_dist = None
792
+
793
+ @property
794
+ def mask (self ) -> torch .Tensor :
795
+ """Get the mask."""
796
+ return self ._mask
797
+
798
+ @property
799
+ def logits (self ) -> torch .Tensor :
800
+ """Get the original logits."""
801
+ return self ._original_logits
802
+
803
+ @property
804
+ def masked_logits (self ) -> torch .Tensor :
805
+ """Get the masked logits for sampling operations."""
806
+ return self ._sampling_logits
807
+
808
+ @property
809
+ def masked_dist (self ) -> D .Categorical :
810
+ """Get the masked distribution for sampling operations."""
811
+ return self ._sampling_dist
0 commit comments