Skip to content

Commit 2046bc5

Browse files
author
Vincent Moens
committed
[Deprecation] Softly change default behavior of auto_unwrap
ghstack-source-id: c28c11e Pull Request resolved: #2793
1 parent 3e42e7a commit 2046bc5

File tree

5 files changed

+154
-20
lines changed

5 files changed

+154
-20
lines changed

docs/source/reference/utils.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.. currentmodule:: torchrl._utils
1+
.. currentmodule:: torchrl
22

33
torchrl._utils package
44
====================
@@ -11,3 +11,5 @@ Set of utility methods that are used internally by the library.
1111
:template: rl_template.rst
1212

1313
implement_for
14+
set_auto_unwrap_transformed_env
15+
auto_unwrap_transformed_env

test/test_transforms.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from tensordict.nn import TensorDictSequential
3434
from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td
3535
from torch import multiprocessing as mp, nn, Tensor
36-
from torchrl._utils import _replace_last, prod
36+
from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env
3737

3838
from torchrl.collectors import MultiSyncDataCollector
3939
from torchrl.data import (
@@ -9846,6 +9846,40 @@ def test_added_transforms_are_in_eval_mode():
98469846

98479847

98489848
class TestTransformedEnv:
9849+
@pytest.mark.filterwarnings("error")
9850+
def test_nested_transformed_env(self):
9851+
base_env = ContinuousActionVecMockEnv()
9852+
t1 = RewardScaling(0, 1)
9853+
t2 = RewardScaling(0, 2)
9854+
9855+
def test_unwrap():
9856+
env = TransformedEnv(TransformedEnv(base_env, t1), t2)
9857+
assert env.base_env is base_env
9858+
assert isinstance(env.transform, Compose)
9859+
children = list(env.transform.transforms.children())
9860+
assert len(children) == 2
9861+
assert children[0].scale == 1
9862+
assert children[1].scale == 2
9863+
9864+
def test_wrap(auto_unwrap=None):
9865+
env = TransformedEnv(
9866+
TransformedEnv(base_env, t1), t2, auto_unwrap=auto_unwrap
9867+
)
9868+
assert env.base_env is not base_env
9869+
assert isinstance(env.base_env.transform, RewardScaling)
9870+
assert isinstance(env.transform, RewardScaling)
9871+
9872+
with pytest.warns(FutureWarning):
9873+
test_unwrap()
9874+
9875+
test_wrap(False)
9876+
9877+
with set_auto_unwrap_transformed_env(True):
9878+
test_unwrap()
9879+
9880+
with set_auto_unwrap_transformed_env(False):
9881+
test_wrap()
9882+
98499883
def test_attr_error(self):
98509884
class BuggyTransform(Transform):
98519885
def transform_observation_spec(
@@ -9936,20 +9970,6 @@ def test_allow_done_after_reset(self):
99369970
assert not t1._allow_done_after_reset
99379971

99389972

9939-
def test_nested_transformed_env():
9940-
base_env = ContinuousActionVecMockEnv()
9941-
t1 = RewardScaling(0, 1)
9942-
t2 = RewardScaling(0, 2)
9943-
env = TransformedEnv(TransformedEnv(base_env, t1), t2)
9944-
9945-
assert env.base_env is base_env
9946-
assert isinstance(env.transform, Compose)
9947-
children = list(env.transform.transforms.children())
9948-
assert len(children) == 2
9949-
assert children[0].scale == 1
9950-
assert children[1].scale == 2
9951-
9952-
99539973
def test_transform_parent():
99549974
base_env = ContinuousActionVecMockEnv()
99559975
t1 = RewardScaling(0, 1)

torchrl/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,13 @@
5252
import torchrl.modules
5353
import torchrl.objectives
5454
import torchrl.trainers
55-
from torchrl._utils import compile_with_warmup, timeit
55+
from torchrl._utils import (
56+
auto_unwrap_transformed_env,
57+
compile_with_warmup,
58+
implement_for,
59+
set_auto_unwrap_transformed_env,
60+
timeit,
61+
)
5662

5763
# Filter warnings in subprocesses: True by default given the multiple optional
5864
# deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`.

torchrl/_utils.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,3 +984,86 @@ def count_and_compile(*model_args, **model_kwargs):
984984
return compiled_model(*model_args, **model_kwargs)
985985

986986
return count_and_compile
987+
988+
989+
# auto unwrap control
990+
_DEFAULT_AUTO_UNWRAP = True
991+
_AUTO_UNWRAP = os.environ.get("AUTO_UNWRAP_TRANSFORMED_ENV")
992+
993+
994+
class set_auto_unwrap_transformed_env(_DecoratorContextManager):
995+
"""A context manager or decorator to control whether TransformedEnv should automatically unwrap nested TransformedEnv instances.
996+
997+
Args:
998+
mode (bool): Whether to automatically unwrap nested :class:`~torchrl.envs.TransformedEnv`
999+
instances. If ``False``, :class:`~torchrl.envs.TransformedEnv` will not unwrap nested instances.
1000+
Defaults to ``True``.
1001+
1002+
.. note:: Until v0.9, this will raise a warning if :class:`~torchrl.envs.TransformedEnv` are nested
1003+
and the value is not set explicitly (`auto_unwrap=True` default behavior).
1004+
You can set the value of :func:`~torchrl.envs.auto_unwrap_transformed_env`
1005+
through:
1006+
1007+
- The ``AUTO_UNWRAP_TRANSFORMED_ENV`` environment variable;
1008+
- By setting ``torchrl.set_auto_unwrap_transformed_env(val: bool).set()`` at the
1009+
beginning of your script;
1010+
- By using ``torchrl.set_auto_unwrap_transformed_env(val: bool)`` as a context
1011+
manager or a decorator.
1012+
1013+
.. seealso:: :class:`~torchrl.envs.TransformedEnv`
1014+
1015+
Examples:
1016+
>>> with set_auto_unwrap_transformed_env(False):
1017+
... env = TransformedEnv(TransformedEnv(env))
1018+
... assert not isinstance(env.base_env, TransformedEnv)
1019+
>>> @set_auto_unwrap_transformed_env(False)
1020+
... def my_function():
1021+
... env = TransformedEnv(TransformedEnv(env))
1022+
... assert not isinstance(env.base_env, TransformedEnv)
1023+
... return env
1024+
1025+
"""
1026+
1027+
def __init__(self, mode: bool) -> None:
1028+
super().__init__()
1029+
self.mode = mode
1030+
1031+
def clone(self) -> set_auto_unwrap_transformed_env:
1032+
# override this method if your children class takes __init__ parameters
1033+
return type(self)(self.mode)
1034+
1035+
def __enter__(self) -> None:
1036+
self.set()
1037+
1038+
def set(self) -> None:
1039+
global _AUTO_UNWRAP
1040+
self._old_mode = _AUTO_UNWRAP
1041+
_AUTO_UNWRAP = bool(self.mode)
1042+
# we do this such that sub-processes see the same lazy op than the main one
1043+
os.environ["AUTO_UNWRAP_TRANSFORMED_ENV"] = str(_AUTO_UNWRAP)
1044+
1045+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
1046+
global _AUTO_UNWRAP
1047+
_AUTO_UNWRAP = self._old_mode
1048+
os.environ["AUTO_UNWRAP_TRANSFORMED_ENV"] = str(_AUTO_UNWRAP)
1049+
1050+
1051+
def auto_unwrap_transformed_env(allow_none=False):
1052+
"""Get the current setting for automatically unwrapping TransformedEnv instances.
1053+
1054+
Args:
1055+
allow_none (bool, optional): If True, returns ``None`` if no setting has been
1056+
specified. Otherwise, returns the default setting. Defaults to ``False``.
1057+
1058+
seealso: :func:`~torchrl.set_auto_unwrap_transformed_env`
1059+
1060+
Returns:
1061+
bool or None: The current setting for automatically unwrapping TransformedEnv
1062+
instances.
1063+
"""
1064+
global _AUTO_UNWRAP
1065+
if _AUTO_UNWRAP is None and allow_none:
1066+
return None
1067+
elif _AUTO_UNWRAP is None:
1068+
return _DEFAULT_AUTO_UNWRAP
1069+
return strtobool(_AUTO_UNWRAP) if isinstance(_AUTO_UNWRAP, str) else _AUTO_UNWRAP

torchrl/envs/transforms/transforms.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
_ends_with,
6262
_make_ordinal_device,
6363
_replace_last,
64+
auto_unwrap_transformed_env,
6465
logger as torchrl_logger,
6566
)
6667

@@ -705,7 +706,11 @@ class TransformedEnv(EnvBase, metaclass=_TEnvPostInit):
705706
Keyword Args:
706707
auto_unwrap (bool, optional): if ``True``, wrapping a transformed env in transformed env
707708
unwraps the transforms of the inner TransformedEnv in the outer one (the new instance).
708-
Defaults to ``True``
709+
Defaults to ``True``.
710+
711+
.. note:: This behavior will switch to ``False`` in v0.9.
712+
713+
.. seealso:: :class:`~torchrl.set_auto_unwrap_transformed_env`
709714
710715
Examples:
711716
>>> env = GymEnv("Pendulum-v0")
@@ -724,7 +729,7 @@ def __init__(
724729
transform: Optional[Transform] = None,
725730
cache_specs: bool = True,
726731
*,
727-
auto_unwrap: bool = True,
732+
auto_unwrap: bool | None = None,
728733
**kwargs,
729734
):
730735
self._transform = None
@@ -737,7 +742,24 @@ def __init__(
737742

738743
# Type matching must be exact here, because subtyping could introduce differences in behavior that must
739744
# be contained within the subclass.
740-
if type(env) is TransformedEnv and type(self) is TransformedEnv and auto_unwrap:
745+
if type(env) is TransformedEnv and type(self) is TransformedEnv:
746+
if auto_unwrap is None:
747+
auto_unwrap = auto_unwrap_transformed_env(allow_none=True)
748+
if auto_unwrap is None:
749+
warnings.warn(
750+
"The default behavior of TransformedEnv will change in version 0.9. "
751+
"Nested TransformedEnvs will no longer be automatically unwrapped by default. "
752+
"To prepare for this change, use set_auto_unwrap_transformed_env(val: bool) "
753+
"as a decorator or context manager, or set the environment variable "
754+
"AUTO_UNWRAP_TRANSFORMED_ENV to 'False'.",
755+
FutureWarning,
756+
stacklevel=2,
757+
)
758+
auto_unwrap = True
759+
else:
760+
auto_unwrap = False
761+
762+
if auto_unwrap:
741763
self._set_env(env.base_env, device)
742764
if type(transform) is not Compose:
743765
# we don't use isinstance as some transforms may be subclassed from
@@ -768,6 +790,7 @@ def __init__(
768790
self._set_env(env, device)
769791
if transform is None:
770792
transform = Compose()
793+
771794
self.transform = transform
772795

773796
self._last_obs = None

0 commit comments

Comments
 (0)