Skip to content

Commit 3da76f0

Browse files
[Feature] ActionDiscretizer custom sampling (#2609)
Co-authored-by: Oliver Slumbers <oliver.slumbers@helsing.ai>
1 parent 607ebc5 commit 3da76f0

File tree

1 file changed

+24
-22
lines changed

1 file changed

+24
-22
lines changed

torchrl/envs/transforms/transforms.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8583,6 +8583,26 @@ def _indent(s):
85838583
f"\n{_indent(out_action_key)},\n{_indent(sampling)},\n{_indent(categorical)})"
85848584
)
85858585

8586+
def _custom_arange(self, nint, device):
8587+
result = torch.arange(
8588+
start=0.0,
8589+
end=1.0,
8590+
step=1 / nint,
8591+
dtype=self.dtype,
8592+
device=device,
8593+
)
8594+
result_ = result
8595+
if self.sampling in (
8596+
self.SamplingStrategy.HIGH,
8597+
self.SamplingStrategy.MEDIAN,
8598+
):
8599+
result_ = (1 - result).flip(0)
8600+
if self.sampling == self.SamplingStrategy.MEDIAN:
8601+
result = (result + result_) / 2
8602+
else:
8603+
result = result_
8604+
return result
8605+
85868606
def transform_input_spec(self, input_spec):
85878607
try:
85888608
action_spec = self.parent.full_action_spec_unbatched[self.in_keys_inv[0]]
@@ -8611,29 +8631,11 @@ def transform_input_spec(self, input_spec):
86118631
num_intervals = int(num_intervals.squeeze())
86128632
self.num_intervals = torch.as_tensor(num_intervals)
86138633

8614-
def custom_arange(nint):
8615-
result = torch.arange(
8616-
start=0.0,
8617-
end=1.0,
8618-
step=1 / nint,
8619-
dtype=self.dtype,
8620-
device=action_spec.device,
8621-
)
8622-
result_ = result
8623-
if self.sampling in (
8624-
self.SamplingStrategy.HIGH,
8625-
self.SamplingStrategy.MEDIAN,
8626-
):
8627-
result_ = (1 - result).flip(0)
8628-
if self.sampling == self.SamplingStrategy.MEDIAN:
8629-
result = (result + result_) / 2
8630-
else:
8631-
result = result_
8632-
return result
8633-
86348634
if isinstance(num_intervals, int):
86358635
arange = (
8636-
custom_arange(num_intervals).expand((*n_act, num_intervals))
8636+
self._custom_arange(num_intervals, action_spec.device).expand(
8637+
(*n_act, num_intervals)
8638+
)
86378639
* interval
86388640
)
86398641
low = action_spec.low
@@ -8642,7 +8644,7 @@ def custom_arange(nint):
86428644
self.register_buffer("intervals", low + arange)
86438645
else:
86448646
arange = [
8645-
custom_arange(_num_intervals) * interval
8647+
self._custom_arange(_num_intervals, action_spec.device) * interval
86468648
for _num_intervals, interval in zip(
86478649
num_intervals.tolist(), interval.unbind(-2)
86488650
)

0 commit comments

Comments
 (0)