3
3
from collections import deque
4
4
5
5
import torch
6
- from torch import nn , cat , stack , tensor , Tensor
6
+ from torch import nn , cat , stack , is_tensor , tensor , Tensor
7
7
from torch .nn import Module , ModuleList , GRU
8
8
9
9
import torch .nn .functional as F
@@ -48,6 +48,9 @@ def is_odd(num):
48
48
49
49
# tensor helpers
50
50
51
+ def is_empty (t ):
52
+ return t .numel () == 0
53
+
51
54
def log (t , eps = 1e-20 ):
52
55
return torch .log (t .clamp (min = eps ))
53
56
@@ -337,7 +340,7 @@ def forward(
337
340
embed = self .proj_in (state )
338
341
339
342
if exists (world_model_embed ):
340
- assert exists ( self .world_model_film ) , f'`dim_world_model_embed` must be set on `Actor` to utilize world model for prediction'
343
+ assert self .can_cond_on_world_model , f'`dim_world_model_embed` must be set on `Actor` to utilize world model for prediction'
341
344
342
345
embed = self .world_model_film (embed , world_model_embed )
343
346
@@ -421,11 +424,13 @@ class Memory(NamedTuple):
421
424
reward : Scalar
422
425
value : Scalar
423
426
done : Bool ['' ]
427
+ world_embed : Float ['d' ] | None
424
428
425
429
class MemoriesWithNextState (NamedTuple ):
426
- memories : Deque [Memory ]
427
- next_state : FrameState
428
- from_real_env : bool
430
+ memories : Deque [Memory ]
431
+ next_state : FrameState
432
+ from_real_env : bool
433
+ has_world_model_embed : bool
429
434
430
435
# actor critic agent
431
436
@@ -506,6 +511,7 @@ def policy_loss(
506
511
old_log_probs : Float ['b' ],
507
512
values : Float ['b' ],
508
513
returns : Float ['b' ],
514
+ world_model_embeds : Float ['b d' ] | None = None
509
515
) -> Loss :
510
516
511
517
self .actor .train ()
@@ -514,7 +520,7 @@ def policy_loss(
514
520
advantages = F .layer_norm (returns - values , (batch ,))
515
521
516
522
actor_critic_input , _ = self .impala (states )
517
- action_logits = self .actor (actor_critic_input )
523
+ action_logits = self .actor (actor_critic_input , world_model_embed = world_model_embeds )
518
524
519
525
prob = action_logits .softmax (dim = - 1 )
520
526
@@ -572,9 +578,11 @@ def learn(
572
578
if isinstance (memories , MemoriesWithNextState ):
573
579
memories = [memories ]
574
580
581
+ assert len ({one_memory .has_world_model_embed for one_memory in memories }) == 1 , 'memories must either all use world embed or not'
582
+
575
583
datasets = []
576
584
577
- for one_memories , next_state , from_real_env in memories :
585
+ for one_memories , next_state , from_real_env , _ in memories :
578
586
579
587
with torch .no_grad ():
580
588
self .critic .eval ()
@@ -594,6 +602,7 @@ def learn(
594
602
rewards ,
595
603
values ,
596
604
dones ,
605
+ world_model_embeds
597
606
) = map (stack , zip (* list (one_memories )))
598
607
599
608
values_with_next = cat ((values , next_value ), dim = 0 )
@@ -606,7 +615,7 @@ def learn(
606
615
607
616
# memories dataset for updating actor and critic learning
608
617
609
- dataset = TensorDataset (states , actions , action_log_probs , returns , values , dones )
618
+ dataset = TensorDataset (states , actions , action_log_probs , returns , values , dones , world_model_embeds )
610
619
611
620
datasets .append (dataset )
612
621
@@ -630,7 +639,8 @@ def learn(
630
639
action_log_probs ,
631
640
returns ,
632
641
values ,
633
- dones
642
+ dones ,
643
+ world_model_embeds
634
644
) = tuple (t .to (self .device ) for t in batched_data )
635
645
636
646
returns = self .batchnorm_target (returns )
@@ -642,7 +652,8 @@ def learn(
642
652
actions = actions ,
643
653
old_log_probs = action_log_probs ,
644
654
values = values ,
645
- returns = returns
655
+ returns = returns ,
656
+ world_model_embeds = world_model_embeds if not is_empty (world_model_embeds ) else None
646
657
)
647
658
648
659
actor_loss .mean ().backward ()
@@ -703,6 +714,9 @@ def interact_with_env(
703
714
704
715
# maybe conditioning actor with learned world model embed
705
716
717
+ world_model_dim = world_model .dim if exists (world_model ) else 0
718
+ world_model_embeds = torch .empty ((1 , 0 , world_model_dim ), device = device , dtype = torch .float32 )
719
+
706
720
if exists (world_model ):
707
721
world_model_cache = None
708
722
@@ -752,6 +766,12 @@ def interact_with_env(
752
766
next_done = rearrange (next_done , '1 -> 1 1' )
753
767
dones = cat ((dones , next_done ), dim = - 1 )
754
768
769
+ if exists (world_model_embed ):
770
+ next_embed = rearrange (world_model_embed , '... -> 1 ...' )
771
+ world_model_embeds = cat ((world_model_embeds , next_embed ), dim = 1 )
772
+ else :
773
+ world_model_embeds = world_model_embeds .reshape (1 , time_step + 1 , 0 )
774
+
755
775
time_step += 1
756
776
last_done = dones [0 , - 1 ]
757
777
@@ -763,7 +783,7 @@ def interact_with_env(
763
783
764
784
# move all intermediates to cpu and detach and store into memory for learning actor and critic
765
785
766
- states , actions , action_log_probs , rewards , values , dones = tuple (rearrange (t , '1 ... -> ...' ).cpu () for t in (states , actions , action_log_probs , rewards , values , dones ))
786
+ states , actions , action_log_probs , rewards , values , dones , world_model_embeds = tuple (rearrange (t , '1 ... -> ...' ).cpu () for t in (states , actions , action_log_probs , rewards , values , dones , world_model_embeds ))
767
787
768
788
states , next_state = states [:, :- 1 ], states [:, - 1 :]
769
789
@@ -778,11 +798,12 @@ def interact_with_env(
778
798
rewards ,
779
799
values ,
780
800
dones ,
801
+ world_model_embeds
781
802
))
782
803
783
804
memories .extend (episode_memories )
784
805
785
- return MemoriesWithNextState (memories , next_state , from_real_env = True )
806
+ return MemoriesWithNextState (memories , next_state , from_real_env = True , has_world_model_embed = exists ( world_model ) )
786
807
787
808
@torch .no_grad ()
788
809
@inputs_to_model_device
@@ -791,8 +812,8 @@ def forward(
791
812
world_model : WorldModel ,
792
813
init_state : FrameState ,
793
814
memories : Memories | None = None ,
794
- max_steps = float ('inf' )
795
-
815
+ max_steps = float ('inf' ),
816
+ use_world_model_embed = False
796
817
) -> MemoriesWithNextState :
797
818
798
819
device = init_state .device
@@ -817,20 +838,48 @@ def forward(
817
838
last_done = dones [0 , - 1 ]
818
839
time_step = states .shape [2 ] + 1
819
840
841
+ world_model_dim = world_model .dim if use_world_model_embed else 0
842
+ world_model_embeds = torch .empty ((1 , 0 , world_model_dim ), device = device , dtype = torch .float32 )
843
+
820
844
world_model_cache = None
821
845
822
846
while time_step < max_steps and not last_done :
823
847
848
+ world_model_embed = None
849
+
850
+ if use_world_model_embed :
851
+ with torch .no_grad ():
852
+ world_model .eval ()
853
+
854
+ world_model_embed , _ = world_model (
855
+ state_or_token_ids = states [:, :, - 1 :],
856
+ actions = actions [:, - 1 :],
857
+ rewards = rewards [:, - 1 :],
858
+ cache = world_model_cache ,
859
+ remove_cache_len_from_time = False ,
860
+ return_embed = True ,
861
+ return_cache = True ,
862
+ return_loss = False
863
+ )
864
+
865
+ world_model_embed = rearrange (world_model_embed , '1 1 d -> 1 d' )
866
+
824
867
actor_critic_input , rnn_hiddens = self .impala (next_state )
825
868
826
- action , action_log_prob = self .actor (actor_critic_input , sample_action = True )
869
+ action , action_log_prob = self .actor (actor_critic_input , world_model_embed = world_model_embed , sample_action = True )
827
870
828
871
action = rearrange (action , 'b -> b 1 1' )
829
872
action_log_prob = rearrange (action_log_prob , 'b -> b 1' )
830
873
831
874
actions = cat ((actions , action ), dim = 1 )
832
875
action_log_probs = cat ((action_log_probs , action_log_prob ), dim = 1 )
833
876
877
+ if exists (world_model_embed ):
878
+ next_embed = rearrange (world_model_embed , '... -> 1 ...' )
879
+ world_model_embeds = cat ((world_model_embeds , next_embed ), dim = 1 )
880
+ else :
881
+ world_model_embeds = world_model_embeds .reshape (1 , time_step + 1 , 0 )
882
+
834
883
(states , rewards , dones ), world_model_cache = world_model .sample (
835
884
prompt = states ,
836
885
actions = actions ,
@@ -852,7 +901,7 @@ def forward(
852
901
853
902
# move all intermediates to cpu and detach and store into memory for learning actor and critic
854
903
855
- states , actions , action_log_probs , rewards , values , dones = tuple (rearrange (t , '1 ... -> ...' ).cpu () for t in (states , actions , action_log_probs , rewards , values , dones ))
904
+ states , actions , action_log_probs , rewards , values , dones , world_model_embeds = tuple (rearrange (t , '1 ... -> ...' ).cpu () for t in (states , actions , action_log_probs , rewards , values , dones , world_model_embeds ))
856
905
857
906
states , next_state = states [:, :- 1 ], states [:, - 1 :]
858
907
@@ -867,8 +916,9 @@ def forward(
867
916
rewards ,
868
917
values ,
869
918
dones ,
919
+ world_model_embeds
870
920
))
871
921
872
922
memories .extend (episode_memories )
873
923
874
- return MemoriesWithNextState (memories , next_state , from_real_env = False )
924
+ return MemoriesWithNextState (memories , next_state , from_real_env = False , has_world_model_embed = use_world_model_embed )
0 commit comments