Skip to content

Commit f81c9e3

Browse files
author
Vincent Moens
committed
Update (base update)
[ghstack-poisoned]
2 parents d6ca42f + 61e05b3 commit f81c9e3

File tree

10 files changed

+485
-12
lines changed

10 files changed

+485
-12
lines changed

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,7 @@ to be able to create this other composition:
829829
GrayScale
830830
InitTracker
831831
KLRewardTransform
832+
LineariseReward
832833
NoopResetEnv
833834
ObservationNorm
834835
ObservationTransform

test/test_transforms.py

Lines changed: 330 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
from torchrl._utils import _replace_last, prod
8585
from torchrl.data import (
8686
Bounded,
87+
BoundedContinuous,
8788
Categorical,
8889
Composite,
8990
LazyTensorStorage,
@@ -92,6 +93,7 @@
9293
TensorSpec,
9394
TensorStorage,
9495
Unbounded,
96+
UnboundedContinuous,
9597
)
9698
from torchrl.envs import (
9799
ActionMask,
@@ -117,6 +119,7 @@
117119
GrayScale,
118120
gSDENoise,
119121
InitTracker,
122+
LineariseRewards,
120123
MultiStepTransform,
121124
NoopResetEnv,
122125
ObservationNorm,
@@ -412,7 +415,7 @@ def test_transform_rb(self, rbclass):
412415
assert ((sample["reward"] == 0) | (sample["reward"] == 1)).all()
413416

414417
def test_transform_inverse(self):
415-
raise pytest.skip("No inverse for BinerizedReward")
418+
raise pytest.skip("No inverse for BinarizedReward")
416419

417420

418421
class TestClipTransform(TransformBase):
@@ -12403,6 +12406,332 @@ def test_transform_inverse(self):
1240312406
pytest.skip("Tested elsewhere")
1240412407

1240512408

12409+
class TestLineariseRewards(TransformBase):
12410+
def test_weight_shape_error(self):
12411+
with pytest.raises(
12412+
ValueError, match="Expected weights to be a unidimensional tensor"
12413+
):
12414+
LineariseRewards(in_keys=("reward",), weights=torch.ones(size=(2, 4)))
12415+
12416+
def test_weight_sign_error(self):
12417+
with pytest.raises(ValueError, match="Expected all weights to be >0"):
12418+
LineariseRewards(in_keys=("reward",), weights=-torch.ones(size=(2,)))
12419+
12420+
def test_discrete_spec_error(self):
12421+
with pytest.raises(
12422+
NotImplementedError,
12423+
match="Aggregation of rewards that take discrete values is not supported.",
12424+
):
12425+
transform = LineariseRewards(in_keys=("reward",))
12426+
reward_spec = Categorical(n=2)
12427+
transform.transform_reward_spec(reward_spec)
12428+
12429+
@pytest.mark.parametrize(
12430+
"reward_spec",
12431+
[
12432+
UnboundedContinuous(shape=3),
12433+
BoundedContinuous(0, 1, shape=2),
12434+
],
12435+
)
12436+
def test_single_trans_env_check(self, reward_spec: TensorSpec):
12437+
env = TransformedEnv(
12438+
ContinuousActionVecMockEnv(reward_spec=reward_spec),
12439+
LineariseRewards(in_keys=["reward"]), # will use default weights
12440+
)
12441+
check_env_specs(env)
12442+
12443+
@pytest.mark.parametrize(
12444+
"reward_spec",
12445+
[
12446+
UnboundedContinuous(shape=3),
12447+
BoundedContinuous(0, 1, shape=2),
12448+
],
12449+
)
12450+
def test_serial_trans_env_check(self, reward_spec: TensorSpec):
12451+
def make_env():
12452+
return TransformedEnv(
12453+
ContinuousActionVecMockEnv(reward_spec=reward_spec),
12454+
LineariseRewards(in_keys=["reward"]), # will use default weights
12455+
)
12456+
12457+
env = SerialEnv(2, make_env)
12458+
check_env_specs(env)
12459+
12460+
@pytest.mark.parametrize(
12461+
"reward_spec",
12462+
[
12463+
UnboundedContinuous(shape=3),
12464+
BoundedContinuous(0, 1, shape=2),
12465+
],
12466+
)
12467+
def test_parallel_trans_env_check(
12468+
self, maybe_fork_ParallelEnv, reward_spec: TensorSpec
12469+
):
12470+
def make_env():
12471+
return TransformedEnv(
12472+
ContinuousActionVecMockEnv(reward_spec=reward_spec),
12473+
LineariseRewards(in_keys=["reward"]), # will use default weights
12474+
)
12475+
12476+
env = maybe_fork_ParallelEnv(2, make_env)
12477+
try:
12478+
check_env_specs(env)
12479+
finally:
12480+
try:
12481+
env.close()
12482+
except RuntimeError:
12483+
pass
12484+
12485+
@pytest.mark.parametrize(
12486+
"reward_spec",
12487+
[
12488+
UnboundedContinuous(shape=3),
12489+
BoundedContinuous(0, 1, shape=2),
12490+
],
12491+
)
12492+
def test_trans_serial_env_check(self, reward_spec: TensorSpec):
12493+
def make_env():
12494+
return ContinuousActionVecMockEnv(reward_spec=reward_spec)
12495+
12496+
env = TransformedEnv(
12497+
SerialEnv(2, make_env), LineariseRewards(in_keys=["reward"])
12498+
)
12499+
check_env_specs(env)
12500+
12501+
@pytest.mark.parametrize(
12502+
"reward_spec",
12503+
[
12504+
UnboundedContinuous(shape=3),
12505+
BoundedContinuous(0, 1, shape=2),
12506+
],
12507+
)
12508+
def test_trans_parallel_env_check(
12509+
self, maybe_fork_ParallelEnv, reward_spec: TensorSpec
12510+
):
12511+
def make_env():
12512+
return ContinuousActionVecMockEnv(reward_spec=reward_spec)
12513+
12514+
env = TransformedEnv(
12515+
maybe_fork_ParallelEnv(2, make_env),
12516+
LineariseRewards(in_keys=["reward"]),
12517+
)
12518+
try:
12519+
check_env_specs(env)
12520+
finally:
12521+
try:
12522+
env.close()
12523+
except RuntimeError:
12524+
pass
12525+
12526+
@pytest.mark.parametrize("reward_key", [("reward",), ("agents", "reward")])
12527+
@pytest.mark.parametrize(
12528+
"num_rewards, weights",
12529+
[
12530+
(1, None),
12531+
(3, None),
12532+
(2, [1.0, 2.0]),
12533+
],
12534+
)
12535+
def test_transform_no_env(self, reward_key, num_rewards, weights):
12536+
out_keys = reward_key[:-1] + ("scalar_reward",)
12537+
t = LineariseRewards(in_keys=[reward_key], out_keys=[out_keys], weights=weights)
12538+
td = TensorDict({reward_key: torch.randn(num_rewards)}, [])
12539+
t._call(td)
12540+
12541+
weights = torch.ones(num_rewards) if weights is None else torch.tensor(weights)
12542+
expected = sum(
12543+
w * r
12544+
for w, r in zip(
12545+
weights,
12546+
td[reward_key],
12547+
)
12548+
)
12549+
torch.testing.assert_close(td[out_keys], expected)
12550+
12551+
@pytest.mark.parametrize("reward_key", [("reward",), ("agents", "reward")])
12552+
@pytest.mark.parametrize(
12553+
"num_rewards, weights",
12554+
[
12555+
(1, None),
12556+
(3, None),
12557+
(2, [1.0, 2.0]),
12558+
],
12559+
)
12560+
def test_transform_compose(self, reward_key, num_rewards, weights):
12561+
out_keys = reward_key[:-1] + ("scalar_reward",)
12562+
t = Compose(
12563+
LineariseRewards(in_keys=[reward_key], out_keys=[out_keys], weights=weights)
12564+
)
12565+
td = TensorDict({reward_key: torch.randn(num_rewards)}, [])
12566+
t._call(td)
12567+
12568+
weights = torch.ones(num_rewards) if weights is None else torch.tensor(weights)
12569+
expected = sum(
12570+
w * r
12571+
for w, r in zip(
12572+
weights,
12573+
td[reward_key],
12574+
)
12575+
)
12576+
torch.testing.assert_close(td[out_keys], expected)
12577+
12578+
class _DummyMultiObjectiveEnv(EnvBase):
12579+
"""A dummy multi-objective environment."""
12580+
12581+
def __init__(self, num_rewards: int) -> None:
12582+
super().__init__()
12583+
self._num_rewards = num_rewards
12584+
12585+
self.observation_spec = Composite(
12586+
observation=UnboundedContinuous((*self.batch_size, 3))
12587+
)
12588+
self.action_spec = Categorical(2, (*self.batch_size, 1), dtype=torch.bool)
12589+
self.done_spec = Categorical(2, (*self.batch_size, 1), dtype=torch.bool)
12590+
self.full_done_spec["truncated"] = self.full_done_spec["terminated"].clone()
12591+
self.reward_spec = UnboundedContinuous(*self.batch_size, num_rewards)
12592+
12593+
def _reset(self, tensordict: TensorDict) -> TensorDict:
12594+
return self.observation_spec.sample()
12595+
12596+
def _step(self, tensordict: TensorDict) -> TensorDict:
12597+
done, terminated = False, False
12598+
reward = torch.randn((self._num_rewards,))
12599+
12600+
return TensorDict(
12601+
{
12602+
("observation"): self.observation_spec["observation"].sample(),
12603+
("done"): done,
12604+
("terminated"): terminated,
12605+
("reward"): reward,
12606+
}
12607+
)
12608+
12609+
def _set_seed(self) -> None:
12610+
pass
12611+
12612+
@pytest.mark.parametrize(
12613+
"num_rewards, weights",
12614+
[
12615+
(1, None),
12616+
(3, None),
12617+
(2, [1.0, 2.0]),
12618+
],
12619+
)
12620+
def test_transform_env(self, num_rewards, weights):
12621+
weights = weights if weights is not None else [1.0 for _ in range(num_rewards)]
12622+
12623+
transform = LineariseRewards(
12624+
in_keys=("reward",), out_keys=("scalar_reward",), weights=weights
12625+
)
12626+
env = TransformedEnv(self._DummyMultiObjectiveEnv(num_rewards), transform)
12627+
rollout = env.rollout(10)
12628+
scalar_reward = rollout.get(("next", "scalar_reward"))
12629+
assert scalar_reward.shape[-1] == 1
12630+
12631+
expected = sum(
12632+
w * r
12633+
for w, r in zip(weights, rollout.get(("next", "reward")).split(1, dim=-1))
12634+
)
12635+
torch.testing.assert_close(scalar_reward, expected)
12636+
12637+
@pytest.mark.parametrize(
12638+
"num_rewards, weights",
12639+
[
12640+
(1, None),
12641+
(3, None),
12642+
(2, [1.0, 2.0]),
12643+
],
12644+
)
12645+
def test_transform_model(self, num_rewards, weights):
12646+
weights = weights if weights is not None else [1.0 for _ in range(num_rewards)]
12647+
transform = LineariseRewards(
12648+
in_keys=("reward",), out_keys=("scalar_reward",), weights=weights
12649+
)
12650+
12651+
model = nn.Sequential(transform, nn.Identity())
12652+
td = TensorDict({"reward": torch.randn(num_rewards)}, [])
12653+
model(td)
12654+
12655+
expected = sum(w * r for w, r in zip(weights, td["reward"]))
12656+
torch.testing.assert_close(td["scalar_reward"], expected)
12657+
12658+
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
12659+
def test_transform_rb(self, rbclass):
12660+
num_rewards = 3
12661+
weights = None
12662+
transform = LineariseRewards(
12663+
in_keys=("reward",), out_keys=("scalar_reward",), weights=weights
12664+
)
12665+
12666+
rb = rbclass(storage=LazyTensorStorage(10))
12667+
td = TensorDict({"reward": torch.randn(num_rewards)}, []).expand(10)
12668+
rb.append_transform(transform)
12669+
rb.extend(td)
12670+
12671+
td = rb.sample(2)
12672+
torch.testing.assert_close(td["scalar_reward"], td["reward"].sum(-1))
12673+
12674+
def test_transform_inverse(self):
12675+
raise pytest.skip("No inverse for LineariseReward")
12676+
12677+
@pytest.mark.parametrize(
12678+
"weights, reward_spec, expected_spec",
12679+
[
12680+
(None, UnboundedContinuous(shape=3), UnboundedContinuous(shape=1)),
12681+
(
12682+
None,
12683+
BoundedContinuous(0, 1, shape=3),
12684+
BoundedContinuous(0, 3, shape=1),
12685+
),
12686+
(
12687+
None,
12688+
BoundedContinuous(low=[-1.0, -2.0], high=[1.0, 2.0]),
12689+
BoundedContinuous(low=-3.0, high=3.0, shape=1),
12690+
),
12691+
(
12692+
[1.0, 0.0],
12693+
BoundedContinuous(
12694+
low=[-1.0, -2.0],
12695+
high=[1.0, 2.0],
12696+
shape=2,
12697+
),
12698+
BoundedContinuous(low=-1.0, high=1.0, shape=1),
12699+
),
12700+
],
12701+
)
12702+
def test_reward_spec(
12703+
self,
12704+
weights,
12705+
reward_spec: TensorSpec,
12706+
expected_spec: TensorSpec,
12707+
) -> None:
12708+
transform = LineariseRewards(in_keys=("reward",), weights=weights)
12709+
assert transform.transform_reward_spec(reward_spec) == expected_spec
12710+
12711+
def test_composite_reward_spec(self) -> None:
12712+
weights = None
12713+
reward_spec = Composite(
12714+
agent_0=Composite(
12715+
reward=BoundedContinuous(low=[0, 0, 0], high=[1, 1, 1], shape=3)
12716+
),
12717+
agent_1=Composite(
12718+
reward=BoundedContinuous(
12719+
low=[-1, -1, -1],
12720+
high=[1, 1, 1],
12721+
shape=3,
12722+
)
12723+
),
12724+
)
12725+
expected_reward_spec = Composite(
12726+
agent_0=Composite(reward=BoundedContinuous(low=0, high=3, shape=1)),
12727+
agent_1=Composite(reward=BoundedContinuous(low=-3, high=3, shape=1)),
12728+
)
12729+
transform = LineariseRewards(
12730+
in_keys=[("agent_0", "reward"), ("agent_1", "reward")], weights=weights
12731+
)
12732+
assert transform.transform_reward_spec(reward_spec) == expected_reward_spec
12733+
12734+
1240612735
if __name__ == "__main__":
1240712736
args, unknown = argparse.ArgumentParser().parse_known_args()
1240812737
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
Binary,
7373
BinaryDiscreteTensorSpec,
7474
Bounded,
75+
BoundedContinuous,
7576
BoundedTensorSpec,
7677
Categorical,
7778
Composite,

torchrl/envs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
gSDENoise,
7070
InitTracker,
7171
KLRewardTransform,
72+
LineariseRewards,
7273
MultiStepTransform,
7374
NoopResetEnv,
7475
ObservationNorm,

torchrl/envs/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
GrayScale,
3333
gSDENoise,
3434
InitTracker,
35+
LineariseRewards,
3536
NoopResetEnv,
3637
ObservationNorm,
3738
ObservationTransform,

0 commit comments

Comments
 (0)