Skip to content

Commit 759ea27

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent b70f558 commit 759ea27

File tree

1 file changed

+0
-17
lines changed

1 file changed

+0
-17
lines changed

torchrl/envs/transforms/transforms.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9974,20 +9974,3 @@ def _apply_transform(self, reward: Tensor) -> TensorDictBase:
99749974
)
99759975

99769976
return (self.weights * reward).sum(dim=-1)
9977-
9978-
class ConditionalPolicySwitch(Transform):
9979-
def __init__(self, policy: Callable[[TensorDictBase], TensorDictBase], condition: Callable[[TensorDictBase], bool]):
9980-
super().__init__([], [])
9981-
self.__dict__["policy"] = policy
9982-
self.condition = condition
9983-
def _step(
9984-
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
9985-
) -> TensorDictBase:
9986-
if self.condition(tensordict):
9987-
parent: TransformedEnv = self.parent
9988-
tensordict = parent.step(tensordict)
9989-
tensordict_ = parent.step_mdp(tensordict)
9990-
tensordict_ = self.policy(tensordict_)
9991-
return parent.step(tensordict_)
9992-
return tensordict
9993-
return

0 commit comments

Comments
 (0)