@@ -1483,8 +1483,9 @@ def env_fn(seed):
1483
1483
assert_allclose_td (data10 , data20 )
1484
1484
1485
1485
@pytest .mark .parametrize ("use_async" , [False , True ])
1486
+ @pytest .mark .parametrize ("cudagraph" , [False , True ])
1486
1487
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "no cuda device found" )
1487
- def test_update_weights (self , use_async ):
1488
+ def test_update_weights (self , use_async , cudagraph ):
1488
1489
def create_env ():
1489
1490
return ContinuousActionVecMockEnv ()
1490
1491
@@ -1504,48 +1505,51 @@ def create_env():
1504
1505
storing_device = [torch .device ("cuda:0" )] * 3 ,
1505
1506
frames_per_batch = 20 ,
1506
1507
cat_results = "stack" ,
1508
+ cudagraph_policy = cudagraph ,
1507
1509
)
1508
- # collect state_dict
1509
- state_dict = collector .state_dict ()
1510
- policy_state_dict = policy .state_dict ()
1511
- for worker in range (3 ):
1512
- for k in state_dict [f"worker{ worker } " ]["policy_state_dict" ]:
1513
- torch .testing .assert_close (
1514
- state_dict [f"worker{ worker } " ]["policy_state_dict" ][k ],
1515
- policy_state_dict [k ].cpu (),
1516
- )
1517
-
1518
- # change policy weights
1519
- for p in policy .parameters ():
1520
- p .data += torch .randn_like (p )
1521
-
1522
- # collect state_dict
1523
- state_dict = collector .state_dict ()
1524
- policy_state_dict = policy .state_dict ()
1525
- # check they don't match
1526
- for worker in range (3 ):
1527
- for k in state_dict [f"worker{ worker } " ]["policy_state_dict" ]:
1528
- with pytest .raises (AssertionError ):
1510
+ try :
1511
+ # collect state_dict
1512
+ state_dict = collector .state_dict ()
1513
+ policy_state_dict = policy .state_dict ()
1514
+ for worker in range (3 ):
1515
+ assert "policy_state_dict" in state_dict [f"worker{ worker } " ], state_dict [f"worker{ worker } " ].keys ()
1516
+ for k in state_dict [f"worker{ worker } " ]["policy_state_dict" ]:
1529
1517
torch .testing .assert_close (
1530
1518
state_dict [f"worker{ worker } " ]["policy_state_dict" ][k ],
1531
1519
policy_state_dict [k ].cpu (),
1532
1520
)
1533
1521
1534
- # update weights
1535
- collector .update_policy_weights_ ()
1536
-
1537
- # collect state_dict
1538
- state_dict = collector .state_dict ()
1539
- policy_state_dict = policy .state_dict ()
1540
- for worker in range (3 ):
1541
- for k in state_dict [f"worker{ worker } " ]["policy_state_dict" ]:
1542
- torch .testing .assert_close (
1543
- state_dict [f"worker{ worker } " ]["policy_state_dict" ][k ],
1544
- policy_state_dict [k ].cpu (),
1545
- )
1546
-
1547
- collector .shutdown ()
1548
- del collector
1522
+ # change policy weights
1523
+ for p in policy .parameters ():
1524
+ p .data += torch .randn_like (p )
1525
+
1526
+ # collect state_dict
1527
+ state_dict = collector .state_dict ()
1528
+ policy_state_dict = policy .state_dict ()
1529
+ # check they don't match
1530
+ for worker in range (3 ):
1531
+ for k in state_dict [f"worker{ worker } " ]["policy_state_dict" ]:
1532
+ with pytest .raises (AssertionError ):
1533
+ torch .testing .assert_close (
1534
+ state_dict [f"worker{ worker } " ]["policy_state_dict" ][k ],
1535
+ policy_state_dict [k ].cpu (),
1536
+ )
1537
+
1538
+ # update weights
1539
+ collector .update_policy_weights_ ()
1540
+
1541
+ # collect state_dict
1542
+ state_dict = collector .state_dict ()
1543
+ policy_state_dict = policy .state_dict ()
1544
+ for worker in range (3 ):
1545
+ for k in state_dict [f"worker{ worker } " ]["policy_state_dict" ]:
1546
+ torch .testing .assert_close (
1547
+ state_dict [f"worker{ worker } " ]["policy_state_dict" ][k ],
1548
+ policy_state_dict [k ].cpu (),
1549
+ )
1550
+ finally :
1551
+ collector .shutdown ()
1552
+ del collector
1549
1553
1550
1554
1551
1555
class TestCollectorDevices :
0 commit comments