Skip to content

Commit 526b38d

Browse files
author
Vincent Moens
committed
[Quality] IMPALA auto-device
ghstack-source-id: abbb304 Pull Request resolved: #2654
1 parent 187de7c commit 526b38d

File tree

6 files changed

+18
-7
lines changed

6 files changed

+18
-7
lines changed

sota-implementations/impala/config_multi_node_ray.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ ray_init_config:
2424
storage: null
2525

2626
# Device for the forward and backward passes
27-
local_device: "cuda:0"
27+
local_device:
2828

2929
# Resources assigned to each IMPALA rollout collection worker
3030
remote_worker_resources:

sota-implementations/impala/config_multi_node_submitit.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ env:
33
env_name: PongNoFrameskip-v4
44

55
# Device for the forward and backward passes
6-
local_device: "cuda:0"
6+
local_device:
77

88
# SLURM config
99
slurm_config:

sota-implementations/impala/config_single_node.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ env:
33
env_name: PongNoFrameskip-v4
44

55
# Device for the forward and backward passes
6-
device: "cuda:0"
6+
device:
77

88
# collector
99
collector:

sota-implementations/impala/impala_multi_node_ray.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ def main(cfg: "DictConfig"): # noqa: F821
3232
from torchrl.record.loggers import generate_exp_name, get_logger
3333
from utils import eval_model, make_env, make_ppo_models
3434

35-
device = torch.device(cfg.local_device)
35+
device = cfg.local_device
36+
if not device:
37+
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
38+
else:
39+
device = torch.device(device)
3640

3741
# Correct for frame_skip
3842
frame_skip = 4

sota-implementations/impala/impala_multi_node_submitit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ def main(cfg: "DictConfig"): # noqa: F821
3434
from torchrl.record.loggers import generate_exp_name, get_logger
3535
from utils import eval_model, make_env, make_ppo_models
3636

37-
device = torch.device(cfg.local_device)
37+
device = cfg.local_device
38+
if not device:
39+
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
40+
else:
41+
device = torch.device(device)
3842

3943
# Correct for frame_skip
4044
frame_skip = 4

sota-implementations/impala/impala_single_node.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ def main(cfg: "DictConfig"): # noqa: F821
3131
from torchrl.record.loggers import generate_exp_name, get_logger
3232
from utils import eval_model, make_env, make_ppo_models
3333

34-
device = torch.device(cfg.device)
34+
device = cfg.device
35+
if not device:
36+
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
37+
else:
38+
device = torch.device(device)
3539

3640
# Correct for frame_skip
3741
frame_skip = 4
@@ -55,7 +59,6 @@ def main(cfg: "DictConfig"): # noqa: F821
5559

5660
# Create models (check utils.py)
5761
actor, critic = make_ppo_models(cfg.env.env_name)
58-
actor, critic = actor.to(device), critic.to(device)
5962

6063
# Create collector
6164
collector = MultiaSyncDataCollector(

0 commit comments

Comments
 (0)