Skip to content

Commit 4b3279a

Browse files
author
Vincent Moens
committed
[BE] Add type annotation for tensor_keys to facilitate auto-complete
ghstack-source-id: b4a8fe3 Pull Request resolved: pytorch/rl#2696
1 parent b7a0d11 commit 4b3279a

File tree

19 files changed

+35
-0
lines changed

19 files changed

+35
-0
lines changed

torchrl/objectives/a2c.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ class _AcceptedKeys:
242242
terminated: NestedKey = "terminated"
243243
sample_log_prob: NestedKey = "sample_log_prob"
244244

245+
tensor_keys: _AcceptedKeys
245246
default_keys = _AcceptedKeys()
246247
default_value_estimator: ValueEstimators = ValueEstimators.GAE
247248

torchrl/objectives/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class _AcceptedKeys:
128128

129129
pass
130130

131+
tensor_keys: _AcceptedKeys
131132
_vmap_randomness = None
132133
default_value_estimator: ValueEstimators = None
133134

torchrl/objectives/cql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ class _AcceptedKeys:
260260
done: NestedKey = "done"
261261
terminated: NestedKey = "terminated"
262262

263+
tensor_keys: _AcceptedKeys
263264
default_keys = _AcceptedKeys()
264265
default_value_estimator = ValueEstimators.TD0
265266

@@ -1024,6 +1025,7 @@ class _AcceptedKeys:
10241025
terminated: NestedKey = "terminated"
10251026
pred_val: NestedKey = "pred_val"
10261027

1028+
tensor_keys: _AcceptedKeys
10271029
default_keys = _AcceptedKeys()
10281030
default_value_estimator = ValueEstimators.TD0
10291031
out_keys = [

torchrl/objectives/crossq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ class _AcceptedKeys:
242242
terminated: NestedKey = "terminated"
243243
log_prob: NestedKey = "_log_prob"
244244

245+
tensor_keys: _AcceptedKeys
245246
default_keys = _AcceptedKeys()
246247
default_value_estimator = ValueEstimators.TD0
247248

torchrl/objectives/ddpg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ class _AcceptedKeys:
173173
done: NestedKey = "done"
174174
terminated: NestedKey = "terminated"
175175

176+
tensor_keys: _AcceptedKeys
176177
default_keys = _AcceptedKeys()
177178
default_value_estimator: ValueEstimators = ValueEstimators.TD0
178179
out_keys = [

torchrl/objectives/decision_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class _AcceptedKeys:
7070
# the "action" output from the model
7171
action_pred: NestedKey = "action"
7272

73+
tensor_keys: _AcceptedKeys
7374
default_keys = _AcceptedKeys()
7475

7576
actor_network: TensorDictModule
@@ -280,6 +281,7 @@ class _AcceptedKeys:
280281
# the "action" output from the model
281282
action_pred: NestedKey = "action"
282283

284+
tensor_keys: _AcceptedKeys
283285
default_keys = _AcceptedKeys()
284286

285287
actor_network: TensorDictModule

torchrl/objectives/deprecated.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ class _AcceptedKeys:
127127
done: NestedKey = "done"
128128
terminated: NestedKey = "terminated"
129129

130+
tensor_keys: _AcceptedKeys
130131
default_keys = _AcceptedKeys()
131132
delay_actor: bool = False
132133
default_value_estimator = ValueEstimators.TD0

torchrl/objectives/dqn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ class _AcceptedKeys:
164164
done: NestedKey = "done"
165165
terminated: NestedKey = "terminated"
166166

167+
tensor_keys: _AcceptedKeys
167168
default_keys = _AcceptedKeys()
168169
default_value_estimator = ValueEstimators.TD0
169170
out_keys = ["loss"]
@@ -435,6 +436,7 @@ class _AcceptedKeys:
435436
terminated: NestedKey = "terminated"
436437
steps_to_next_obs: NestedKey = "steps_to_next_obs"
437438

439+
tensor_keys: _AcceptedKeys
438440
default_keys = _AcceptedKeys()
439441
default_value_estimator = ValueEstimators.TD0
440442

torchrl/objectives/dreamer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,13 @@ class _AcceptedKeys:
8989
pixels: NestedKey = "pixels"
9090
reco_pixels: NestedKey = "reco_pixels"
9191

92+
tensor_keys: _AcceptedKeys
9293
default_keys = _AcceptedKeys()
9394

95+
decoder: TensorDictModule
96+
reward_model: TensorDictModule
97+
world_mdel: TensorDictModule
98+
9499
def __init__(
95100
self,
96101
world_model: TensorDictModule,
@@ -238,9 +243,13 @@ class _AcceptedKeys:
238243
done: NestedKey = "done"
239244
terminated: NestedKey = "terminated"
240245

246+
tensor_keys: _AcceptedKeys
241247
default_keys = _AcceptedKeys()
242248
default_value_estimator = ValueEstimators.TDLambda
243249

250+
value_model: TensorDictModule
251+
actor_model: TensorDictModule
252+
244253
def __init__(
245254
self,
246255
actor_model: TensorDictModule,
@@ -392,8 +401,11 @@ class _AcceptedKeys:
392401

393402
value: NestedKey = "state_value"
394403

404+
tensor_keys: _AcceptedKeys
395405
default_keys = _AcceptedKeys()
396406

407+
value_model: TensorDictModule
408+
397409
def __init__(
398410
self,
399411
value_model: TensorDictModule,

torchrl/objectives/gail.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class _AcceptedKeys:
5959
collector_observation: NestedKey = "collector_observation"
6060
discriminator_pred: NestedKey = "d_logits"
6161

62+
tensor_keys: _AcceptedKeys
6263
default_keys = _AcceptedKeys()
6364

6465
discriminator_network: TensorDictModule

0 commit comments

Comments
 (0)