|
30 | 30 | MaskedOneHotCategorical,
|
31 | 31 | TanhDelta,
|
32 | 32 | )
|
| 33 | +from torchrl.modules.distributions.discrete import LLMMaskedCategorical |
33 | 34 | from torchrl.modules.distributions.continuous import SafeTanhTransform
|
34 | 35 | from torchrl.modules.distributions.discrete import _generate_ordinal_logits
|
35 | 36 |
|
@@ -749,6 +750,313 @@ def test_reparam(self, grad_method, sparse):
|
749 | 750 | assert logits.grad is not None and logits.grad.norm() > 0
|
750 | 751 |
|
751 | 752 |
|
| 753 | +class TestLLMMaskedCategorical: |
| 754 | + """Test the LLM-optimized masked categorical distribution.""" |
| 755 | + |
| 756 | + def test_construction(self): |
| 757 | + """Test basic construction and properties.""" |
| 758 | + torch.manual_seed(0) |
| 759 | + logits = torch.randn(2, 3, 4) # batch=2, seq=3, vocab=4 |
| 760 | + mask = torch.ones(2, 3, dtype=torch.bool) |
| 761 | + mask[0, :1] = False # mask first token of first sequence |
| 762 | + |
| 763 | + dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100) |
| 764 | + |
| 765 | + # Check properties |
| 766 | + assert dist.mask.shape == mask.shape |
| 767 | + assert torch.equal(dist.mask, mask) |
| 768 | + assert dist.logits.shape == logits.shape |
| 769 | + assert torch.equal(dist.logits, logits) |
| 770 | + assert dist.ignore_index == -100 |
| 771 | + |
| 772 | + # Check that masked logits are created lazily |
| 773 | + assert dist._masked_logits is None |
| 774 | + assert dist._masked_dist is None |
| 775 | + |
| 776 | + # Access masked logits to trigger creation |
| 777 | + masked_logits = dist.masked_logits |
| 778 | + assert dist._masked_logits is not None |
| 779 | + assert masked_logits.shape == logits.shape |
| 780 | + |
| 781 | + # Check that masked positions have large negative values |
| 782 | + large_neg = torch.finfo(logits.dtype).min |
| 783 | + assert (masked_logits[0, :1] == large_neg).all() |
| 784 | + assert (masked_logits[0, 1:] != large_neg).all() |
| 785 | + |
| 786 | + def test_log_prob_efficiency(self): |
| 787 | + """Test that log_prob uses ignore_index approach efficiently.""" |
| 788 | + torch.manual_seed(0) |
| 789 | + logits = torch.randn(2, 3, 4, requires_grad=True) |
| 790 | + mask = torch.ones(2, 3, dtype=torch.bool) |
| 791 | + mask[0, :1] = False |
| 792 | + |
| 793 | + dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100) |
| 794 | + |
| 795 | + # Create target tokens with ignore_index for masked positions |
| 796 | + target_tokens = torch.randint(0, 4, (2, 3)) |
| 797 | + target_tokens[~mask] = -100 |
| 798 | + |
| 799 | + # Compute log probabilities |
| 800 | + log_probs = dist.log_prob(target_tokens) |
| 801 | + |
| 802 | + # Check shapes |
| 803 | + assert log_probs.shape == target_tokens.shape |
| 804 | + |
| 805 | + # Check that masked positions have zero log probability |
| 806 | + assert (log_probs[~mask] == 0.0).all() |
| 807 | + |
| 808 | + # Check that valid positions have finite log probabilities |
| 809 | + assert torch.isfinite(log_probs[mask]).all() |
| 810 | + |
| 811 | + # Test backward pass |
| 812 | + loss = -log_probs.sum() |
| 813 | + loss.backward() |
| 814 | + assert logits.grad is not None |
| 815 | + |
| 816 | + def test_sampling_correctness(self): |
| 817 | + """Test that sampling works correctly with masking.""" |
| 818 | + torch.manual_seed(0) |
| 819 | + logits = torch.randn(2, 3, 4) |
| 820 | + mask = torch.ones(2, 3, dtype=torch.bool) |
| 821 | + mask[0, :1] = False # mask first token of first sequence |
| 822 | + |
| 823 | + dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100) |
| 824 | + |
| 825 | + # Sample multiple times |
| 826 | + num_samples = 1000 |
| 827 | + samples = dist.sample((num_samples,)) |
| 828 | + |
| 829 | + # Check shapes |
| 830 | + assert samples.shape == (num_samples, 2, 3) |
| 831 | + |
| 832 | + # Check that valid positions are sampled within valid range |
| 833 | + assert (samples >= 0).all() |
| 834 | + assert (samples < 4).all() |
| 835 | + |
| 836 | + # Check that sampling respects the mask (masked positions should have |
| 837 | + # very low probability of being sampled, but not impossible) |
| 838 | + # We'll just check that the distribution is valid for a single sample |
| 839 | + single_sample = samples[0] |
| 840 | + assert torch.isfinite(dist.log_prob(single_sample)).all() |
| 841 | + |
| 842 | + def test_mode_correctness(self): |
| 843 | + """Test that mode computation works correctly.""" |
| 844 | + torch.manual_seed(0) |
| 845 | + logits = torch.randn(2, 3, 4) |
| 846 | + mask = torch.ones(2, 3, dtype=torch.bool) |
| 847 | + mask[0, :1] = False |
| 848 | + |
| 849 | + dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100) |
| 850 | + |
| 851 | + mode = dist.mode |
| 852 | + |
| 853 | + # Check shapes |
| 854 | + assert mode.shape == (2, 3) |
| 855 | + |
| 856 | + # Check that mode values are valid indices |
| 857 | + assert (mode >= 0).all() |
| 858 | + assert (mode < 4).all() |
| 859 | + |
| 860 | + # Check that mode matches the argmax of masked logits |
| 861 | + masked_logits = dist.masked_logits |
| 862 | + expected_mode = masked_logits.argmax(dim=-1) |
| 863 | + torch.testing.assert_close(mode, expected_mode) |
| 864 | + |
| 865 | + def test_entropy_correctness(self): |
| 866 | + """Test that entropy computation works correctly.""" |
| 867 | + torch.manual_seed(0) |
| 868 | + logits = torch.randn(2, 3, 4) |
| 869 | + mask = torch.ones(2, 3, dtype=torch.bool) |
| 870 | + mask[0, :1] = False |
| 871 | + |
| 872 | + dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100) |
| 873 | + |
| 874 | + entropy = dist.entropy() |
| 875 | + |
| 876 | + # Check shapes |
| 877 | + assert entropy.shape == (2, 3) |
| 878 | + |
| 879 | + # Check that entropy is finite for valid positions |
| 880 | + assert torch.isfinite(entropy[mask]).all() |
| 881 | + |
| 882 | + # Check that entropy is reasonable (positive for valid positions) |
| 883 | + assert (entropy[mask] >= 0).all() |
| 884 | + |
| 885 | + def test_clear_cache(self): |
| 886 | + """Test that cache clearing works correctly.""" |
| 887 | + torch.manual_seed(0) |
| 888 | + logits = torch.randn(2, 3, 4) |
| 889 | + mask = torch.ones(2, 3, dtype=torch.bool) |
| 890 | + |
| 891 | + dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100) |
| 892 | + |
| 893 | + # Access masked logits to populate cache |
| 894 | + _ = dist.masked_logits |
| 895 | + _ = dist.masked_dist |
| 896 | + |
| 897 | + # Check that cache is populated |
| 898 | + assert dist._masked_logits is not None |
| 899 | + assert dist._masked_dist is not None |
| 900 | + |
| 901 | + # Clear cache |
| 902 | + dist.clear_cache() |
| 903 | + |
| 904 | + # Check that cache is cleared |
| 905 | + assert dist._masked_logits is None |
| 906 | + assert dist._masked_dist is None |
| 907 | + |
| 908 | + def test_memory_efficiency(self): |
| 909 | + """Test that the distribution is memory efficient.""" |
| 910 | + import psutil |
| 911 | + import os |
| 912 | + |
| 913 | + def get_memory_usage(): |
| 914 | + process = psutil.Process(os.getpid()) |
| 915 | + return process.memory_info().rss / 1024 / 1024 |
| 916 | + |
| 917 | + torch.manual_seed(0) |
| 918 | + # Use larger tensors to make memory differences more apparent |
| 919 | + logits = torch.randn(4, 512, 1000, requires_grad=True) |
| 920 | + mask = torch.ones(4, 512, dtype=torch.bool) |
| 921 | + mask[:, :100] = False # mask first 100 tokens |
| 922 | + |
| 923 | + target_tokens = torch.randint(0, 1000, (4, 512)) |
| 924 | + target_tokens[~mask] = -100 |
| 925 | + |
| 926 | + # Test current MaskedCategorical approach |
| 927 | + torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| 928 | + start_memory = get_memory_usage() |
| 929 | + |
| 930 | + dist_current = MaskedCategorical(logits=logits, mask=mask, use_cross_entropy=True) |
| 931 | + log_probs_current = dist_current.log_prob(target_tokens) |
| 932 | + loss_current = -log_probs_current.sum() |
| 933 | + loss_current.backward() |
| 934 | + |
| 935 | + current_memory = get_memory_usage() - start_memory |
| 936 | + |
| 937 | + # Reset gradients |
| 938 | + logits.grad = None |
| 939 | + |
| 940 | + # Test new LLMMaskedCategorical approach |
| 941 | + torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| 942 | + start_memory = get_memory_usage() |
| 943 | + |
| 944 | + dist_new = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100) |
| 945 | + log_probs_new = dist_new.log_prob(target_tokens) |
| 946 | + loss_new = -log_probs_new.sum() |
| 947 | + loss_new.backward() |
| 948 | + |
| 949 | + new_memory = get_memory_usage() - start_memory |
| 950 | + |
| 951 | + # Check that results are similar |
| 952 | + torch.testing.assert_close(loss_current, loss_new, atol=1e-4, rtol=1e-4) |
| 953 | + |
| 954 | + # Check that new approach uses less memory |
| 955 | + # Note: This might not always be true due to memory fragmentation, |
| 956 | + # but it should generally be the case |
| 957 | + print(f"Current approach memory: {current_memory:.2f} MB") |
| 958 | + print(f"New approach memory: {new_memory:.2f} MB") |
| 959 | + |
| 960 | + def test_ignore_index_variations(self): |
| 961 | + """Test with different ignore_index values.""" |
| 962 | + torch.manual_seed(0) |
| 963 | + logits = torch.randn(2, 3, 4) |
| 964 | + mask = torch.ones(2, 3, dtype=torch.bool) |
| 965 | + mask[0, :1] = False |
| 966 | + |
| 967 | + # Test with different ignore_index values |
| 968 | + for ignore_idx in [-100, -1, 999]: |
| 969 | + dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=ignore_idx) |
| 970 | + |
| 971 | + target_tokens = torch.randint(0, 4, (2, 3)) |
| 972 | + target_tokens[~mask] = ignore_idx |
| 973 | + |
| 974 | + log_probs = dist.log_prob(target_tokens) |
| 975 | + |
| 976 | + # Check that ignored positions have zero log probability |
| 977 | + assert (log_probs[~mask] == 0.0).all() |
| 978 | + |
| 979 | + # Check that valid positions have finite log probabilities |
| 980 | + assert torch.isfinite(log_probs[mask]).all() |
| 981 | + |
| 982 | + def test_large_vocabulary(self): |
| 983 | + """Test with large vocabulary size to ensure scalability.""" |
| 984 | + torch.manual_seed(0) |
| 985 | + vocab_size = 50000 # Large vocabulary like in LLMs |
| 986 | + logits = torch.randn(2, 10, vocab_size) |
| 987 | + mask = torch.ones(2, 10, dtype=torch.bool) |
| 988 | + mask[0, :3] = False # mask first 3 tokens of first sequence |
| 989 | + |
| 990 | + dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100) |
| 991 | + |
| 992 | + # Test log_prob |
| 993 | + target_tokens = torch.randint(0, vocab_size, (2, 10)) |
| 994 | + target_tokens[~mask] = -100 |
| 995 | + |
| 996 | + log_probs = dist.log_prob(target_tokens) |
| 997 | + assert log_probs.shape == (2, 10) |
| 998 | + assert (log_probs[~mask] == 0.0).all() |
| 999 | + assert torch.isfinite(log_probs[mask]).all() |
| 1000 | + |
| 1001 | + # Test sampling |
| 1002 | + samples = dist.sample() |
| 1003 | + assert samples.shape == (2, 10) |
| 1004 | + assert (samples >= 0).all() |
| 1005 | + assert (samples < vocab_size).all() |
| 1006 | + |
| 1007 | + # Test mode |
| 1008 | + mode = dist.mode |
| 1009 | + assert mode.shape == (2, 10) |
| 1010 | + assert (mode >= 0).all() |
| 1011 | + assert (mode < vocab_size).all() |
| 1012 | + |
| 1013 | + def test_comparison_with_masked_categorical(self): |
| 1014 | + """Compare results with the original MaskedCategorical.""" |
| 1015 | + torch.manual_seed(0) |
| 1016 | + logits = torch.randn(2, 3, 4) |
| 1017 | + mask = torch.ones(2, 3, dtype=torch.bool) |
| 1018 | + mask[0, :1] = False |
| 1019 | + |
| 1020 | + # Create both distributions |
| 1021 | + dist_original = MaskedCategorical(logits=logits, mask=mask, use_cross_entropy=True) |
| 1022 | + dist_new = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=-100) |
| 1023 | + |
| 1024 | + # Test log_prob with valid tokens |
| 1025 | + target_tokens = torch.randint(0, 4, (2, 3)) |
| 1026 | + target_tokens[~mask] = -100 # Set masked positions to ignore_index |
| 1027 | + |
| 1028 | + log_probs_original = dist_original.log_prob(target_tokens) |
| 1029 | + log_probs_new = dist_new.log_prob(target_tokens) |
| 1030 | + |
| 1031 | + # Results should be similar (might have small numerical differences) |
| 1032 | + torch.testing.assert_close(log_probs_original, log_probs_new, atol=1e-4, rtol=1e-4) |
| 1033 | + |
| 1034 | + # Test sampling |
| 1035 | + samples_original = dist_original.sample() |
| 1036 | + samples_new = dist_new.sample() |
| 1037 | + |
| 1038 | + # Both should produce valid samples |
| 1039 | + assert (samples_original >= 0).all() |
| 1040 | + assert (samples_original < 4).all() |
| 1041 | + assert (samples_new >= 0).all() |
| 1042 | + assert (samples_new < 4).all() |
| 1043 | + |
| 1044 | + def test_error_handling(self): |
| 1045 | + """Test error handling for invalid inputs.""" |
| 1046 | + torch.manual_seed(0) |
| 1047 | + logits = torch.randn(2, 3, 4) |
| 1048 | + mask = torch.ones(2, 3, dtype=torch.bool) |
| 1049 | + |
| 1050 | + # Test with mismatched shapes |
| 1051 | + with pytest.raises(ValueError): |
| 1052 | + LLMMaskedCategorical(logits=logits, mask=torch.ones(3, 4, dtype=torch.bool)) |
| 1053 | + |
| 1054 | + # Test with invalid ignore_index (should be fine since we don't validate this) |
| 1055 | + # The current implementation doesn't validate ignore_index, so this should work |
| 1056 | + dist = LLMMaskedCategorical(logits=logits, mask=mask, ignore_index=5) |
| 1057 | + assert dist.ignore_index == 5 |
| 1058 | + |
| 1059 | + |
752 | 1060 | class TestOrdinal:
|
753 | 1061 | @pytest.mark.parametrize("dtype", [torch.float, torch.double])
|
754 | 1062 | @pytest.mark.parametrize("device", get_default_devices())
|
|
0 commit comments