Skip to content

Commit 3e1f4ff

Browse files
author
Vincent Moens
committed
[Refactor] MaskedCategorical cross_entropy usage for faster loss
ghstack-source-id: 84330cf Pull Request resolved: #2882
1 parent 9c4c086 commit 3e1f4ff

File tree

4 files changed

+100
-10
lines changed

4 files changed

+100
-10
lines changed

test/test_distributions.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,46 @@ def test_sample_sparse(self, neg_inf: float) -> None:
488488
sample_probs = torch.bincount(samples) / num_samples
489489
torch.testing.assert_close(sample_probs, ref_probs, rtol=1e-5, atol=1e-2)
490490

491+
@pytest.mark.parametrize("neg_inf", [-1e20, float("-inf")])
492+
@pytest.mark.parametrize("sparse", [False, True])
493+
@pytest.mark.parametrize("ndim", [2, 1, 3])
494+
def test_crossentropy(self, sparse: bool, neg_inf: float, ndim: int):
495+
torch.manual_seed(0)
496+
logits = torch.randn(4).log_softmax(dim=-1)
497+
# probs = logits.exp()
498+
mask = torch.tensor([True, False, True, True])
499+
indices = torch.tensor([0, 2, 3])
500+
501+
if ndim >= 2:
502+
mask = mask.unsqueeze(0)
503+
logits = logits.unsqueeze(0)
504+
indices = indices.unsqueeze(0)
505+
if ndim == 3:
506+
mask = mask.unsqueeze(0)
507+
logits = logits.unsqueeze(0)
508+
indices = indices.unsqueeze(0)
509+
510+
dist_ce = MaskedCategorical(
511+
logits=logits,
512+
neg_inf=neg_inf,
513+
mask=mask if not sparse else None,
514+
indices=indices if sparse else None,
515+
use_cross_entropy=True,
516+
)
517+
dist = MaskedCategorical(
518+
logits=logits,
519+
neg_inf=neg_inf,
520+
mask=mask if not sparse else None,
521+
indices=indices if sparse else None,
522+
use_cross_entropy=False,
523+
)
524+
data = torch.tensor(0)
525+
if ndim >= 2:
526+
data = data.unsqueeze(0)
527+
if ndim == 3:
528+
data = data.unsqueeze(0)
529+
torch.testing.assert_close(dist.log_prob(data), dist_ce.log_prob(data))
530+
491531

492532
class TestOneHotCategorical:
493533
def test_one_hot(self):

torchrl/modules/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
)
9494
from .utils import get_primers_from_module
9595
from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip
96-
from .llm import TransformersWrapper, vLLMWrapper
96+
from .llm import CategoricalSequential, TransformersWrapper, vLLMWrapper
9797

9898
__all__ = [
9999
"Actor",
@@ -109,6 +109,7 @@
109109
"Conv3dNet",
110110
"ConvNet",
111111
"DTActor",
112+
"CategoricalSequential",
112113
"DdpgCnnActor",
113114
"DdpgCnnQNet",
114115
"DdpgMlpActor",

torchrl/modules/distributions/discrete.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ class MaskedCategorical(D.Categorical):
184184
invalid (out-of-mask) indices. Defaults to -inf.
185185
padding_value: The padding value in the mask tensor. When
186186
sparse_mask == True, the padding_value will be ignored.
187+
use_cross_entropy (bool, optional): For faster computation of the log-probability,
188+
the cross_entropy loss functional can be used. Defaults to ``False``.
187189
188190
Examples:
189191
>>> torch.manual_seed(0)
@@ -225,6 +227,7 @@ def __init__(
225227
indices: torch.Tensor = None,
226228
neg_inf: float = float("-inf"),
227229
padding_value: int | None = None,
230+
use_cross_entropy: bool = False,
228231
) -> None:
229232
if not ((mask is None) ^ (indices is None)):
230233
raise ValueError(
@@ -247,6 +250,7 @@ def __init__(
247250
probs = probs / probs.sum(-1, keepdim=True)
248251
logits = probs.log()
249252
num_samples = logits.shape[-1]
253+
self.use_cross_entropy = use_cross_entropy
250254
logits = self._mask_logits(
251255
logits,
252256
mask,
@@ -282,19 +286,36 @@ def sample(
282286

283287
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
284288
if not self._sparse_mask:
285-
return super().log_prob(value)
289+
if self.use_cross_entropy:
290+
logits = self.logits
291+
if logits.ndim > 2:
292+
# Bring channels in 2nd dim
293+
logits = logits.transpose(-1, 1)
294+
result = -torch.nn.functional.cross_entropy(logits, value, reduce=False)
295+
else:
296+
result = super().log_prob(value)
297+
result = torch.where(torch.isfinite(result), result, self.neg_inf)
298+
return result
286299

287300
idx_3d = self._mask.view(1, -1, self._num_events)
288301
val_3d = value.view(-1, idx_3d.size(1), 1)
289302
mask = idx_3d == val_3d
290303
idx = mask.int().argmax(dim=-1, keepdim=True)
291-
ret = super().log_prob(idx.view_as(value))
304+
idx = idx.view_as(value)
305+
if self.use_cross_entropy:
306+
logits = self.logits
307+
if logits.ndim > 2:
308+
# Bring channels in 2nd dim
309+
logits = logits.transpose(-1, 1)
310+
ret = -torch.nn.functional.cross_entropy(logits, idx, reduce=False)
311+
else:
312+
ret = super().log_prob(idx)
292313
# Fill masked values with neg_inf.
293314
ret = ret.view_as(val_3d)
294315
ret = ret.masked_fill(
295316
torch.logical_not(mask.any(dim=-1, keepdim=True)), self.neg_inf
296317
)
297-
return ret.resize_as(value)
318+
return ret.view_as(value)
298319

299320
@staticmethod
300321
def _mask_logits(

torchrl/modules/llm/common.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7+
import torch
78
from tensordict import NestedKey, TensorDictBase
8-
from tensordict.nn import (
9-
ProbabilisticTensorDictModule,
10-
TensorDictModuleBase,
11-
TensorDictSequential,
12-
)
9+
from tensordict.nn import TensorDictModuleBase, TensorDictSequential
1310
from torch import distributions as D
1411
from torch.distributions import Categorical
12+
from torchrl.modules import MaskedCategorical
1513

1614

1715
class CategoricalSequential(TensorDictModuleBase):
@@ -21,14 +19,44 @@ class CategoricalSequential(TensorDictModuleBase):
2119
2220
"""
2321

22+
generate: bool
23+
2424
def get_dist(
2525
self,
2626
tensordict: TensorDictBase,
2727
tensordict_out: TensorDictBase | None = None,
28+
as_padded_tensor: bool | None = None,
29+
as_nested_tensor: bool | None = None,
30+
padding_value: float | None = None,
31+
padding_side: str = "right",
32+
layout: torch.layout | None = None,
2833
**kwargs,
2934
) -> D.Distribution:
3035
td_out = self(tensordict.copy())
31-
return Categorical(td_out.get("logits"))
36+
# By default, pad and use masked categorical
37+
if as_padded_tensor is None:
38+
as_padded_tensor = as_nested_tensor is not True
39+
if padding_value is None:
40+
padding_value = 0.0
41+
if as_nested_tensor is None:
42+
as_nested_tensor = False
43+
logits = td_out.get(
44+
"logits",
45+
as_padded_tensor=as_padded_tensor,
46+
as_nested_tensor=as_nested_tensor,
47+
padding_value=padding_value,
48+
padding_side=padding_side,
49+
layout=layout,
50+
)
51+
if as_padded_tensor:
52+
# We can use MaskedCategorical
53+
dist = MaskedCategorical(
54+
logits=logits,
55+
mask=logits != padding_value,
56+
# use_cross_entropy=True,
57+
)
58+
return dist
59+
return Categorical(logits)
3260

3361
# Sampling is taken care of by the sub-modules
3462
forward = TensorDictSequential.forward

0 commit comments

Comments
 (0)