Skip to content

Commit 83e0b05

Browse files
author
Vincent Moens
committed
[Feature] Composite.separates
ghstack-source-id: fbfc430 Pull Request resolved: #2599
1 parent 8d16c12 commit 83e0b05

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

torchrl/data/tensor_specs.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4339,6 +4339,33 @@ def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> Any:
43394339
return default
43404340
raise KeyError(f"{key} not found in composite spec.")
43414341

4342+
def separates(self, *keys: NestedKey, default: Any = None) -> Composite:
4343+
"""Splits the composite spec by extracting specified keys and their associated values into a new composite spec.
4344+
4345+
This method iterates over the provided keys, removes them from the current composite spec, and adds them to a new
4346+
composite spec. If a key is not found, the specified default value is used. The new composite spec is returned.
4347+
4348+
Args:
4349+
*keys (NestedKey):
4350+
One or more keys to be extracted from the composite spec. Each key can be a single key or a nested key.
4351+
default (Any, optional):
4352+
The value to use if a specified key is not found in the composite spec. Defaults to `None`.
4353+
4354+
Returns:
4355+
Composite: A new composite spec containing the extracted keys and their associated values.
4356+
4357+
Note:
4358+
If none of the specified keys are found, the method returns `None`.
4359+
"""
4360+
out = None
4361+
for key in keys:
4362+
result = self.pop(key, default=default)
4363+
if result is not None:
4364+
if out is None:
4365+
out = Composite(batch_size=self.batch_size, device=self.device)
4366+
out[key] = result
4367+
return out
4368+
43424369
def set(self, name, spec):
43434370
if self.locked:
43444371
raise RuntimeError("Cannot modify a locked Composite.")

0 commit comments

Comments
 (0)