Skip to content

Commit 15717b1

Browse files
committed
amend
1 parent 471cb48 commit 15717b1

File tree

4 files changed

+335
-10
lines changed

4 files changed

+335
-10
lines changed

test/test_distributions.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
MaskedOneHotCategorical,
3131
TanhDelta,
3232
)
33+
from torchrl.modules.distributions.discrete import LLMMaskedCategorical
3334
from torchrl.modules.distributions.continuous import SafeTanhTransform
3435
from torchrl.modules.distributions.discrete import _generate_ordinal_logits
3536

@@ -749,6 +750,313 @@ def test_reparam(self, grad_method, sparse):
749750
assert logits.grad is not None and logits.grad.norm() > 0
750751

751752

753+
class TestLLMMaskedCategorical:
754+
"""Test the LLM-optimized masked categorical distribution."""
755+
756+
def test_construction(self):
757+
"""Test basic construction and properties."""
758+
torch.manual_seed(0)
759+
logits = torch.randn(2, 3, 4) # batch=2, seq=3, vocab=4
760+
mask = torch.ones(2, 3, dtype=torch.bool)
761+
mask[0, :1] = False # mask first token of first sequence
762+
763+
dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100)
764+
765+
# Check properties
766+
assert dist.mask.shape == mask.shape
767+
assert torch.equal(dist.mask, mask)
768+
assert dist.logits.shape == logits.shape
769+
assert torch.equal(dist.logits, logits)
770+
assert dist.ignore_index == -100
771+
772+
# Check that masked logits are created lazily
773+
assert dist._masked_logits is None
774+
assert dist._masked_dist is None
775+
776+
# Access masked logits to trigger creation
777+
masked_logits = dist.masked_logits
778+
assert dist._masked_logits is not None
779+
assert masked_logits.shape == logits.shape
780+
781+
# Check that masked positions have large negative values
782+
large_neg = torch.finfo(logits.dtype).min
783+
assert (masked_logits[0, :1] == large_neg).all()
784+
assert (masked_logits[0, 1:] != large_neg).all()
785+
786+
def test_log_prob_efficiency(self):
787+
"""Test that log_prob uses ignore_index approach efficiently."""
788+
torch.manual_seed(0)
789+
logits = torch.randn(2, 3, 4, requires_grad=True)
790+
mask = torch.ones(2, 3, dtype=torch.bool)
791+
mask[0, :1] = False
792+
793+
dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100)
794+
795+
# Create target tokens with ignore_index for masked positions
796+
target_tokens = torch.randint(0, 4, (2, 3))
797+
target_tokens[~mask] = -100
798+
799+
# Compute log probabilities
800+
log_probs = dist.log_prob(target_tokens)
801+
802+
# Check shapes
803+
assert log_probs.shape == target_tokens.shape
804+
805+
# Check that masked positions have zero log probability
806+
assert (log_probs[~mask] == 0.0).all()
807+
808+
# Check that valid positions have finite log probabilities
809+
assert torch.isfinite(log_probs[mask]).all()
810+
811+
# Test backward pass
812+
loss = -log_probs.sum()
813+
loss.backward()
814+
assert logits.grad is not None
815+
816+
def test_sampling_correctness(self):
817+
"""Test that sampling works correctly with masking."""
818+
torch.manual_seed(0)
819+
logits = torch.randn(2, 3, 4)
820+
mask = torch.ones(2, 3, dtype=torch.bool)
821+
mask[0, :1] = False # mask first token of first sequence
822+
823+
dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100)
824+
825+
# Sample multiple times
826+
num_samples = 1000
827+
samples = dist.sample((num_samples,))
828+
829+
# Check shapes
830+
assert samples.shape == (num_samples, 2, 3)
831+
832+
# Check that valid positions are sampled within valid range
833+
assert (samples >= 0).all()
834+
assert (samples < 4).all()
835+
836+
# Check that sampling respects the mask (masked positions should have
837+
# very low probability of being sampled, but not impossible)
838+
# We'll just check that the distribution is valid for a single sample
839+
single_sample = samples[0]
840+
assert torch.isfinite(dist.log_prob(single_sample)).all()
841+
842+
def test_mode_correctness(self):
843+
"""Test that mode computation works correctly."""
844+
torch.manual_seed(0)
845+
logits = torch.randn(2, 3, 4)
846+
mask = torch.ones(2, 3, dtype=torch.bool)
847+
mask[0, :1] = False
848+
849+
dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100)
850+
851+
mode = dist.mode
852+
853+
# Check shapes
854+
assert mode.shape == (2, 3)
855+
856+
# Check that mode values are valid indices
857+
assert (mode >= 0).all()
858+
assert (mode < 4).all()
859+
860+
# Check that mode matches the argmax of masked logits
861+
masked_logits = dist.masked_logits
862+
expected_mode = masked_logits.argmax(dim=-1)
863+
torch.testing.assert_close(mode, expected_mode)
864+
865+
def test_entropy_correctness(self):
866+
"""Test that entropy computation works correctly."""
867+
torch.manual_seed(0)
868+
logits = torch.randn(2, 3, 4)
869+
mask = torch.ones(2, 3, dtype=torch.bool)
870+
mask[0, :1] = False
871+
872+
dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100)
873+
874+
entropy = dist.entropy()
875+
876+
# Check shapes
877+
assert entropy.shape == (2, 3)
878+
879+
# Check that entropy is finite for valid positions
880+
assert torch.isfinite(entropy[mask]).all()
881+
882+
# Check that entropy is reasonable (positive for valid positions)
883+
assert (entropy[mask] >= 0).all()
884+
885+
def test_clear_cache(self):
886+
"""Test that cache clearing works correctly."""
887+
torch.manual_seed(0)
888+
logits = torch.randn(2, 3, 4)
889+
mask = torch.ones(2, 3, dtype=torch.bool)
890+
891+
dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100)
892+
893+
# Access masked logits to populate cache
894+
_ = dist.masked_logits
895+
_ = dist.masked_dist
896+
897+
# Check that cache is populated
898+
assert dist._masked_logits is not None
899+
assert dist._masked_dist is not None
900+
901+
# Clear cache
902+
dist.clear_cache()
903+
904+
# Check that cache is cleared
905+
assert dist._masked_logits is None
906+
assert dist._masked_dist is None
907+
908+
def test_memory_efficiency(self):
909+
"""Test that the distribution is memory efficient."""
910+
import psutil
911+
import os
912+
913+
def get_memory_usage():
914+
process = psutil.Process(os.getpid())
915+
return process.memory_info().rss / 1024 / 1024
916+
917+
torch.manual_seed(0)
918+
# Use larger tensors to make memory differences more apparent
919+
logits = torch.randn(4, 512, 1000, requires_grad=True)
920+
mask = torch.ones(4, 512, dtype=torch.bool)
921+
mask[:, :100] = False # mask first 100 tokens
922+
923+
target_tokens = torch.randint(0, 1000, (4, 512))
924+
target_tokens[~mask] = -100
925+
926+
# Test current MaskedCategorical approach
927+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
928+
start_memory = get_memory_usage()
929+
930+
dist_current = MaskedCategorical(logits=logits, mask=mask, use_cross_entropy=True)
931+
log_probs_current = dist_current.log_prob(target_tokens)
932+
loss_current = -log_probs_current.sum()
933+
loss_current.backward()
934+
935+
current_memory = get_memory_usage() - start_memory
936+
937+
# Reset gradients
938+
logits.grad = None
939+
940+
# Test new LLMMaskedCategorical approach
941+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
942+
start_memory = get_memory_usage()
943+
944+
dist_new = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100)
945+
log_probs_new = dist_new.log_prob(target_tokens)
946+
loss_new = -log_probs_new.sum()
947+
loss_new.backward()
948+
949+
new_memory = get_memory_usage() - start_memory
950+
951+
# Check that results are similar
952+
torch.testing.assert_close(loss_current, loss_new, atol=1e-4, rtol=1e-4)
953+
954+
# Check that new approach uses less memory
955+
# Note: This might not always be true due to memory fragmentation,
956+
# but it should generally be the case
957+
print(f"Current approach memory: {current_memory:.2f} MB")
958+
print(f"New approach memory: {new_memory:.2f} MB")
959+
960+
def test_ignore_index_variations(self):
961+
"""Test with different ignore_index values."""
962+
torch.manual_seed(0)
963+
logits = torch.randn(2, 3, 4)
964+
mask = torch.ones(2, 3, dtype=torch.bool)
965+
mask[0, :1] = False
966+
967+
# Test with different ignore_index values
968+
for ignore_idx in [-100, -1, 999]:
969+
dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=ignore_idx)
970+
971+
target_tokens = torch.randint(0, 4, (2, 3))
972+
target_tokens[~mask] = ignore_idx
973+
974+
log_probs = dist.log_prob(target_tokens)
975+
976+
# Check that ignored positions have zero log probability
977+
assert (log_probs[~mask] == 0.0).all()
978+
979+
# Check that valid positions have finite log probabilities
980+
assert torch.isfinite(log_probs[mask]).all()
981+
982+
def test_large_vocabulary(self):
983+
"""Test with large vocabulary size to ensure scalability."""
984+
torch.manual_seed(0)
985+
vocab_size = 50000 # Large vocabulary like in LLMs
986+
logits = torch.randn(2, 10, vocab_size)
987+
mask = torch.ones(2, 10, dtype=torch.bool)
988+
mask[0, :3] = False # mask first 3 tokens of first sequence
989+
990+
dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100)
991+
992+
# Test log_prob
993+
target_tokens = torch.randint(0, vocab_size, (2, 10))
994+
target_tokens[~mask] = -100
995+
996+
log_probs = dist.log_prob(target_tokens)
997+
assert log_probs.shape == (2, 10)
998+
assert (log_probs[~mask] == 0.0).all()
999+
assert torch.isfinite(log_probs[mask]).all()
1000+
1001+
# Test sampling
1002+
samples = dist.sample()
1003+
assert samples.shape == (2, 10)
1004+
assert (samples >= 0).all()
1005+
assert (samples < vocab_size).all()
1006+
1007+
# Test mode
1008+
mode = dist.mode
1009+
assert mode.shape == (2, 10)
1010+
assert (mode >= 0).all()
1011+
assert (mode < vocab_size).all()
1012+
1013+
def test_comparison_with_masked_categorical(self):
1014+
"""Compare results with the original MaskedCategorical."""
1015+
torch.manual_seed(0)
1016+
logits = torch.randn(2, 3, 4)
1017+
mask = torch.ones(2, 3, dtype=torch.bool)
1018+
mask[0, :1] = False
1019+
1020+
# Create both distributions
1021+
dist_original = MaskedCategorical(logits=logits, mask=mask, use_cross_entropy=True)
1022+
dist_new = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100)
1023+
1024+
# Test log_prob with valid tokens
1025+
target_tokens = torch.randint(0, 4, (2, 3))
1026+
target_tokens[~mask] = -100 # Set masked positions to ignore_index
1027+
1028+
log_probs_original = dist_original.log_prob(target_tokens)
1029+
log_probs_new = dist_new.log_prob(target_tokens)
1030+
1031+
# Results should be similar (might have small numerical differences)
1032+
torch.testing.assert_close(log_probs_original, log_probs_new, atol=1e-4, rtol=1e-4)
1033+
1034+
# Test sampling
1035+
samples_original = dist_original.sample()
1036+
samples_new = dist_new.sample()
1037+
1038+
# Both should produce valid samples
1039+
assert (samples_original >= 0).all()
1040+
assert (samples_original < 4).all()
1041+
assert (samples_new >= 0).all()
1042+
assert (samples_new < 4).all()
1043+
1044+
def test_error_handling(self):
1045+
"""Test error handling for invalid inputs."""
1046+
torch.manual_seed(0)
1047+
logits = torch.randn(2, 3, 4)
1048+
mask = torch.ones(2, 3, dtype=torch.bool)
1049+
1050+
# Test with mismatched shapes
1051+
with pytest.raises(ValueError):
1052+
LLMMaskedCategorical(logits=logits, mask=torch.ones(3, 4, dtype=torch.bool))
1053+
1054+
# Test with invalid ignore_index (should be fine since we don't validate this)
1055+
# The current implementation doesn't validate ignore_index, so this should work
1056+
dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=5)
1057+
assert dist.ignore_index == 5
1058+
1059+
7521060
class TestOrdinal:
7531061
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
7541062
@pytest.mark.parametrize("device", get_default_devices())

torchrl/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ReparamGradientStrategy,
2020
TanhDelta,
2121
TanhNormal,
22+
LLMMaskedCategorical,
2223
TruncatedNormal,
2324
)
2425
from .models import (
@@ -138,6 +139,7 @@
138139
"MaskedOneHotCategorical",
139140
"MultiAgentConvNet",
140141
"MultiAgentMLP",
142+
"LLMMaskedCategorical",
141143
"MultiAgentNetBase",
142144
"MultiStepActorWrapper",
143145
"NoisyLazyLinear",

torchrl/modules/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
OneHotOrdinal,
2222
Ordinal,
2323
ReparamGradientStrategy,
24+
LLMMaskedCategorical,
2425
)
2526

2627
distributions_maps = {
@@ -57,6 +58,7 @@
5758
"distributions",
5859
"Delta",
5960
"IndependentNormal",
61+
"LLMMaskedCategorical",
6062
"NormalParamWrapper",
6163
"TanhDelta",
6264
"TanhNormal",

0 commit comments

Comments
 (0)