Skip to content

Commit 830f2f2

Browse files
author
Vincent Moens
committed
[BugFix] ActionDiscretizer scalar integration
ghstack-source-id: b22102f Pull Request resolved: #2619
1 parent de61e4d commit 830f2f2

File tree

3 files changed

+213
-48
lines changed

3 files changed

+213
-48
lines changed

test/mocking_classes.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,3 +1927,72 @@ def _step(
19271927
def _set_seed(self, seed: Optional[int]):
19281928
self.manual_seed = seed
19291929
return seed
1930+
1931+
1932+
class EnvWithScalarAction(EnvBase):
1933+
def __init__(self, singleton: bool = False, **kwargs):
1934+
super().__init__(**kwargs)
1935+
self.singleton = singleton
1936+
self.action_spec = Bounded(
1937+
-1,
1938+
1,
1939+
shape=(
1940+
*self.batch_size,
1941+
1,
1942+
)
1943+
if self.singleton
1944+
else self.batch_size,
1945+
)
1946+
self.observation_spec = Composite(
1947+
observation=Unbounded(
1948+
shape=(
1949+
*self.batch_size,
1950+
3,
1951+
)
1952+
),
1953+
shape=self.batch_size,
1954+
)
1955+
self.done_spec = Composite(
1956+
done=Unbounded(self.batch_size + (1,), dtype=torch.bool),
1957+
terminated=Unbounded(self.batch_size + (1,), dtype=torch.bool),
1958+
truncated=Unbounded(self.batch_size + (1,), dtype=torch.bool),
1959+
shape=self.batch_size,
1960+
)
1961+
self.reward_spec = Unbounded(
1962+
shape=(
1963+
*self.batch_size,
1964+
1,
1965+
)
1966+
)
1967+
1968+
def _reset(self, td: TensorDict):
1969+
return TensorDict(
1970+
observation=torch.randn(*self.batch_size, 3, device=self.device),
1971+
done=torch.zeros(*self.batch_size, 1, dtype=torch.bool, device=self.device),
1972+
truncated=torch.zeros(
1973+
*self.batch_size, 1, dtype=torch.bool, device=self.device
1974+
),
1975+
terminated=torch.zeros(
1976+
*self.batch_size, 1, dtype=torch.bool, device=self.device
1977+
),
1978+
device=self.device,
1979+
)
1980+
1981+
def _step(
1982+
self,
1983+
tensordict: TensorDictBase,
1984+
) -> TensorDictBase:
1985+
return TensorDict(
1986+
observation=torch.randn(*self.batch_size, 3, device=self.device),
1987+
reward=torch.zeros(1, device=self.device),
1988+
done=torch.zeros(*self.batch_size, 1, dtype=torch.bool, device=self.device),
1989+
truncated=torch.zeros(
1990+
*self.batch_size, 1, dtype=torch.bool, device=self.device
1991+
),
1992+
terminated=torch.zeros(
1993+
*self.batch_size, 1, dtype=torch.bool, device=self.device
1994+
),
1995+
)
1996+
1997+
def _set_seed(self, seed: Optional[int]):
1998+
...

test/test_transforms.py

Lines changed: 83 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
CountingEnvCountPolicy,
4242
DiscreteActionConvMockEnv,
4343
DiscreteActionConvMockEnvNumpy,
44+
EnvWithScalarAction,
4445
IncrementingEnv,
4546
MockBatchedLockedEnv,
4647
MockBatchedUnLockedEnv,
@@ -66,6 +67,7 @@
6667
CountingEnvCountPolicy,
6768
DiscreteActionConvMockEnv,
6869
DiscreteActionConvMockEnvNumpy,
70+
EnvWithScalarAction,
6971
IncrementingEnv,
7072
MockBatchedLockedEnv,
7173
MockBatchedUnLockedEnv,
@@ -11781,17 +11783,33 @@ def test_transform_inverse(self):
1178111783

1178211784
class TestActionDiscretizer(TransformBase):
1178311785
@pytest.mark.parametrize("categorical", [True, False])
11784-
def test_single_trans_env_check(self, categorical):
11785-
base_env = ContinuousActionVecMockEnv()
11786+
@pytest.mark.parametrize(
11787+
"env_cls",
11788+
[
11789+
ContinuousActionVecMockEnv,
11790+
partial(EnvWithScalarAction, singleton=True),
11791+
partial(EnvWithScalarAction, singleton=False),
11792+
],
11793+
)
11794+
def test_single_trans_env_check(self, categorical, env_cls):
11795+
base_env = env_cls()
1178611796
env = base_env.append_transform(
1178711797
ActionDiscretizer(num_intervals=5, categorical=categorical)
1178811798
)
1178911799
check_env_specs(env)
1179011800

1179111801
@pytest.mark.parametrize("categorical", [True, False])
11792-
def test_serial_trans_env_check(self, categorical):
11802+
@pytest.mark.parametrize(
11803+
"env_cls",
11804+
[
11805+
ContinuousActionVecMockEnv,
11806+
partial(EnvWithScalarAction, singleton=True),
11807+
partial(EnvWithScalarAction, singleton=False),
11808+
],
11809+
)
11810+
def test_serial_trans_env_check(self, categorical, env_cls):
1179311811
def make_env():
11794-
base_env = ContinuousActionVecMockEnv()
11812+
base_env = env_cls()
1179511813
return base_env.append_transform(
1179611814
ActionDiscretizer(num_intervals=5, categorical=categorical)
1179711815
)
@@ -11800,9 +11818,17 @@ def make_env():
1180011818
check_env_specs(env)
1180111819

1180211820
@pytest.mark.parametrize("categorical", [True, False])
11803-
def test_parallel_trans_env_check(self, categorical):
11821+
@pytest.mark.parametrize(
11822+
"env_cls",
11823+
[
11824+
ContinuousActionVecMockEnv,
11825+
partial(EnvWithScalarAction, singleton=True),
11826+
partial(EnvWithScalarAction, singleton=False),
11827+
],
11828+
)
11829+
def test_parallel_trans_env_check(self, categorical, env_cls):
1180411830
def make_env():
11805-
base_env = ContinuousActionVecMockEnv()
11831+
base_env = env_cls()
1180611832
env = base_env.append_transform(
1180711833
ActionDiscretizer(num_intervals=5, categorical=categorical)
1180811834
)
@@ -11812,17 +11838,33 @@ def make_env():
1181211838
check_env_specs(env)
1181311839

1181411840
@pytest.mark.parametrize("categorical", [True, False])
11815-
def test_trans_serial_env_check(self, categorical):
11816-
env = SerialEnv(2, ContinuousActionVecMockEnv).append_transform(
11841+
@pytest.mark.parametrize(
11842+
"env_cls",
11843+
[
11844+
ContinuousActionVecMockEnv,
11845+
partial(EnvWithScalarAction, singleton=True),
11846+
partial(EnvWithScalarAction, singleton=False),
11847+
],
11848+
)
11849+
def test_trans_serial_env_check(self, categorical, env_cls):
11850+
env = SerialEnv(2, env_cls).append_transform(
1181711851
ActionDiscretizer(num_intervals=5, categorical=categorical)
1181811852
)
1181911853
check_env_specs(env)
1182011854

1182111855
@pytest.mark.parametrize("categorical", [True, False])
11822-
def test_trans_parallel_env_check(self, categorical):
11823-
env = ParallelEnv(
11824-
2, ContinuousActionVecMockEnv, mp_start_method=mp_ctx
11825-
).append_transform(ActionDiscretizer(num_intervals=5, categorical=categorical))
11856+
@pytest.mark.parametrize(
11857+
"env_cls",
11858+
[
11859+
ContinuousActionVecMockEnv,
11860+
partial(EnvWithScalarAction, singleton=True),
11861+
partial(EnvWithScalarAction, singleton=False),
11862+
],
11863+
)
11864+
def test_trans_parallel_env_check(self, categorical, env_cls):
11865+
env = ParallelEnv(2, env_cls, mp_start_method=mp_ctx).append_transform(
11866+
ActionDiscretizer(num_intervals=5, categorical=categorical)
11867+
)
1182611868
check_env_specs(env)
1182711869

1182811870
def test_transform_no_env(self):
@@ -11838,7 +11880,6 @@ def test_transform_compose(self):
1183811880
check_env_specs(env)
1183911881

1184011882
@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
11841-
@pytest.mark.parametrize("envname", ["cheetah", "pendulum"])
1184211883
@pytest.mark.parametrize("interval_as_tensor", [False, True])
1184311884
@pytest.mark.parametrize("categorical", [True, False])
1184411885
@pytest.mark.parametrize(
@@ -11851,15 +11892,37 @@ def test_transform_compose(self):
1185111892
ActionDiscretizer.SamplingStrategy.RANDOM,
1185211893
],
1185311894
)
11854-
def test_transform_env(self, envname, interval_as_tensor, categorical, sampling):
11895+
@pytest.mark.parametrize(
11896+
"env_cls",
11897+
[
11898+
"cheetah",
11899+
"pendulum",
11900+
partial(EnvWithScalarAction, singleton=True),
11901+
partial(EnvWithScalarAction, singleton=False),
11902+
],
11903+
)
11904+
def test_transform_env(self, env_cls, interval_as_tensor, categorical, sampling):
1185511905
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11856-
base_env = GymEnv(
11857-
HALFCHEETAH_VERSIONED() if envname == "cheetah" else PENDULUM_VERSIONED(),
11858-
device=device,
11859-
)
11860-
if interval_as_tensor:
11861-
num_intervals = torch.arange(5, 11 if envname == "cheetah" else 6)
11906+
if env_cls == "cheetah":
11907+
base_env = GymEnv(
11908+
HALFCHEETAH_VERSIONED(),
11909+
device=device,
11910+
)
11911+
num_intervals = torch.arange(5, 11)
11912+
elif env_cls == "pendulum":
11913+
base_env = GymEnv(
11914+
PENDULUM_VERSIONED(),
11915+
device=device,
11916+
)
11917+
num_intervals = torch.arange(5, 6)
1186211918
else:
11919+
base_env = env_cls(
11920+
device=device,
11921+
)
11922+
num_intervals = torch.arange(5, 6)
11923+
11924+
if not interval_as_tensor:
11925+
# override
1186311926
num_intervals = 5
1186411927
t = ActionDiscretizer(
1186511928
num_intervals=num_intervals,

torchrl/envs/transforms/transforms.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8585,24 +8585,32 @@ def _indent(s):
85858585

85868586
def transform_input_spec(self, input_spec):
85878587
try:
8588-
action_spec = input_spec["full_action_spec", self.in_keys_inv[0]]
8588+
action_spec = self.parent.full_action_spec_unbatched[self.in_keys_inv[0]]
85898589
if not isinstance(action_spec, Bounded):
85908590
raise TypeError(
8591-
f"action spec type {type(action_spec)} is not supported."
8591+
f"action spec type {type(action_spec)} is not supported. The action spec type must be Bounded."
85928592
)
85938593

85948594
n_act = action_spec.shape
85958595
if not n_act:
8596-
n_act = 1
8596+
n_act = ()
8597+
empty_shape = True
85978598
else:
8598-
n_act = n_act[-1]
8599+
n_act = (n_act[-1],)
8600+
empty_shape = False
85998601
self.n_act = n_act
86008602

86018603
self.dtype = action_spec.dtype
8602-
interval = (action_spec.high - action_spec.low).unsqueeze(-1)
8604+
interval = action_spec.high - action_spec.low
86038605

86048606
num_intervals = self.num_intervals
86058607

8608+
if not empty_shape:
8609+
interval = interval.unsqueeze(-1)
8610+
elif isinstance(num_intervals, torch.Tensor):
8611+
num_intervals = int(num_intervals.squeeze())
8612+
self.num_intervals = torch.as_tensor(num_intervals)
8613+
86068614
def custom_arange(nint):
86078615
result = torch.arange(
86088616
start=0.0,
@@ -8625,11 +8633,13 @@ def custom_arange(nint):
86258633

86268634
if isinstance(num_intervals, int):
86278635
arange = (
8628-
custom_arange(num_intervals).expand(n_act, num_intervals) * interval
8629-
)
8630-
self.register_buffer(
8631-
"intervals", action_spec.low.unsqueeze(-1) + arange
8636+
custom_arange(num_intervals).expand((*n_act, num_intervals))
8637+
* interval
86328638
)
8639+
low = action_spec.low
8640+
if not empty_shape:
8641+
low = low.unsqueeze(-1)
8642+
self.register_buffer("intervals", low + arange)
86338643
else:
86348644
arange = [
86358645
custom_arange(_num_intervals) * interval
@@ -8644,20 +8654,17 @@ def custom_arange(nint):
86448654
)
86458655
]
86468656

8647-
cls = (
8648-
functools.partial(MultiCategorical, remove_singleton=False)
8649-
if self.categorical
8650-
else MultiOneHot
8651-
)
8652-
86538657
if not isinstance(num_intervals, torch.Tensor):
86548658
nvec = torch.as_tensor(num_intervals, device=action_spec.device)
86558659
else:
86568660
nvec = num_intervals
86578661
if nvec.ndim > 1:
86588662
raise RuntimeError(f"Cannot use num_intervals with shape {nvec.shape}")
86598663
if nvec.ndim == 0 or nvec.numel() == 1:
8660-
nvec = nvec.expand(action_spec.shape[-1])
8664+
if not empty_shape:
8665+
nvec = nvec.expand(action_spec.shape[-1])
8666+
else:
8667+
nvec = nvec.squeeze()
86618668
self.register_buffer("nvec", nvec)
86628669
if self.sampling == self.SamplingStrategy.RANDOM:
86638670
# compute jitters
@@ -8667,7 +8674,22 @@ def custom_arange(nint):
86678674
if self.categorical
86688675
else (*action_spec.shape[:-1], nvec.sum())
86698676
)
8670-
action_spec = cls(nvec=nvec, shape=shape, device=action_spec.device)
8677+
8678+
if not empty_shape:
8679+
cls = (
8680+
functools.partial(MultiCategorical, remove_singleton=False)
8681+
if self.categorical
8682+
else MultiOneHot
8683+
)
8684+
action_spec = cls(nvec=nvec, shape=shape, device=action_spec.device)
8685+
8686+
else:
8687+
cls = Categorical if self.categorical else OneHot
8688+
action_spec = cls(n=int(nvec), shape=shape, device=action_spec.device)
8689+
8690+
batch_size = self.parent.batch_size
8691+
if batch_size:
8692+
action_spec = action_spec.expand(batch_size + action_spec.shape)
86718693
input_spec["full_action_spec", self.out_keys_inv[0]] = action_spec
86728694

86738695
if self.out_keys_inv[0] != self.in_keys_inv[0]:
@@ -8705,6 +8727,8 @@ def _inv_call(self, tensordict):
87058727
if self.categorical:
87068728
action = action.unsqueeze(-1)
87078729
if isinstance(intervals, torch.Tensor):
8730+
shape = action.shape[: -intervals.ndim]
8731+
intervals = intervals.expand(shape + intervals.shape)
87088732
action = intervals.gather(index=action, dim=-1).squeeze(-1)
87098733
else:
87108734
action = torch.stack(
@@ -8715,17 +8739,26 @@ def _inv_call(self, tensordict):
87158739
-1,
87168740
)
87178741
else:
8718-
nvec = self.nvec.tolist()
8719-
action = action.split(nvec, dim=-1)
8720-
if isinstance(intervals, torch.Tensor):
8721-
intervals = intervals.unbind(-2)
8722-
action = torch.stack(
8723-
[
8724-
intervals[action].view(action.shape[:-1])
8725-
for (intervals, action) in zip(intervals, action)
8726-
],
8727-
-1,
8728-
)
8742+
nvec = self.nvec
8743+
empty_shape = not nvec.ndim
8744+
if not empty_shape:
8745+
nvec = nvec.tolist()
8746+
if isinstance(intervals, torch.Tensor):
8747+
shape = action.shape[: (-intervals.ndim + 1)]
8748+
intervals = intervals.expand(shape + intervals.shape)
8749+
intervals = intervals.unbind(-2)
8750+
action = action.split(nvec, dim=-1)
8751+
action = torch.stack(
8752+
[
8753+
intervals[action].view(action.shape[:-1])
8754+
for (intervals, action) in zip(intervals, action)
8755+
],
8756+
-1,
8757+
)
8758+
else:
8759+
shape = action.shape[: -intervals.ndim]
8760+
intervals = intervals.expand(shape + intervals.shape)
8761+
action = intervals[action].squeeze(-1)
87298762

87308763
if self.sampling == self.SamplingStrategy.RANDOM:
87318764
action = action + self.jitters * torch.rand_like(self.jitters)

0 commit comments

Comments
 (0)