Skip to content

Commit 01399e0

Browse files
author
Vincent Moens
committed
[BugFix] Fix get_original_weights in collectors
ghstack-source-id: bf77b22 Pull-Request-resolved: #2951
1 parent 6b48e08 commit 01399e0

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

sota-implementations/sac/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,7 @@ def make_collector(cfg, train_env, actor_model_explore, compile_mode):
128128
device = cfg.collector.device
129129
if device in ("", None):
130130
if torch.cuda.is_available():
131-
if torch.cuda.device_count() < 2:
132-
raise RuntimeError("Requires >= 2 GPUs")
133-
device = torch.device("cuda:1")
131+
device = torch.device("cuda:0")
134132
else:
135133
device = torch.device("cpu")
136134
collector = SyncDataCollector(
@@ -158,7 +156,9 @@ def make_collector_async(
158156
device = cfg.collector.device
159157
if device in ("", None):
160158
if torch.cuda.is_available():
161-
device = torch.device("cuda:0")
159+
if torch.cuda.device_count() < 2:
160+
raise RuntimeError("Requires >= 2 GPUs")
161+
device = torch.device("cuda:1")
162162
else:
163163
device = torch.device("cpu")
164164

test/test_collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3419,7 +3419,7 @@ def test_collector_rb_multisync(
34193419
assert len(rb) == pred_len
34203420
collector.shutdown()
34213421
assert len(rb) == 256
3422-
if not extend_buffer:
3422+
if extend_buffer:
34233423
steps_counts = rb["step_count"].squeeze().split(16)
34243424
collector_ids = rb["collector", "traj_ids"].squeeze().split(16)
34253425
for step_count, ids in zip(steps_counts, collector_ids):

torchrl/collectors/collectors.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def _get_policy_and_device(
208208
return policy, None
209209

210210
if isinstance(policy, nn.Module):
211-
param_and_buf = TensorDict.from_module(policy, as_module=True).data
211+
param_and_buf = TensorDict.from_module(policy, as_module=True)
212212
else:
213213
# Because we want to reach the warning
214214
param_and_buf = TensorDict()
@@ -231,19 +231,25 @@ def _get_policy_and_device(
231231
return policy, None
232232

233233
# Create a stateless policy, then populate this copy with params on device
234-
def get_original_weights(policy):
234+
def get_original_weights(policy=policy):
235235
td = TensorDict.from_module(policy)
236236
return td.data
237237

238238
# We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function
239239
with param_and_buf.data.to("meta").to_module(policy):
240-
policy = deepcopy(policy)
240+
policy_new_device = deepcopy(policy)
241241

242-
param_and_buf.apply(
242+
param_and_buf_new_device = param_and_buf.apply(
243243
functools.partial(_map_weight, policy_device=policy_device),
244244
filter_empty=False,
245-
).to_module(policy)
246-
return policy, get_original_weights
245+
)
246+
param_and_buf_new_device.to_module(policy_new_device)
247+
# Sanity check
248+
if set(TensorDict.from_module(policy_new_device).keys(True, True)) != set(
249+
get_original_weights().keys(True, True)
250+
):
251+
raise RuntimeError("Failed to map weights. The weight sets mismatch.")
252+
return policy_new_device, get_original_weights
247253

248254
def start(self):
249255
"""Starts the collector for asynchronous data collection.
@@ -1976,17 +1982,17 @@ def __init__(
19761982
for policy_device, env_maker, env_maker_kwargs in _zip_strict(
19771983
self.policy_device, self.create_env_fn, self.create_env_kwargs
19781984
):
1979-
(policy_copy, get_weights_fn,) = self._get_policy_and_device(
1985+
(policy_new_device, get_weights_fn,) = self._get_policy_and_device(
19801986
policy=policy,
19811987
policy_device=policy_device,
19821988
env_maker=env_maker,
19831989
env_maker_kwargs=env_maker_kwargs,
19841990
)
1985-
if type(policy_copy) is not type(policy):
1986-
policy = policy_copy
1991+
if type(policy_new_device) is not type(policy):
1992+
policy = policy_new_device
19871993
weights = (
1988-
TensorDict.from_module(policy_copy).data
1989-
if isinstance(policy_copy, nn.Module)
1994+
TensorDict.from_module(policy_new_device).data
1995+
if isinstance(policy_new_device, nn.Module)
19901996
else TensorDict()
19911997
)
19921998
self._policy_weights_dict[policy_device] = weights

0 commit comments

Comments
 (0)