Skip to content

Commit 601483e

Browse files
author
Vincent Moens
committed
[Feature] lock_ / unlock_ graphs
ghstack-source-id: 01e375e Pull Request resolved: #2729
1 parent dc63e82 commit 601483e

File tree

4 files changed

+267
-60
lines changed

4 files changed

+267
-60
lines changed

test/test_specs.py

Lines changed: 88 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -697,36 +697,100 @@ def test_create_composite_nested(shape, device):
697697
assert c["a"].device == device
698698

699699

700-
@pytest.mark.parametrize("recurse", [True, False])
701-
def test_lock(recurse):
702-
shape = [3, 4, 5]
703-
spec = Composite(
704-
a=Composite(b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2]),
705-
shape=shape[:1],
706-
)
707-
spec["a"] = spec["a"].clone()
708-
spec["a", "b"] = spec["a", "b"].clone()
709-
assert not spec.locked
710-
spec.lock_(recurse=recurse)
711-
assert spec.locked
712-
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
700+
class TestLock:
701+
@pytest.mark.parametrize("recurse", [None, True, False])
702+
def test_lock(self, recurse):
703+
catch_warn = (
704+
pytest.warns(DeprecationWarning, match="recurse")
705+
if recurse is None
706+
else contextlib.nullcontext()
707+
)
708+
709+
shape = [3, 4, 5]
710+
spec = Composite(
711+
a=Composite(b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2]),
712+
shape=shape[:1],
713+
)
713714
spec["a"] = spec["a"].clone()
714-
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
715-
spec.set("a", spec["a"].clone())
716-
if recurse:
717-
assert spec["a"].locked
715+
spec["a", "b"] = spec["a", "b"].clone()
716+
assert not spec.locked
717+
with catch_warn:
718+
spec.lock_(recurse=recurse)
719+
assert spec.locked
718720
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
719-
spec["a"].set("b", spec["a", "b"].clone())
721+
spec["a"] = spec["a"].clone()
720722
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
723+
spec.set("a", spec["a"].clone())
724+
if recurse:
725+
assert spec["a"].locked
726+
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
727+
spec["a"].set("b", spec["a", "b"].clone())
728+
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
729+
spec["a", "b"] = spec["a", "b"].clone()
730+
else:
731+
assert not spec["a"].locked
721732
spec["a", "b"] = spec["a", "b"].clone()
722-
else:
723-
assert not spec["a"].locked
733+
spec["a"].set("b", spec["a", "b"].clone())
734+
with catch_warn:
735+
spec.unlock_(recurse=recurse)
736+
spec["a"] = spec["a"].clone()
724737
spec["a", "b"] = spec["a", "b"].clone()
725738
spec["a"].set("b", spec["a", "b"].clone())
726-
spec.unlock_(recurse=recurse)
727-
spec["a"] = spec["a"].clone()
728-
spec["a", "b"] = spec["a", "b"].clone()
729-
spec["a"].set("b", spec["a", "b"].clone())
739+
740+
def test_edge_cases(self):
741+
level3 = Composite()
742+
level2 = Composite(level3=level3)
743+
level1 = Composite(level2=level2)
744+
level0 = Composite(level1=level1)
745+
# locking level0 locks them all
746+
level0.lock_(recurse=True)
747+
assert level3.is_locked
748+
# We cannot unlock level3
749+
with pytest.raises(
750+
RuntimeError,
751+
match="Cannot unlock a Composite that is part of a locked graph",
752+
):
753+
level3.unlock_(recurse=True)
754+
assert level3.is_locked
755+
# Adding level2 to a new spec and locking it makes it hard to unlock the level0 root
756+
new_spec = Composite(level2=level2)
757+
new_spec.lock_(recurse=True)
758+
with pytest.raises(
759+
RuntimeError,
760+
match="Cannot unlock a Composite that is part of a locked graph",
761+
):
762+
level0.unlock_(recurse=True)
763+
assert level0.is_locked
764+
765+
def test_lock_mix_recurse_nonrecurse(self):
766+
# lock with recurse
767+
level3 = Composite()
768+
level2 = Composite(level3=level3)
769+
level1 = Composite(level2=level2)
770+
level0 = Composite(level1=level1)
771+
# locking level0 locks them all
772+
level0.lock_(recurse=True)
773+
new_spec = Composite(level2=level2)
774+
new_spec.lock_(recurse=True)
775+
776+
# Unlock with recurse=False
777+
with pytest.raises(RuntimeError, match="Cannot unlock"):
778+
level3.unlock_(recurse=False)
779+
assert level3.is_locked
780+
assert level2.is_locked
781+
assert new_spec.is_locked
782+
with pytest.raises(RuntimeError, match="Cannot unlock"):
783+
level2.unlock_(recurse=False)
784+
with pytest.raises(RuntimeError, match="Cannot unlock"):
785+
level1.unlock_(recurse=False)
786+
level0.unlock_(recurse=False)
787+
assert level3.is_locked
788+
assert level2.is_locked
789+
assert level1.is_locked
790+
new_spec.unlock_(recurse=False)
791+
assert level3.is_locked
792+
assert level2.is_locked
793+
assert level1.is_locked
730794

731795

732796
def test_keys_to_empty_composite_spec():

torchrl/data/tensor_specs.py

Lines changed: 160 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
import abc
99
import enum
10+
import gc
1011
import math
1112
import warnings
13+
import weakref
1214
from collections.abc import Iterable
1315
from copy import deepcopy
1416
from dataclasses import dataclass
@@ -4428,7 +4430,7 @@ class Composite(TensorSpec):
44284430
@classmethod
44294431
def __new__(cls, *args, **kwargs):
44304432
cls._device = None
4431-
cls._locked = False
4433+
cls._is_locked = False
44324434
return super().__new__(cls)
44334435

44344436
@property
@@ -4959,6 +4961,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Composite:
49594961
return self.__class__(**kwargs, device=_device, shape=self.shape)
49604962

49614963
def clone(self) -> Composite:
4964+
"""Clones the Composite spec.
4965+
4966+
Locked specs will not produce locked clones.
4967+
"""
49624968
try:
49634969
device = self.device
49644970
except RuntimeError:
@@ -5170,14 +5176,82 @@ def unbind(self, dim: int = 0):
51705176
for i in range(self.shape[dim])
51715177
)
51725178

5173-
def lock_(self, recurse=False):
5174-
"""Locks the Composite and prevents modification of its content.
5179+
# Locking functionality
5180+
@property
5181+
def is_locked(self) -> bool:
5182+
return self._is_locked
5183+
5184+
@is_locked.setter
5185+
def is_locked(self, value: bool) -> None:
5186+
if value:
5187+
self.lock_()
5188+
else:
5189+
self.unlock_()
5190+
5191+
def __getstate__(self):
5192+
result = self.__dict__.copy()
5193+
__lock_parents_weakrefs = result.pop("__lock_parents_weakrefs", None)
5194+
if __lock_parents_weakrefs is not None:
5195+
result["_lock_recurse"] = True
5196+
return result
5197+
5198+
def __setstate__(self, state):
5199+
_lock_recurse = state.pop("_lock_recurse", False)
5200+
for key, value in state.items():
5201+
setattr(self, key, value)
5202+
if self._is_locked:
5203+
self._is_locked = False
5204+
self.lock_(recurse=_lock_recurse)
5205+
5206+
def _propagate_lock(
5207+
self, *, recurse: bool, lock_parents_weakrefs=None, is_compiling
5208+
):
5209+
"""Registers the parent composite that handles the lock."""
5210+
self._is_locked = True
5211+
if lock_parents_weakrefs is not None:
5212+
lock_parents_weakrefs = [
5213+
ref
5214+
for ref in lock_parents_weakrefs
5215+
if not any(refref is ref for refref in self._lock_parents_weakrefs)
5216+
]
5217+
if not is_compiling:
5218+
is_root = lock_parents_weakrefs is None
5219+
if is_root:
5220+
lock_parents_weakrefs = []
5221+
else:
5222+
self._lock_parents_weakrefs = (
5223+
self._lock_parents_weakrefs + lock_parents_weakrefs
5224+
)
5225+
lock_parents_weakrefs = list(lock_parents_weakrefs)
5226+
lock_parents_weakrefs.append(weakref.ref(self))
51755227

5176-
This is only a first-level lock, unless specified otherwise through the
5177-
``recurse`` arg.
5228+
if recurse:
5229+
for value in self.values():
5230+
if isinstance(value, Composite):
5231+
value._propagate_lock(
5232+
recurse=True,
5233+
lock_parents_weakrefs=lock_parents_weakrefs,
5234+
is_compiling=is_compiling,
5235+
)
51785236

5179-
Leaf specs can always be modified in place, but they cannot be replaced
5180-
in their Composite parent.
5237+
@property
5238+
def _lock_parents_weakrefs(self):
5239+
_lock_parents_weakrefs = self.__dict__.get("__lock_parents_weakrefs")
5240+
if _lock_parents_weakrefs is None:
5241+
self.__dict__["__lock_parents_weakrefs"] = []
5242+
_lock_parents_weakrefs = self.__dict__["__lock_parents_weakrefs"]
5243+
return _lock_parents_weakrefs
5244+
5245+
@_lock_parents_weakrefs.setter
5246+
def _lock_parents_weakrefs(self, value: list):
5247+
self.__dict__["__lock_parents_weakrefs"] = value
5248+
5249+
def lock_(self, recurse: bool | None = None) -> T:
5250+
"""Locks the Composite and prevents modification of its content.
5251+
5252+
The recurse argument control whether the lock will be propagated to sub-specs.
5253+
The current default is ``False`` but it will be turned to ``True`` for consistency
5254+
with the TensorDict API in v0.8.
51815255
51825256
Examples:
51835257
>>> shape = [3, 4, 5]
@@ -5211,30 +5285,99 @@ def lock_(self, recurse=False):
52115285
failed!
52125286
52135287
"""
5214-
self._locked = True
5288+
if self.is_locked:
5289+
return self
5290+
is_comp = is_compiling()
5291+
if is_comp:
5292+
# TODO: See what to do when compiling
5293+
pass
5294+
if recurse is None:
5295+
warnings.warn(
5296+
"You have not specified a value for recurse when calling CompositeSpec.lock_(). "
5297+
"The current default is False but it will be turned to True in v0.8. To adapt to these changes "
5298+
"and silence this warning, pass the value of recurse explicitly.",
5299+
category=DeprecationWarning,
5300+
)
5301+
recurse = False
5302+
self._propagate_lock(recurse=recurse, is_compiling=is_comp)
5303+
return self
5304+
5305+
def _propagate_unlock(self, recurse: bool):
5306+
# if we end up here, we can clear the graph associated with this td
5307+
self._is_locked = False
5308+
5309+
self._is_shared = False
5310+
self._is_memmap = False
5311+
52155312
if recurse:
5313+
sub_specs = []
52165314
for value in self.values():
52175315
if isinstance(value, Composite):
5218-
value.lock_(recurse)
5219-
return self
5316+
sub_specs.extend(value._propagate_unlock(recurse=recurse))
5317+
sub_specs.append(value)
5318+
return sub_specs
5319+
return []
5320+
5321+
def _check_unlock(self, first_attempt=True):
5322+
if not first_attempt:
5323+
gc.collect()
5324+
obj = None
5325+
for ref in self._lock_parents_weakrefs:
5326+
obj = ref()
5327+
# check if the locked parent exists and if it's locked
5328+
# we check _is_locked because it can be False or None in the case of Lazy stacks,
5329+
# but if we check obj.is_locked it will be True for this class.
5330+
if obj is not None and obj._is_locked:
5331+
break
52205332

5221-
def unlock_(self, recurse=False):
5333+
else:
5334+
try:
5335+
self._lock_parents_weakrefs = []
5336+
except AttributeError:
5337+
# Some tds (eg, LazyStack) have an automated way of creating the _lock_parents_weakref
5338+
pass
5339+
return
5340+
5341+
if first_attempt:
5342+
del obj
5343+
return self._check_unlock(False)
5344+
raise RuntimeError(
5345+
"Cannot unlock a Composite that is part of a locked graph. "
5346+
"Graphs are locked when a Composite is locked with recurse=True. "
5347+
"Unlock the root Composite first. If the Composite is part of multiple graphs, "
5348+
"group the graphs under a common Composite an unlock this root. "
5349+
f"self: {self}, obj: {obj}"
5350+
)
5351+
5352+
def unlock_(self, recurse: bool | None = None) -> T:
52225353
"""Unlocks the Composite and allows modification of its content.
52235354
52245355
This is only a first-level lock modification, unless specified
52255356
otherwise through the ``recurse`` arg.
52265357
52275358
"""
5228-
self._locked = False
5229-
if recurse:
5230-
for value in self.values():
5231-
if isinstance(value, Composite):
5232-
value.unlock_(recurse)
5359+
try:
5360+
if recurse is None:
5361+
warnings.warn(
5362+
"You have not specified a value for recurse when calling CompositeSpec.unlock_(). "
5363+
"The current default is False but it will be turned to True in v0.8. To adapt to these changes "
5364+
"and silence this warning, pass the value of recurse explicitly.",
5365+
category=DeprecationWarning,
5366+
)
5367+
recurse = False
5368+
sub_specs = self._propagate_unlock(recurse=recurse)
5369+
if recurse:
5370+
for sub_spec in sub_specs:
5371+
sub_spec._check_unlock()
5372+
self._check_unlock()
5373+
except RuntimeError as err:
5374+
self.lock_()
5375+
raise err
52335376
return self
52345377

52355378
@property
52365379
def locked(self):
5237-
return self._locked
5380+
return self._is_locked
52385381

52395382

52405383
class StackedComposite(_LazyStackedMixin[Composite], Composite):

torchrl/envs/libs/isaacgym.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
8080
specs = make_composite_from_td(data)
8181

8282
obs_spec = self.observation_spec
83-
obs_spec.unlock_()
83+
obs_spec.unlock_(recurse=True)
8484
obs_spec.update(specs)
85-
obs_spec.lock_()
85+
obs_spec.lock_(recurse=True)
8686

8787
def _output_transform(self, output):
8888
obs, reward, done, info = output

0 commit comments

Comments
 (0)