Skip to content

Commit 863121a

Browse files
author
Vincent Moens
committed
[BugFix] Fix failing tests
ghstack-source-id: a43a2e3 Pull Request resolved: #2582
1 parent 408cf7d commit 863121a

File tree

20 files changed

+248
-117
lines changed

20 files changed

+248
-117
lines changed

.github/workflows/docs.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ jobs:
119119
120120
REF_TYPE=${{ github.ref_type }}
121121
REF_NAME=${{ github.ref_name }}
122+
apt-get update
123+
apt-get install rsync -y
122124
123125
if [[ "${REF_TYPE}" == branch ]]; then
124126
if [[ "${REF_NAME}" == main ]]; then

sota-implementations/ddpg/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
234234
OrnsteinUhlenbeckProcessModule(
235235
spec=action_spec,
236236
annealing_num_steps=1_000_000,
237-
).to(device),
237+
device=device,
238+
),
238239
)
239240
elif cfg.network.noise_type == "gaussian":
240241
actor_model_explore = TensorDictSequential(
@@ -245,7 +246,8 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
245246
sigma_init=1.0,
246247
mean=0.0,
247248
std=0.1,
248-
).to(device),
249+
device=device,
250+
),
249251
)
250252
else:
251253
raise NotImplementedError

sota-implementations/dreamer/dreamer_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def make_dreamer(
275275
annealing_num_steps=1,
276276
mean=0.0,
277277
std=cfg.networks.exploration_noise,
278+
device=device,
278279
),
279280
)
280281

sota-implementations/multiagent/maddpg_iddpg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def train(cfg: "DictConfig"): # noqa: F821
108108
spec=env.unbatched_action_spec,
109109
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
110110
action_key=env.action_key,
111+
device=cfg.train.device,
111112
),
112113
)
113114

sota-implementations/redq/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ collector:
3030
async_collection: 1
3131
frames_per_batch: 1024
3232
total_frames: 1_000_000
33-
device: cpu
33+
device:
3434
env_per_collector: 1
3535
init_random_frames: 50_000
3636
multi_step: 1

sota-implementations/redq/redq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def main(cfg: "DictConfig"): # noqa: F821
119119
annealing_num_steps=cfg.exploration.annealing_frames,
120120
sigma=cfg.exploration.ou_sigma,
121121
theta=cfg.exploration.ou_theta,
122-
).to(device),
122+
device=device,
123+
),
123124
)
124125
if device == torch.device("cpu"):
125126
# mostly for debugging

sota-implementations/redq/utils.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,55 +21,59 @@
2121
from torchrl._utils import logger as torchrl_logger, VERBOSE
2222
from torchrl.collectors.collectors import DataCollectorBase
2323

24-
from torchrl.data import ReplayBuffer, TensorDictReplayBuffer
25-
from torchrl.data.postprocs import MultiStep
26-
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
27-
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
24+
from torchrl.data import (
25+
LazyMemmapStorage,
26+
MultiStep,
27+
PrioritizedSampler,
28+
RandomSampler,
29+
ReplayBuffer,
30+
TensorDictReplayBuffer,
31+
)
2832
from torchrl.data.utils import DEVICE_TYPING
29-
from torchrl.envs import ParallelEnv
30-
from torchrl.envs.common import EnvBase
31-
from torchrl.envs.env_creator import env_creator, EnvCreator
32-
from torchrl.envs.libs.dm_control import DMControlEnv
33-
from torchrl.envs.libs.gym import GymEnv
34-
from torchrl.envs.transforms import (
33+
from torchrl.envs import (
3534
CatFrames,
3635
CatTensors,
3736
CenterCrop,
3837
Compose,
38+
DMControlEnv,
3939
DoubleToFloat,
40+
env_creator,
41+
EnvBase,
42+
EnvCreator,
43+
FlattenObservation,
4044
GrayScale,
45+
gSDENoise,
46+
GymEnv,
47+
InitTracker,
4148
NoopResetEnv,
4249
ObservationNorm,
50+
ParallelEnv,
4351
Resize,
4452
RewardScaling,
53+
StepCounter,
4554
ToTensorImage,
4655
TransformedEnv,
4756
VecNorm,
4857
)
49-
from torchrl.envs.transforms.transforms import (
50-
FlattenObservation,
51-
gSDENoise,
52-
InitTracker,
53-
StepCounter,
54-
)
5558
from torchrl.envs.utils import ExplorationType, set_exploration_type
5659
from torchrl.modules import (
5760
ActorCriticOperator,
5861
ActorValueOperator,
62+
DdpgCnnActor,
63+
DdpgCnnQNet,
64+
MLP,
5965
NoisyLinear,
6066
NormalParamExtractor,
67+
ProbabilisticActor,
6168
SafeModule,
6269
SafeSequential,
70+
TanhNormal,
71+
ValueOperator,
6372
)
64-
from torchrl.modules.distributions import TanhNormal
6573
from torchrl.modules.distributions.continuous import SafeTanhTransform
6674
from torchrl.modules.models.exploration import LazygSDEModule
67-
from torchrl.modules.models.models import DdpgCnnActor, DdpgCnnQNet, MLP
68-
from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
69-
from torchrl.objectives import HardUpdate, SoftUpdate
70-
from torchrl.objectives.common import LossModule
75+
from torchrl.objectives import HardUpdate, LossModule, SoftUpdate, TargetNetUpdater
7176
from torchrl.objectives.deprecated import REDQLoss_deprecated
72-
from torchrl.objectives.utils import TargetNetUpdater
7377
from torchrl.record.loggers import Logger
7478
from torchrl.record.recorder import VideoRecorder
7579
from torchrl.trainers.helpers import sync_async_collector, sync_sync_collector
@@ -518,7 +522,7 @@ def make_redq_model(
518522
actor_module = SafeSequential(
519523
actor_module,
520524
SafeModule(
521-
LazygSDEModule(transform=transform),
525+
LazygSDEModule(transform=transform, device=device),
522526
in_keys=["action", gSDE_state_key, "_eps_gSDE"],
523527
out_keys=["loc", "scale", "action", "_eps_gSDE"],
524528
),
@@ -606,7 +610,9 @@ def make_transformed_env(**kwargs) -> TransformedEnv:
606610
categorical_action_encoding = cfg.env.categorical_action_encoding
607611

608612
if custom_env is None and custom_env_maker is None:
609-
if isinstance(cfg.collector.device, str):
613+
if cfg.collector.device in ("", None):
614+
device = "cpu" if not torch.cuda.is_available() else "cuda:0"
615+
elif isinstance(cfg.collector.device, str):
610616
device = cfg.collector.device
611617
elif isinstance(cfg.collector.device, Sequence):
612618
device = cfg.collector.device[0]
@@ -1000,11 +1006,14 @@ def make_collector_offpolicy(
10001006
env_kwargs.update(make_env_kwargs)
10011007
elif make_env_kwargs is not None:
10021008
env_kwargs = make_env_kwargs
1003-
cfg.collector.device = (
1004-
cfg.collector.device
1005-
if len(cfg.collector.device) > 1
1006-
else cfg.collector.device[0]
1007-
)
1009+
if cfg.collector.device in ("", None):
1010+
cfg.collector.device = "cpu" if not torch.cuda.is_available() else "cuda:0"
1011+
else:
1012+
cfg.collector.device = (
1013+
cfg.collector.device
1014+
if len(cfg.collector.device) > 1
1015+
else cfg.collector.device[0]
1016+
)
10081017
collector_helper_kwargs = {
10091018
"env_fns": make_env,
10101019
"env_kwargs": env_kwargs,
@@ -1017,7 +1026,6 @@ def make_collector_offpolicy(
10171026
# we already took care of building the make_parallel_env function
10181027
"num_collectors": -cfg.num_workers // -cfg.collector.env_per_collector,
10191028
"device": cfg.collector.device,
1020-
"storing_device": cfg.collector.device,
10211029
"init_random_frames": cfg.collector.init_random_frames,
10221030
"split_trajs": True,
10231031
# trajectories must be separated if multi-step is used

sota-implementations/td3/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ def make_td3_agent(cfg, train_env, eval_env, device):
242242
mean=0,
243243
std=0.1,
244244
spec=action_spec,
245-
).to(device),
245+
device=device,
246+
),
246247
)
247248
return model, actor_model_explore
248249

sota-implementations/td3_bc/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ def make_td3_agent(cfg, train_env, device):
183183
mean=0,
184184
std=0.1,
185185
spec=action_spec,
186-
).to(device),
186+
device=device,
187+
),
187188
)
188189
return model, actor_model_explore
189190

test/_utils_internal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,11 @@ def get_available_devices():
167167
def get_default_devices():
168168
num_cuda = torch.cuda.device_count()
169169
if num_cuda == 0:
170+
if torch.mps.is_available():
171+
return [torch.device("mps:0")]
170172
return [torch.device("cpu")]
171173
elif num_cuda == 1:
172174
return [torch.device("cuda:0")]
173-
elif torch.mps.is_available():
174-
return [torch.device("mps:0")]
175175
else:
176176
# then run on all devices
177177
return get_available_devices()

0 commit comments

Comments
 (0)