Skip to content

Commit 3a9f244

Browse files
authored
[BE] _set_seed returns None + type annotations (#2903)
1 parent f5f3ae4 commit 3a9f244

30 files changed

+85
-91
lines changed

docs/source/reference/envs.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ Here is a working example:
814814
... reward = self.full_reward_spec.zero()
815815
... return observation.update(done).update(reward)
816816
...
817-
... def _set_seed(self, seed: Optional[int]):
817+
... def _set_seed(self, seed: Optional[int]) -> None:
818818
... self.manual_seed = seed
819819
... return seed
820820
>>> env = EnvWithDynamicSpec()

test/mocking_classes.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -128,29 +128,29 @@ def __init__(
128128
self.is_closed = False
129129

130130
@property
131-
def maxstep(self):
131+
def maxstep(self) -> int:
132132
return 100
133133

134-
def _set_seed(self, seed: int | None):
134+
def _set_seed(self, seed: int | None) -> None:
135135
self.seed = seed
136136
self.counter = seed % 17 # make counter a small number
137137

138-
def custom_fun(self):
138+
def custom_fun(self) -> int:
139139
return 0
140140

141141
custom_attr = 1
142142

143143
@property
144-
def custom_prop(self):
144+
def custom_prop(self) -> int:
145145
return 2
146146

147147
@property
148-
def custom_td(self):
148+
def custom_td(self) -> TensorDict:
149149
return TensorDict({"a": torch.zeros(3)}, [])
150150

151151

152152
class MockSerialEnv(EnvBase):
153-
"""A simple counting env that is reset after a predifined max number of steps."""
153+
"""A simple counting env that is reset after a predefined max number of steps."""
154154

155155
@classmethod
156156
def __new__(
@@ -219,13 +219,13 @@ def __init__(self, device="cpu"):
219219
super().__init__(device=device)
220220
self.is_closed = False
221221

222-
def _set_seed(self, seed: int | None):
222+
def _set_seed(self, seed: int | None) -> None:
223223
assert seed >= 1
224224
self.seed = seed
225225
self.counter = seed % 17 # make counter a small number
226226
self.max_val = max(self.counter + 100, self.counter * 2)
227227

228-
def _step(self, tensordict):
228+
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
229229
self.counter += 1
230230
n = torch.tensor(
231231
[self.counter], device=self.device, dtype=torch.get_default_dtype()
@@ -341,13 +341,13 @@ def __init__(self, device="cpu", batch_size=None):
341341

342342
rand_step = MockSerialEnv.rand_step
343343

344-
def _set_seed(self, seed: int | None):
344+
def _set_seed(self, seed: int | None) -> None:
345345
assert seed >= 1
346346
self.seed = seed
347347
self.counter = seed % 17 # make counter a small number
348348
self.max_val = max(self.counter + 100, self.counter * 2)
349349

350-
def _step(self, tensordict):
350+
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
351351
if len(self.batch_size):
352352
leading_batch_size = (
353353
tensordict.shape[: -len(self.batch_size)]
@@ -506,7 +506,7 @@ def _step(
506506
device=tensordict.device,
507507
)
508508

509-
def _set_seed(self, seed: int | None):
509+
def _set_seed(self, seed: int | None) -> None:
510510
...
511511

512512

@@ -738,7 +738,7 @@ def _get_in_obs(self, tensordict):
738738
obs = tensordict.get(*self.in_keys)
739739
return obs
740740

741-
def __call__(self, tensordict):
741+
def __call__(self, tensordict: TensorDictBase) -> TensorDictBase:
742742
obs = self._get_in_obs(tensordict)
743743
max_obs = (obs == obs.max(dim=-1, keepdim=True)[0]).cumsum(-1).argmax(-1)
744744
k = tensordict.get(*self.in_keys).shape[-1]
@@ -1106,7 +1106,7 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
11061106
torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int),
11071107
)
11081108

1109-
def _set_seed(self, seed: int | None):
1109+
def _set_seed(self, seed: int | None) -> None:
11101110
torch.manual_seed(seed)
11111111

11121112
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
@@ -1285,7 +1285,7 @@ def __init__(
12851285
torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int),
12861286
)
12871287

1288-
def _set_seed(self, seed: int | None):
1288+
def _set_seed(self, seed: int | None) -> None:
12891289
torch.manual_seed(seed)
12901290

12911291
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
@@ -1611,7 +1611,7 @@ def __init__(
16111611
elif start_val.numel() <= 1:
16121612
self.start_val = start_val.expand_as(self.count)
16131613

1614-
def _set_seed(self, seed: int | None):
1614+
def _set_seed(self, seed: int | None) -> None:
16151615
torch.manual_seed(seed)
16161616

16171617
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
@@ -1827,7 +1827,7 @@ def _step(
18271827

18281828
return td
18291829

1830-
def _set_seed(self, seed: int | None):
1830+
def _set_seed(self, seed: int | None) -> None:
18311831
torch.manual_seed(seed)
18321832

18331833

@@ -2058,7 +2058,7 @@ def _step(
20582058
assert td.batch_size == self.batch_size
20592059
return td
20602060

2061-
def _set_seed(self, seed: int | None):
2061+
def _set_seed(self, seed: int | None) -> None:
20622062
torch.manual_seed(seed)
20632063

20642064

@@ -2095,8 +2095,8 @@ def _step(
20952095
data.update(self._saved_full_reward_spec.zero())
20962096
return data
20972097

2098-
def _set_seed(self, seed: int | None):
2099-
return seed
2098+
def _set_seed(self, seed: int | None) -> None:
2099+
...
21002100

21012101

21022102
class AutoResettingCountingEnv(CountingEnv):
@@ -2221,9 +2221,8 @@ def _step(
22212221
reward = self.full_reward_spec.zero()
22222222
return observation.update(done).update(reward)
22232223

2224-
def _set_seed(self, seed: int | None):
2224+
def _set_seed(self, seed: int | None) -> None:
22252225
self.manual_seed = seed
2226-
return seed
22272226

22282227

22292228
class EnvWithScalarAction(EnvBase):
@@ -2291,7 +2290,7 @@ def _step(
22912290
),
22922291
)
22932292

2294-
def _set_seed(self, seed: int | None):
2293+
def _set_seed(self, seed: int | None) -> None:
22952294
...
22962295

22972296

@@ -2305,7 +2304,7 @@ def _step(
23052304
) -> TensorDictBase:
23062305
return TensorDict(batch_size=self.batch_size, device=self.device)
23072306

2308-
def _set_seed(self, seed):
2307+
def _set_seed(self, seed: int | None) -> None:
23092308
...
23102309

23112310

@@ -2339,10 +2338,9 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
23392338
def get_random_string(self):
23402339
return get_random_string(self.min_size, self.max_size)
23412340

2342-
def _set_seed(self, seed: int | None):
2341+
def _set_seed(self, seed: int | None) -> None:
23432342
random.seed(seed)
23442343
torch.manual_seed(0)
2345-
return seed
23462344

23472345

23482346
class EnvThatErrorsAfter10Iters(EnvBase):
@@ -2367,7 +2365,7 @@ def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDict:
23672365
.update(self.full_reward_spec.zero())
23682366
)
23692367

2370-
def _set_seed(self, seed: int | None):
2368+
def _set_seed(self, seed: int | None) -> None:
23712369
...
23722370

23732371

test/test_collector.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,7 @@ def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDict:
915915
.update(self.full_reward_spec.zero())
916916
)
917917
918-
def _set_seed(self, seed: Optional[int]):
918+
def _set_seed(self, seed: Optional[int]) -> None:
919919
...
920920
921921
if __name__ == "__main__":
@@ -1617,8 +1617,8 @@ def _reset(self, tensordict=None):
16171617
device=None,
16181618
)
16191619

1620-
def _set_seed(self, seed: int | None = None):
1621-
return seed
1620+
def _set_seed(self, seed: int | None = None) -> None:
1621+
...
16221622

16231623
class EnvWithDevice(EnvBase):
16241624
def __init__(self, default_device):
@@ -1674,8 +1674,8 @@ def _reset(self, tensordict=None):
16741674
device=self.default_device,
16751675
)
16761676

1677-
def _set_seed(self, seed: int | None = None):
1678-
return seed
1677+
def _set_seed(self, seed: int | None = None) -> None:
1678+
...
16791679

16801680
class DeviceLessPolicy(TensorDictModuleBase):
16811681
in_keys = ["observation"]
@@ -1840,8 +1840,8 @@ def _step(
18401840
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
18411841
return self.full_done_specs.zeros().update(self.observation_spec.zeros())
18421842

1843-
def _set_seed(self, seed: int | None):
1844-
return seed
1843+
def _set_seed(self, seed: int | None) -> None:
1844+
...
18451845

18461846
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device")
18471847
@pytest.mark.parametrize("env_device", ["cuda:0", "cpu"])
@@ -2660,8 +2660,8 @@ def _reset(self, tensordict=None):
26602660
{"state": self.state.clone()}, self.batch_size, device=self.device
26612661
)
26622662

2663-
def _set_seed(self, seed):
2664-
return seed
2663+
def _set_seed(self, seed: int | None) -> None:
2664+
...
26652665

26662666
class Policy(TensorDictModuleBase):
26672667
def __init__(self):

test/test_cost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def _step(
280280
def _reset(self, tensordic):
281281
...
282282

283-
def _set_seed(self, seed: int | None):
283+
def _set_seed(self, seed: int | None) -> None:
284284
...
285285

286286

test/test_env.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def _step(
291291
) -> TensorDictBase:
292292
...
293293

294-
def _set_seed(self, seed: int | None):
294+
def _set_seed(self, seed: int | None) -> None:
295295
...
296296

297297
def test_env_lock(self):
@@ -453,9 +453,8 @@ def __init__(self, device):
453453
)
454454
self.seed = 0
455455

456-
def _set_seed(self, seed):
456+
def _set_seed(self, seed: int | None) -> None:
457457
self.seed = seed
458-
return seed
459458

460459
def _reset(self, tensordict):
461460
td = self.observation_spec.zero().update(self.done_spec.zero())

test/test_libs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,8 @@ def _step(self, tensordict):
289289
batch_size=[],
290290
)
291291

292-
def _set_seed(self, seed):
293-
return seed + 1
292+
def _set_seed(self, seed: int | None) -> None:
293+
...
294294

295295
@implement_for("gym", None, "0.18")
296296
def _make_spec(self, batch_size, cat, cat_shape, multicat, multicat_shape):

test/test_transforms.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4329,8 +4329,8 @@ def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
43294329
{"done": torch.zeros(1, dtype=torch.bool)}
43304330
)
43314331

4332-
def _set_seed(self, seed):
4333-
return seed + 1
4332+
def _set_seed(self, seed: int | None) -> None:
4333+
...
43344334

43354335
def test_single_trans_env_check(self):
43364336
t = Compose(
@@ -4567,8 +4567,8 @@ def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
45674567
{"done": torch.zeros(1, dtype=torch.bool)}
45684568
)
45694569

4570-
def _set_seed(self, seed):
4571-
return seed + 1
4570+
def _set_seed(self, seed: int | None) -> None:
4571+
...
45724572

45734573
def test_single_trans_env_check(self):
45744574
t = Compose(
@@ -9517,7 +9517,7 @@ def _step(
95179517
tensordict["reward"] = self.reward_spec.rand()
95189518
return tensordict
95199519

9520-
def _set_seed(self, seed: int | None):
9520+
def _set_seed(self, seed: int | None) -> None:
95219521
...
95229522

95239523
@pytest.mark.parametrize("batched", [False, True])
@@ -11880,8 +11880,8 @@ def _step(self, data):
1188011880
td.set("done", ~(mask.any().view(1)))
1188111881
return td
1188211882

11883-
def _set_seed(self, seed):
11884-
return seed
11883+
def _set_seed(self, seed: int | None) -> None:
11884+
...
1188511885

1188611886
return MaskedEnv
1188711887

@@ -13050,8 +13050,8 @@ def _step(self, tensordict):
1305013050
.update(self.full_reward_spec.rand())
1305113051
)
1305213052

13053-
def _set_seed(self, seed):
13054-
return seed + 1
13053+
def _set_seed(self, seed: int | None) -> None:
13054+
...
1305513055

1305613056
def test_single_trans_env_check(self):
1305713057
env = TransformedEnv(self.DummyEnv(), RemoveEmptySpecs())
@@ -13261,8 +13261,8 @@ def _step(
1326113261
result.update(self.full_reward_spec.zero(tensordict.batch_size))
1326213262
return result
1326313263

13264-
def _set_seed(self, seed: int):
13265-
pass
13264+
def _set_seed(self, seed: int | None) -> None:
13265+
...
1326613266

1326713267
@classmethod
1326813268
def reset_func(tensordict, tensordict_reset, env):
@@ -13886,7 +13886,7 @@ def _step(self, tensordict: TensorDict) -> TensorDict:
1388613886
}
1388713887
)
1388813888

13889-
def _set_seed(self) -> None:
13889+
def _set_seed(self, seed: int | None = None) -> None:
1389013890
pass
1389113891

1389213892
@pytest.mark.parametrize(

torchrl/envs/async_envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def _sort_results(self, results, *other_results):
305305
return results, *other_results, idx
306306
return results, idx
307307

308-
def _set_seed(self, seed: int | None):
308+
def _set_seed(self, seed: int | None) -> None:
309309
raise NotImplementedError
310310

311311
@abc.abstractmethod

torchrl/envs/batched_envs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -878,8 +878,9 @@ def close(self, *, raise_if_closed: bool = True) -> None:
878878
def _shutdown_workers(self) -> None:
879879
raise NotImplementedError
880880

881-
def _set_seed(self, seed: int | None):
881+
def _set_seed(self, seed: int | None) -> None:
882882
"""This method is not used in batched envs."""
883+
pass
883884

884885
@lazy
885886
def start(self) -> None:

torchrl/envs/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2899,7 +2899,7 @@ def set_seed(
28992899
return seed
29002900

29012901
@abc.abstractmethod
2902-
def _set_seed(self, seed: int | None):
2902+
def _set_seed(self, seed: int | None) -> None:
29032903
raise NotImplementedError
29042904

29052905
def set_state(self):

0 commit comments

Comments
 (0)