Skip to content

Commit 4b3279a

Browse files
author
Vincent Moens
committed
[BE] Add type annotation for tensor_keys to facilitate auto-complete
ghstack-source-id: b4a8fe3 Pull Request resolved: #2696
1 parent b7a0d11 commit 4b3279a

File tree

19 files changed

+35
-0
lines changed

19 files changed

+35
-0
lines changed

torchrl/objectives/a2c.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ class _AcceptedKeys:
242242
terminated: NestedKey = "terminated"
243243
sample_log_prob: NestedKey = "sample_log_prob"
244244

245+
tensor_keys: _AcceptedKeys
245246
default_keys = _AcceptedKeys()
246247
default_value_estimator: ValueEstimators = ValueEstimators.GAE
247248

torchrl/objectives/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class _AcceptedKeys:
128128

129129
pass
130130

131+
tensor_keys: _AcceptedKeys
131132
_vmap_randomness = None
132133
default_value_estimator: ValueEstimators = None
133134

torchrl/objectives/cql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ class _AcceptedKeys:
260260
done: NestedKey = "done"
261261
terminated: NestedKey = "terminated"
262262

263+
tensor_keys: _AcceptedKeys
263264
default_keys = _AcceptedKeys()
264265
default_value_estimator = ValueEstimators.TD0
265266

@@ -1024,6 +1025,7 @@ class _AcceptedKeys:
10241025
terminated: NestedKey = "terminated"
10251026
pred_val: NestedKey = "pred_val"
10261027

1028+
tensor_keys: _AcceptedKeys
10271029
default_keys = _AcceptedKeys()
10281030
default_value_estimator = ValueEstimators.TD0
10291031
out_keys = [

torchrl/objectives/crossq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ class _AcceptedKeys:
242242
terminated: NestedKey = "terminated"
243243
log_prob: NestedKey = "_log_prob"
244244

245+
tensor_keys: _AcceptedKeys
245246
default_keys = _AcceptedKeys()
246247
default_value_estimator = ValueEstimators.TD0
247248

torchrl/objectives/ddpg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ class _AcceptedKeys:
173173
done: NestedKey = "done"
174174
terminated: NestedKey = "terminated"
175175

176+
tensor_keys: _AcceptedKeys
176177
default_keys = _AcceptedKeys()
177178
default_value_estimator: ValueEstimators = ValueEstimators.TD0
178179
out_keys = [

torchrl/objectives/decision_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class _AcceptedKeys:
7070
# the "action" output from the model
7171
action_pred: NestedKey = "action"
7272

73+
tensor_keys: _AcceptedKeys
7374
default_keys = _AcceptedKeys()
7475

7576
actor_network: TensorDictModule
@@ -280,6 +281,7 @@ class _AcceptedKeys:
280281
# the "action" output from the model
281282
action_pred: NestedKey = "action"
282283

284+
tensor_keys: _AcceptedKeys
283285
default_keys = _AcceptedKeys()
284286

285287
actor_network: TensorDictModule

torchrl/objectives/deprecated.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ class _AcceptedKeys:
127127
done: NestedKey = "done"
128128
terminated: NestedKey = "terminated"
129129

130+
tensor_keys: _AcceptedKeys
130131
default_keys = _AcceptedKeys()
131132
delay_actor: bool = False
132133
default_value_estimator = ValueEstimators.TD0

torchrl/objectives/dqn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ class _AcceptedKeys:
164164
done: NestedKey = "done"
165165
terminated: NestedKey = "terminated"
166166

167+
tensor_keys: _AcceptedKeys
167168
default_keys = _AcceptedKeys()
168169
default_value_estimator = ValueEstimators.TD0
169170
out_keys = ["loss"]
@@ -435,6 +436,7 @@ class _AcceptedKeys:
435436
terminated: NestedKey = "terminated"
436437
steps_to_next_obs: NestedKey = "steps_to_next_obs"
437438

439+
tensor_keys: _AcceptedKeys
438440
default_keys = _AcceptedKeys()
439441
default_value_estimator = ValueEstimators.TD0
440442

torchrl/objectives/dreamer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,13 @@ class _AcceptedKeys:
8989
pixels: NestedKey = "pixels"
9090
reco_pixels: NestedKey = "reco_pixels"
9191

92+
tensor_keys: _AcceptedKeys
9293
default_keys = _AcceptedKeys()
9394

95+
decoder: TensorDictModule
96+
reward_model: TensorDictModule
97+
world_mdel: TensorDictModule
98+
9499
def __init__(
95100
self,
96101
world_model: TensorDictModule,
@@ -238,9 +243,13 @@ class _AcceptedKeys:
238243
done: NestedKey = "done"
239244
terminated: NestedKey = "terminated"
240245

246+
tensor_keys: _AcceptedKeys
241247
default_keys = _AcceptedKeys()
242248
default_value_estimator = ValueEstimators.TDLambda
243249

250+
value_model: TensorDictModule
251+
actor_model: TensorDictModule
252+
244253
def __init__(
245254
self,
246255
actor_model: TensorDictModule,
@@ -392,8 +401,11 @@ class _AcceptedKeys:
392401

393402
value: NestedKey = "state_value"
394403

404+
tensor_keys: _AcceptedKeys
395405
default_keys = _AcceptedKeys()
396406

407+
value_model: TensorDictModule
408+
397409
def __init__(
398410
self,
399411
value_model: TensorDictModule,

torchrl/objectives/gail.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class _AcceptedKeys:
5959
collector_observation: NestedKey = "collector_observation"
6060
discriminator_pred: NestedKey = "d_logits"
6161

62+
tensor_keys: _AcceptedKeys
6263
default_keys = _AcceptedKeys()
6364

6465
discriminator_network: TensorDictModule

torchrl/objectives/iql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ class _AcceptedKeys:
233233
done: NestedKey = "done"
234234
terminated: NestedKey = "terminated"
235235

236+
tensor_keys: _AcceptedKeys
236237
default_keys = _AcceptedKeys()
237238
default_value_estimator = ValueEstimators.TD0
238239
out_keys = [
@@ -709,6 +710,7 @@ class _AcceptedKeys:
709710
done: NestedKey = "done"
710711
terminated: NestedKey = "terminated"
711712

713+
tensor_keys: _AcceptedKeys
712714
default_keys = _AcceptedKeys()
713715
default_value_estimator = ValueEstimators.TD0
714716
out_keys = [

torchrl/objectives/multiagent/qmixer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class _AcceptedKeys:
179179
done: NestedKey = "done"
180180
terminated: NestedKey = "terminated"
181181

182+
tensor_keys: _AcceptedKeys
182183
default_keys = _AcceptedKeys()
183184
default_value_estimator = ValueEstimators.TD0
184185
out_keys = ["loss"]

torchrl/objectives/ppo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ class _AcceptedKeys:
295295
done: NestedKey = "done"
296296
terminated: NestedKey = "terminated"
297297

298+
tensor_keys: _AcceptedKeys
298299
default_keys = _AcceptedKeys()
299300
default_value_estimator = ValueEstimators.GAE
300301

torchrl/objectives/redq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ class _AcceptedKeys:
231231
done: NestedKey = "done"
232232
terminated: NestedKey = "terminated"
233233

234+
tensor_keys: _AcceptedKeys
234235
default_keys = _AcceptedKeys()
235236
delay_actor: bool = False
236237
default_value_estimator = ValueEstimators.TD0

torchrl/objectives/reinforce.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ class _AcceptedKeys:
211211
done: NestedKey = "done"
212212
terminated: NestedKey = "terminated"
213213

214+
tensor_keys: _AcceptedKeys
214215
default_keys = _AcceptedKeys()
215216
default_value_estimator = ValueEstimators.GAE
216217
out_keys = ["loss_actor", "loss_value"]

torchrl/objectives/sac.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ class _AcceptedKeys:
290290
done: NestedKey = "done"
291291
terminated: NestedKey = "terminated"
292292

293+
tensor_keys: _AcceptedKeys
293294
default_keys = _AcceptedKeys()
294295
default_value_estimator = ValueEstimators.TD0
295296

@@ -1029,6 +1030,7 @@ class _AcceptedKeys:
10291030
terminated: NestedKey = "terminated"
10301031
log_prob: NestedKey = "log_prob"
10311032

1033+
tensor_keys: _AcceptedKeys
10321034
default_keys = _AcceptedKeys()
10331035
default_value_estimator = ValueEstimators.TD0
10341036
delay_actor: bool = False

torchrl/objectives/td3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class _AcceptedKeys:
204204
done: NestedKey = "done"
205205
terminated: NestedKey = "terminated"
206206

207+
tensor_keys: _AcceptedKeys
207208
default_keys = _AcceptedKeys()
208209
default_value_estimator = ValueEstimators.TD0
209210
out_keys = [

torchrl/objectives/td3_bc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ class _AcceptedKeys:
217217
done: NestedKey = "done"
218218
terminated: NestedKey = "terminated"
219219

220+
tensor_keys: _AcceptedKeys
220221
default_keys = _AcceptedKeys()
221222
default_value_estimator = ValueEstimators.TD0
222223
out_keys = [

torchrl/objectives/value/advantages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ class _AcceptedKeys:
143143
steps_to_next_obs: NestedKey = "steps_to_next_obs"
144144
sample_log_prob: NestedKey = "sample_log_prob"
145145

146+
tensor_keys: _AcceptedKeys
146147
default_keys = _AcceptedKeys()
147148
value_network: Union[TensorDictModule, Callable]
148149
_vmap_randomness = None

0 commit comments

Comments
 (0)