Skip to content

Commit 116c0e1

Browse files
Update base for Update on "[Feature] ConditionalPolicySwitch transform"
[ghstack-poisoned]
2 parents 4791529 + dbc8e2e commit 116c0e1

23 files changed

+755
-168
lines changed

.github/workflows/docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
strategy:
2828
matrix:
2929
python_version: ["3.10"]
30-
cuda_arch_version: ["12.1"]
30+
cuda_arch_version: ["12.4"]
3131
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3232
with:
3333
repository: pytorch/rl

.github/workflows/test-linux-habitat.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
strategy:
2525
matrix:
2626
python_version: ["3.9"]
27-
cuda_arch_version: ["12.1"]
27+
cuda_arch_version: ["12.4"]
2828
fail-fast: false
2929
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3030
with:

.github/workflows/test-linux-libs.yml

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
strategy:
2626
matrix:
2727
python_version: ["3.9"]
28-
cuda_arch_version: ["12.1"]
28+
cuda_arch_version: ["12.4"]
2929
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
3030
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3131
with:
@@ -59,7 +59,7 @@ jobs:
5959
strategy:
6060
matrix:
6161
python_version: ["3.11"]
62-
cuda_arch_version: ["12.1"]
62+
cuda_arch_version: ["12.4"]
6363
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
6464
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
6565
with:
@@ -96,7 +96,7 @@ jobs:
9696
strategy:
9797
matrix:
9898
python_version: ["3.9"]
99-
cuda_arch_version: ["12.1"]
99+
cuda_arch_version: ["12.4"]
100100
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
101101
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
102102
with:
@@ -131,7 +131,7 @@ jobs:
131131
strategy:
132132
matrix:
133133
python_version: ["3.9"]
134-
cuda_arch_version: ["12.1"]
134+
cuda_arch_version: ["12.4"]
135135
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
136136
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
137137
with:
@@ -166,7 +166,7 @@ jobs:
166166
strategy:
167167
matrix:
168168
python_version: ["3.9"]
169-
cuda_arch_version: ["12.1"]
169+
cuda_arch_version: ["12.4"]
170170
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
171171
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
172172
with:
@@ -200,7 +200,7 @@ jobs:
200200
strategy:
201201
matrix:
202202
python_version: ["3.9"]
203-
cuda_arch_version: ["12.1"]
203+
cuda_arch_version: ["12.4"]
204204
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
205205
with:
206206
repository: pytorch/rl
@@ -235,7 +235,7 @@ jobs:
235235
strategy:
236236
matrix:
237237
python_version: ["3.9"]
238-
cuda_arch_version: ["12.1"]
238+
cuda_arch_version: ["12.4"]
239239
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
240240
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
241241
with:
@@ -256,7 +256,7 @@ jobs:
256256
257257
set -euo pipefail
258258
export PYTHON_VERSION="3.9"
259-
export CU_VERSION="12.1"
259+
export CU_VERSION="12.4"
260260
export TAR_OPTIONS="--no-same-owner"
261261
export UPLOAD_CHANNEL="nightly"
262262
export TF_CPP_MIN_LOG_LEVEL=0
@@ -277,7 +277,7 @@ jobs:
277277
repository: pytorch/rl
278278
runner: "linux.g5.4xlarge.nvidia.gpu"
279279
gpu-arch-type: cuda
280-
gpu-arch-version: "12.1"
280+
gpu-arch-version: "12.4"
281281
docker-image: "nvidia/cuda:12.4.1-runtime-ubuntu22.04"
282282
timeout: 120
283283
script: |
@@ -291,7 +291,7 @@ jobs:
291291
292292
set -euo pipefail
293293
export PYTHON_VERSION="3.11"
294-
export CU_VERSION="12.1"
294+
export CU_VERSION="12.4"
295295
export TAR_OPTIONS="--no-same-owner"
296296
export UPLOAD_CHANNEL="nightly"
297297
export TF_CPP_MIN_LOG_LEVEL=0
@@ -309,7 +309,7 @@ jobs:
309309
strategy:
310310
matrix:
311311
python_version: ["3.9"]
312-
cuda_arch_version: ["12.1"]
312+
cuda_arch_version: ["12.4"]
313313
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
314314
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
315315
with:
@@ -330,7 +330,7 @@ jobs:
330330
331331
set -euo pipefail
332332
export PYTHON_VERSION="3.9"
333-
export CU_VERSION="12.1"
333+
export CU_VERSION="12.4"
334334
export TAR_OPTIONS="--no-same-owner"
335335
export UPLOAD_CHANNEL="nightly"
336336
export TF_CPP_MIN_LOG_LEVEL=0
@@ -347,7 +347,7 @@ jobs:
347347
strategy:
348348
matrix:
349349
python_version: ["3.9"]
350-
cuda_arch_version: ["12.1"]
350+
cuda_arch_version: ["12.4"]
351351
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
352352
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
353353
with:
@@ -368,7 +368,7 @@ jobs:
368368
369369
set -euo pipefail
370370
export PYTHON_VERSION="3.9"
371-
export CU_VERSION="12.1"
371+
export CU_VERSION="12.4"
372372
export TAR_OPTIONS="--no-same-owner"
373373
export UPLOAD_CHANNEL="nightly"
374374
export TF_CPP_MIN_LOG_LEVEL=0
@@ -385,7 +385,7 @@ jobs:
385385
strategy:
386386
matrix:
387387
python_version: ["3.10.12"]
388-
cuda_arch_version: ["12.1"]
388+
cuda_arch_version: ["12.4"]
389389
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
390390
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
391391
with:
@@ -406,7 +406,7 @@ jobs:
406406
407407
set -euo pipefail
408408
export PYTHON_VERSION="3.10.12"
409-
export CU_VERSION="12.1"
409+
export CU_VERSION="12.4"
410410
export TAR_OPTIONS="--no-same-owner"
411411
export UPLOAD_CHANNEL="nightly"
412412
export TF_CPP_MIN_LOG_LEVEL=0
@@ -423,7 +423,7 @@ jobs:
423423
strategy:
424424
matrix:
425425
python_version: ["3.9"]
426-
cuda_arch_version: ["12.1"]
426+
cuda_arch_version: ["12.4"]
427427
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
428428
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
429429
with:
@@ -458,7 +458,7 @@ jobs:
458458
strategy:
459459
matrix:
460460
python_version: ["3.9"]
461-
cuda_arch_version: ["12.1"]
461+
cuda_arch_version: ["12.4"]
462462
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
463463
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
464464
with:
@@ -510,7 +510,7 @@ jobs:
510510
511511
set -euo pipefail
512512
export PYTHON_VERSION="3.9"
513-
export CU_VERSION="12.1"
513+
export CU_VERSION="12.4"
514514
export TAR_OPTIONS="--no-same-owner"
515515
export UPLOAD_CHANNEL="nightly"
516516
export TF_CPP_MIN_LOG_LEVEL=0
@@ -528,7 +528,7 @@ jobs:
528528
strategy:
529529
matrix:
530530
python_version: ["3.9"]
531-
cuda_arch_version: ["12.1"]
531+
cuda_arch_version: ["12.4"]
532532
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
533533
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
534534
with:
@@ -562,7 +562,7 @@ jobs:
562562
strategy:
563563
matrix:
564564
python_version: ["3.9"]
565-
cuda_arch_version: ["12.1"]
565+
cuda_arch_version: ["12.4"]
566566
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
567567
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
568568
with:
@@ -597,7 +597,7 @@ jobs:
597597
strategy:
598598
matrix:
599599
python_version: ["3.9"]
600-
cuda_arch_version: ["12.1"]
600+
cuda_arch_version: ["12.4"]
601601
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
602602
with:
603603
repository: pytorch/rl
@@ -633,7 +633,7 @@ jobs:
633633
strategy:
634634
matrix:
635635
python_version: ["3.9"]
636-
cuda_arch_version: ["12.1"]
636+
cuda_arch_version: ["12.4"]
637637
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
638638
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
639639
with:
@@ -654,7 +654,7 @@ jobs:
654654
655655
set -euo pipefail
656656
export PYTHON_VERSION="3.9"
657-
export CU_VERSION="12.1"
657+
export CU_VERSION="12.4"
658658
export TAR_OPTIONS="--no-same-owner"
659659
export UPLOAD_CHANNEL="nightly"
660660
export TF_CPP_MIN_LOG_LEVEL=0
@@ -672,7 +672,7 @@ jobs:
672672
strategy:
673673
matrix:
674674
python_version: ["3.9"]
675-
cuda_arch_version: ["12.1"]
675+
cuda_arch_version: ["12.4"]
676676
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
677677
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
678678
with:
@@ -707,7 +707,7 @@ jobs:
707707
strategy:
708708
matrix:
709709
python_version: ["3.9"]
710-
cuda_arch_version: ["12.1"]
710+
cuda_arch_version: ["12.4"]
711711
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
712712
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
713713
with:
@@ -728,7 +728,7 @@ jobs:
728728
729729
set -euo pipefail
730730
export PYTHON_VERSION="3.9"
731-
export CU_VERSION="12.1"
731+
export CU_VERSION="12.4"
732732
export TAR_OPTIONS="--no-same-owner"
733733
export UPLOAD_CHANNEL="nightly"
734734
export TF_CPP_MIN_LOG_LEVEL=0

.github/workflows/test-linux-rlhf.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
strategy:
2525
matrix:
2626
python_version: ["3.9"]
27-
cuda_arch_version: ["12.1"]
27+
cuda_arch_version: ["12.4"]
2828
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
2929
with:
3030
repository: pytorch/rl

.github/workflows/test-linux-sota.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
strategy:
2828
matrix:
2929
python_version: ["3.9"]
30-
cuda_arch_version: ["12.1"]
30+
cuda_arch_version: ["12.4"]
3131
fail-fast: false
3232
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3333
with:

.github/workflows/test-linux.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ jobs:
8989
strategy:
9090
matrix:
9191
python_version: ["3.11"]
92-
cuda_arch_version: ["12.1"]
92+
cuda_arch_version: ["12.4"]
9393
fail-fast: false
9494
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
9595
with:
@@ -158,7 +158,7 @@ jobs:
158158
strategy:
159159
matrix:
160160
python_version: ["3.11"]
161-
cuda_arch_version: ["12.1"]
161+
cuda_arch_version: ["12.4"]
162162
fail-fast: false
163163
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
164164
with:

docs/source/reference/data.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ using the following components:
148148
LazyMemmapStorage
149149
LazyTensorStorage
150150
ListStorage
151+
LazyStackStorage
151152
ListStorageCheckpointer
152153
NestedStorageCheckpointer
153154
PrioritizedSampler

docs/source/reference/envs.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,7 @@ in the relevant functions:
11171117
>>> print(env2._env.env.env)
11181118
<gym.envs.classic_control.pendulum.PendulumEnv at 0x1629916a0>
11191119

1120-
We can see that the two libraries modify the value returned by :func:`~.gym.gym_backend()`
1120+
We can see that the two libraries modify the value returned by :func:`~torchrl.envs.gym.gym_backend()`
11211121
which can be further used to indicate which library needs to be used for
11221122
the current computation. :class:`~.gym.set_gym_backend` is also a decorator:
11231123
we can use it to tell to a specific function what gym backend needs to be used
@@ -1188,3 +1188,4 @@ the following function will return ``1`` when queried:
11881188
VmasWrapper
11891189
gym_backend
11901190
set_gym_backend
1191+
register_gym_spec_conversion

test/mocking_classes.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,11 @@ def _step(
10681068
return tensordict
10691069

10701070

1071+
def get_random_string(min_size, max_size):
1072+
size = random.randint(min_size, max_size)
1073+
return "".join(random.choice(string.ascii_lowercase) for _ in range(size))
1074+
1075+
10711076
class CountingEnvWithString(CountingEnv):
10721077
def __init__(self, *args, **kwargs):
10731078
self.max_size = kwargs.pop("max_size", 30)
@@ -1083,8 +1088,7 @@ def __init__(self, *args, **kwargs):
10831088
)
10841089

10851090
def get_random_string(self):
1086-
size = random.randint(self.min_size, self.max_size)
1087-
return "".join(random.choice(string.ascii_lowercase) for _ in range(size))
1091+
return get_random_string(self.min_size, self.max_size)
10881092

10891093
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
10901094
res = super()._reset(tensordict, **kwargs)
@@ -2202,3 +2206,39 @@ def _step(
22022206

22032207
def _set_seed(self, seed):
22042208
...
2209+
2210+
2211+
class Str2StrEnv(EnvBase):
2212+
def __init__(self, min_size=4, max_size=10, **kwargs):
2213+
self.observation_spec = Composite(
2214+
observation=NonTensor(example_data="an observation!", shape=())
2215+
)
2216+
self.full_action_spec = Composite(
2217+
action=NonTensor(example_data="an action!", shape=())
2218+
)
2219+
self.reward_spec = Unbounded(shape=(1,), dtype=torch.float)
2220+
self.min_size = min_size
2221+
self.max_size = max_size
2222+
super().__init__(**kwargs)
2223+
2224+
def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2225+
assert isinstance(tensordict["action"], str)
2226+
out = tensordict.empty()
2227+
out.set("observation", self.get_random_string())
2228+
out.set("done", torch.zeros(1, dtype=torch.bool).bernoulli_(0.01))
2229+
out.set("reward", torch.zeros(1, dtype=torch.float).bernoulli_(0.01))
2230+
return out
2231+
2232+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2233+
out = tensordict.empty() if tensordict is not None else TensorDict()
2234+
out.set("observation", self.get_random_string())
2235+
out.set("done", torch.zeros(1, dtype=torch.bool).bernoulli_(0.01))
2236+
return out
2237+
2238+
def get_random_string(self):
2239+
return get_random_string(self.min_size, self.max_size)
2240+
2241+
def _set_seed(self, seed: Optional[int]):
2242+
random.seed(seed)
2243+
torch.manual_seed(0)
2244+
return seed

0 commit comments

Comments
 (0)