|
7 | 7 |
|
8 | 8 | import abc
|
9 | 9 | import enum
|
| 10 | +import gc |
10 | 11 | import math
|
11 | 12 | import warnings
|
| 13 | +import weakref |
12 | 14 | from collections.abc import Iterable
|
13 | 15 | from copy import deepcopy
|
14 | 16 | from dataclasses import dataclass
|
@@ -4428,7 +4430,7 @@ class Composite(TensorSpec):
|
4428 | 4430 | @classmethod
|
4429 | 4431 | def __new__(cls, *args, **kwargs):
|
4430 | 4432 | cls._device = None
|
4431 |
| - cls._locked = False |
| 4433 | + cls._is_locked = False |
4432 | 4434 | return super().__new__(cls)
|
4433 | 4435 |
|
4434 | 4436 | @property
|
@@ -4959,6 +4961,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Composite:
|
4959 | 4961 | return self.__class__(**kwargs, device=_device, shape=self.shape)
|
4960 | 4962 |
|
4961 | 4963 | def clone(self) -> Composite:
|
| 4964 | + """Clones the Composite spec. |
| 4965 | +
|
| 4966 | + Locked specs will not produce locked clones. |
| 4967 | + """ |
4962 | 4968 | try:
|
4963 | 4969 | device = self.device
|
4964 | 4970 | except RuntimeError:
|
@@ -5170,14 +5176,82 @@ def unbind(self, dim: int = 0):
|
5170 | 5176 | for i in range(self.shape[dim])
|
5171 | 5177 | )
|
5172 | 5178 |
|
5173 |
| - def lock_(self, recurse=False): |
5174 |
| - """Locks the Composite and prevents modification of its content. |
| 5179 | + # Locking functionality |
| 5180 | + @property |
| 5181 | + def is_locked(self) -> bool: |
| 5182 | + return self._is_locked |
| 5183 | + |
| 5184 | + @is_locked.setter |
| 5185 | + def is_locked(self, value: bool) -> None: |
| 5186 | + if value: |
| 5187 | + self.lock_() |
| 5188 | + else: |
| 5189 | + self.unlock_() |
| 5190 | + |
| 5191 | + def __getstate__(self): |
| 5192 | + result = self.__dict__.copy() |
| 5193 | + __lock_parents_weakrefs = result.pop("__lock_parents_weakrefs", None) |
| 5194 | + if __lock_parents_weakrefs is not None: |
| 5195 | + result["_lock_recurse"] = True |
| 5196 | + return result |
| 5197 | + |
| 5198 | + def __setstate__(self, state): |
| 5199 | + _lock_recurse = state.pop("_lock_recurse", False) |
| 5200 | + for key, value in state.items(): |
| 5201 | + setattr(self, key, value) |
| 5202 | + if self._is_locked: |
| 5203 | + self._is_locked = False |
| 5204 | + self.lock_(recurse=_lock_recurse) |
| 5205 | + |
| 5206 | + def _propagate_lock( |
| 5207 | + self, *, recurse: bool, lock_parents_weakrefs=None, is_compiling |
| 5208 | + ): |
| 5209 | + """Registers the parent composite that handles the lock.""" |
| 5210 | + self._is_locked = True |
| 5211 | + if lock_parents_weakrefs is not None: |
| 5212 | + lock_parents_weakrefs = [ |
| 5213 | + ref |
| 5214 | + for ref in lock_parents_weakrefs |
| 5215 | + if not any(refref is ref for refref in self._lock_parents_weakrefs) |
| 5216 | + ] |
| 5217 | + if not is_compiling: |
| 5218 | + is_root = lock_parents_weakrefs is None |
| 5219 | + if is_root: |
| 5220 | + lock_parents_weakrefs = [] |
| 5221 | + else: |
| 5222 | + self._lock_parents_weakrefs = ( |
| 5223 | + self._lock_parents_weakrefs + lock_parents_weakrefs |
| 5224 | + ) |
| 5225 | + lock_parents_weakrefs = list(lock_parents_weakrefs) |
| 5226 | + lock_parents_weakrefs.append(weakref.ref(self)) |
5175 | 5227 |
|
5176 |
| - This is only a first-level lock, unless specified otherwise through the |
5177 |
| - ``recurse`` arg. |
| 5228 | + if recurse: |
| 5229 | + for value in self.values(): |
| 5230 | + if isinstance(value, Composite): |
| 5231 | + value._propagate_lock( |
| 5232 | + recurse=True, |
| 5233 | + lock_parents_weakrefs=lock_parents_weakrefs, |
| 5234 | + is_compiling=is_compiling, |
| 5235 | + ) |
5178 | 5236 |
|
5179 |
| - Leaf specs can always be modified in place, but they cannot be replaced |
5180 |
| - in their Composite parent. |
| 5237 | + @property |
| 5238 | + def _lock_parents_weakrefs(self): |
| 5239 | + _lock_parents_weakrefs = self.__dict__.get("__lock_parents_weakrefs") |
| 5240 | + if _lock_parents_weakrefs is None: |
| 5241 | + self.__dict__["__lock_parents_weakrefs"] = [] |
| 5242 | + _lock_parents_weakrefs = self.__dict__["__lock_parents_weakrefs"] |
| 5243 | + return _lock_parents_weakrefs |
| 5244 | + |
| 5245 | + @_lock_parents_weakrefs.setter |
| 5246 | + def _lock_parents_weakrefs(self, value: list): |
| 5247 | + self.__dict__["__lock_parents_weakrefs"] = value |
| 5248 | + |
| 5249 | + def lock_(self, recurse: bool | None = None) -> T: |
| 5250 | + """Locks the Composite and prevents modification of its content. |
| 5251 | +
|
| 5252 | + The recurse argument control whether the lock will be propagated to sub-specs. |
| 5253 | + The current default is ``False`` but it will be turned to ``True`` for consistency |
| 5254 | + with the TensorDict API in v0.8. |
5181 | 5255 |
|
5182 | 5256 | Examples:
|
5183 | 5257 | >>> shape = [3, 4, 5]
|
@@ -5211,30 +5285,99 @@ def lock_(self, recurse=False):
|
5211 | 5285 | failed!
|
5212 | 5286 |
|
5213 | 5287 | """
|
5214 |
| - self._locked = True |
| 5288 | + if self.is_locked: |
| 5289 | + return self |
| 5290 | + is_comp = is_compiling() |
| 5291 | + if is_comp: |
| 5292 | + # TODO: See what to do when compiling |
| 5293 | + pass |
| 5294 | + if recurse is None: |
| 5295 | + warnings.warn( |
| 5296 | + "You have not specified a value for recurse when calling CompositeSpec.lock_(). " |
| 5297 | + "The current default is False but it will be turned to True in v0.8. To adapt to these changes " |
| 5298 | + "and silence this warning, pass the value of recurse explicitly.", |
| 5299 | + category=DeprecationWarning, |
| 5300 | + ) |
| 5301 | + recurse = False |
| 5302 | + self._propagate_lock(recurse=recurse, is_compiling=is_comp) |
| 5303 | + return self |
| 5304 | + |
| 5305 | + def _propagate_unlock(self, recurse: bool): |
| 5306 | + # if we end up here, we can clear the graph associated with this td |
| 5307 | + self._is_locked = False |
| 5308 | + |
| 5309 | + self._is_shared = False |
| 5310 | + self._is_memmap = False |
| 5311 | + |
5215 | 5312 | if recurse:
|
| 5313 | + sub_specs = [] |
5216 | 5314 | for value in self.values():
|
5217 | 5315 | if isinstance(value, Composite):
|
5218 |
| - value.lock_(recurse) |
5219 |
| - return self |
| 5316 | + sub_specs.extend(value._propagate_unlock(recurse=recurse)) |
| 5317 | + sub_specs.append(value) |
| 5318 | + return sub_specs |
| 5319 | + return [] |
| 5320 | + |
| 5321 | + def _check_unlock(self, first_attempt=True): |
| 5322 | + if not first_attempt: |
| 5323 | + gc.collect() |
| 5324 | + obj = None |
| 5325 | + for ref in self._lock_parents_weakrefs: |
| 5326 | + obj = ref() |
| 5327 | + # check if the locked parent exists and if it's locked |
| 5328 | + # we check _is_locked because it can be False or None in the case of Lazy stacks, |
| 5329 | + # but if we check obj.is_locked it will be True for this class. |
| 5330 | + if obj is not None and obj._is_locked: |
| 5331 | + break |
5220 | 5332 |
|
5221 |
| - def unlock_(self, recurse=False): |
| 5333 | + else: |
| 5334 | + try: |
| 5335 | + self._lock_parents_weakrefs = [] |
| 5336 | + except AttributeError: |
| 5337 | + # Some tds (eg, LazyStack) have an automated way of creating the _lock_parents_weakref |
| 5338 | + pass |
| 5339 | + return |
| 5340 | + |
| 5341 | + if first_attempt: |
| 5342 | + del obj |
| 5343 | + return self._check_unlock(False) |
| 5344 | + raise RuntimeError( |
| 5345 | + "Cannot unlock a Composite that is part of a locked graph. " |
| 5346 | + "Graphs are locked when a Composite is locked with recurse=True. " |
| 5347 | + "Unlock the root Composite first. If the Composite is part of multiple graphs, " |
| 5348 | + "group the graphs under a common Composite an unlock this root. " |
| 5349 | + f"self: {self}, obj: {obj}" |
| 5350 | + ) |
| 5351 | + |
| 5352 | + def unlock_(self, recurse: bool | None = None) -> T: |
5222 | 5353 | """Unlocks the Composite and allows modification of its content.
|
5223 | 5354 |
|
5224 | 5355 | This is only a first-level lock modification, unless specified
|
5225 | 5356 | otherwise through the ``recurse`` arg.
|
5226 | 5357 |
|
5227 | 5358 | """
|
5228 |
| - self._locked = False |
5229 |
| - if recurse: |
5230 |
| - for value in self.values(): |
5231 |
| - if isinstance(value, Composite): |
5232 |
| - value.unlock_(recurse) |
| 5359 | + try: |
| 5360 | + if recurse is None: |
| 5361 | + warnings.warn( |
| 5362 | + "You have not specified a value for recurse when calling CompositeSpec.unlock_(). " |
| 5363 | + "The current default is False but it will be turned to True in v0.8. To adapt to these changes " |
| 5364 | + "and silence this warning, pass the value of recurse explicitly.", |
| 5365 | + category=DeprecationWarning, |
| 5366 | + ) |
| 5367 | + recurse = False |
| 5368 | + sub_specs = self._propagate_unlock(recurse=recurse) |
| 5369 | + if recurse: |
| 5370 | + for sub_spec in sub_specs: |
| 5371 | + sub_spec._check_unlock() |
| 5372 | + self._check_unlock() |
| 5373 | + except RuntimeError as err: |
| 5374 | + self.lock_() |
| 5375 | + raise err |
5233 | 5376 | return self
|
5234 | 5377 |
|
5235 | 5378 | @property
|
5236 | 5379 | def locked(self):
|
5237 |
| - return self._locked |
| 5380 | + return self._is_locked |
5238 | 5381 |
|
5239 | 5382 |
|
5240 | 5383 | class StackedComposite(_LazyStackedMixin[Composite], Composite):
|
|
0 commit comments