Skip to content

Commit 0355a01

Browse files
authored
[BugFix] update_policy_weights_() with cudagraph (#3003)
1 parent 16b70be commit 0355a01

File tree

1 file changed

+41
-37
lines changed

1 file changed

+41
-37
lines changed

test/test_collector.py

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,8 +1483,9 @@ def env_fn(seed):
14831483
assert_allclose_td(data10, data20)
14841484

14851485
@pytest.mark.parametrize("use_async", [False, True])
1486+
@pytest.mark.parametrize("cudagraph", [False, True])
14861487
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found")
1487-
def test_update_weights(self, use_async):
1488+
def test_update_weights(self, use_async, cudagraph):
14881489
def create_env():
14891490
return ContinuousActionVecMockEnv()
14901491

@@ -1504,48 +1505,51 @@ def create_env():
15041505
storing_device=[torch.device("cuda:0")] * 3,
15051506
frames_per_batch=20,
15061507
cat_results="stack",
1508+
cudagraph_policy=cudagraph,
15071509
)
1508-
# collect state_dict
1509-
state_dict = collector.state_dict()
1510-
policy_state_dict = policy.state_dict()
1511-
for worker in range(3):
1512-
for k in state_dict[f"worker{worker}"]["policy_state_dict"]:
1513-
torch.testing.assert_close(
1514-
state_dict[f"worker{worker}"]["policy_state_dict"][k],
1515-
policy_state_dict[k].cpu(),
1516-
)
1517-
1518-
# change policy weights
1519-
for p in policy.parameters():
1520-
p.data += torch.randn_like(p)
1521-
1522-
# collect state_dict
1523-
state_dict = collector.state_dict()
1524-
policy_state_dict = policy.state_dict()
1525-
# check they don't match
1526-
for worker in range(3):
1527-
for k in state_dict[f"worker{worker}"]["policy_state_dict"]:
1528-
with pytest.raises(AssertionError):
1510+
try:
1511+
# collect state_dict
1512+
state_dict = collector.state_dict()
1513+
policy_state_dict = policy.state_dict()
1514+
for worker in range(3):
1515+
assert "policy_state_dict" in state_dict[f"worker{worker}"], state_dict[f"worker{worker}"].keys()
1516+
for k in state_dict[f"worker{worker}"]["policy_state_dict"]:
15291517
torch.testing.assert_close(
15301518
state_dict[f"worker{worker}"]["policy_state_dict"][k],
15311519
policy_state_dict[k].cpu(),
15321520
)
15331521

1534-
# update weights
1535-
collector.update_policy_weights_()
1536-
1537-
# collect state_dict
1538-
state_dict = collector.state_dict()
1539-
policy_state_dict = policy.state_dict()
1540-
for worker in range(3):
1541-
for k in state_dict[f"worker{worker}"]["policy_state_dict"]:
1542-
torch.testing.assert_close(
1543-
state_dict[f"worker{worker}"]["policy_state_dict"][k],
1544-
policy_state_dict[k].cpu(),
1545-
)
1546-
1547-
collector.shutdown()
1548-
del collector
1522+
# change policy weights
1523+
for p in policy.parameters():
1524+
p.data += torch.randn_like(p)
1525+
1526+
# collect state_dict
1527+
state_dict = collector.state_dict()
1528+
policy_state_dict = policy.state_dict()
1529+
# check they don't match
1530+
for worker in range(3):
1531+
for k in state_dict[f"worker{worker}"]["policy_state_dict"]:
1532+
with pytest.raises(AssertionError):
1533+
torch.testing.assert_close(
1534+
state_dict[f"worker{worker}"]["policy_state_dict"][k],
1535+
policy_state_dict[k].cpu(),
1536+
)
1537+
1538+
# update weights
1539+
collector.update_policy_weights_()
1540+
1541+
# collect state_dict
1542+
state_dict = collector.state_dict()
1543+
policy_state_dict = policy.state_dict()
1544+
for worker in range(3):
1545+
for k in state_dict[f"worker{worker}"]["policy_state_dict"]:
1546+
torch.testing.assert_close(
1547+
state_dict[f"worker{worker}"]["policy_state_dict"][k],
1548+
policy_state_dict[k].cpu(),
1549+
)
1550+
finally:
1551+
collector.shutdown()
1552+
del collector
15491553

15501554

15511555
class TestCollectorDevices:

0 commit comments

Comments
 (0)