Skip to content

Commit c1b2007

Browse files
committed
replay the world model embed correctly
1 parent a0cebd6 commit c1b2007

File tree

4 files changed

+72
-19
lines changed

4 files changed

+72
-19
lines changed

improving_transformers_world_model/agent.py

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections import deque
44

55
import torch
6-
from torch import nn, cat, stack, tensor, Tensor
6+
from torch import nn, cat, stack, is_tensor, tensor, Tensor
77
from torch.nn import Module, ModuleList, GRU
88

99
import torch.nn.functional as F
@@ -48,6 +48,9 @@ def is_odd(num):
4848

4949
# tensor helpers
5050

51+
def is_empty(t):
52+
return t.numel() == 0
53+
5154
def log(t, eps = 1e-20):
5255
return torch.log(t.clamp(min = eps))
5356

@@ -337,7 +340,7 @@ def forward(
337340
embed = self.proj_in(state)
338341

339342
if exists(world_model_embed):
340-
assert exists(self.world_model_film), f'`dim_world_model_embed` must be set on `Actor` to utilize world model for prediction'
343+
assert self.can_cond_on_world_model, f'`dim_world_model_embed` must be set on `Actor` to utilize world model for prediction'
341344

342345
embed = self.world_model_film(embed, world_model_embed)
343346

@@ -421,11 +424,13 @@ class Memory(NamedTuple):
421424
reward: Scalar
422425
value: Scalar
423426
done: Bool['']
427+
world_embed: Float['d'] | None
424428

425429
class MemoriesWithNextState(NamedTuple):
426-
memories: Deque[Memory]
427-
next_state: FrameState
428-
from_real_env: bool
430+
memories: Deque[Memory]
431+
next_state: FrameState
432+
from_real_env: bool
433+
has_world_model_embed: bool
429434

430435
# actor critic agent
431436

@@ -506,6 +511,7 @@ def policy_loss(
506511
old_log_probs: Float['b'],
507512
values: Float['b'],
508513
returns: Float['b'],
514+
world_model_embeds: Float['b d'] | None = None
509515
) -> Loss:
510516

511517
self.actor.train()
@@ -514,7 +520,7 @@ def policy_loss(
514520
advantages = F.layer_norm(returns - values, (batch,))
515521

516522
actor_critic_input, _ = self.impala(states)
517-
action_logits = self.actor(actor_critic_input)
523+
action_logits = self.actor(actor_critic_input, world_model_embed = world_model_embeds)
518524

519525
prob = action_logits.softmax(dim = -1)
520526

@@ -572,9 +578,11 @@ def learn(
572578
if isinstance(memories, MemoriesWithNextState):
573579
memories = [memories]
574580

581+
assert len({one_memory.has_world_model_embed for one_memory in memories}) == 1, 'memories must either all use world embed or not'
582+
575583
datasets = []
576584

577-
for one_memories, next_state, from_real_env in memories:
585+
for one_memories, next_state, from_real_env, _ in memories:
578586

579587
with torch.no_grad():
580588
self.critic.eval()
@@ -594,6 +602,7 @@ def learn(
594602
rewards,
595603
values,
596604
dones,
605+
world_model_embeds
597606
) = map(stack, zip(*list(one_memories)))
598607

599608
values_with_next = cat((values, next_value), dim = 0)
@@ -606,7 +615,7 @@ def learn(
606615

607616
# memories dataset for updating actor and critic learning
608617

609-
dataset = TensorDataset(states, actions, action_log_probs, returns, values, dones)
618+
dataset = TensorDataset(states, actions, action_log_probs, returns, values, dones, world_model_embeds)
610619

611620
datasets.append(dataset)
612621

@@ -630,7 +639,8 @@ def learn(
630639
action_log_probs,
631640
returns,
632641
values,
633-
dones
642+
dones,
643+
world_model_embeds
634644
) = tuple(t.to(self.device) for t in batched_data)
635645

636646
returns = self.batchnorm_target(returns)
@@ -642,7 +652,8 @@ def learn(
642652
actions = actions,
643653
old_log_probs = action_log_probs,
644654
values = values,
645-
returns = returns
655+
returns = returns,
656+
world_model_embeds = world_model_embeds if not is_empty(world_model_embeds) else None
646657
)
647658

648659
actor_loss.mean().backward()
@@ -703,6 +714,9 @@ def interact_with_env(
703714

704715
# maybe conditioning actor with learned world model embed
705716

717+
world_model_dim = world_model.dim if exists(world_model) else 0
718+
world_model_embeds = torch.empty((1, 0, world_model_dim), device = device, dtype = torch.float32)
719+
706720
if exists(world_model):
707721
world_model_cache = None
708722

@@ -752,6 +766,12 @@ def interact_with_env(
752766
next_done = rearrange(next_done, '1 -> 1 1')
753767
dones = cat((dones, next_done), dim = -1)
754768

769+
if exists(world_model_embed):
770+
next_embed = rearrange(world_model_embed, '... -> 1 ...')
771+
world_model_embeds = cat((world_model_embeds, next_embed), dim = 1)
772+
else:
773+
world_model_embeds = world_model_embeds.reshape(1, time_step + 1, 0)
774+
755775
time_step += 1
756776
last_done = dones[0, -1]
757777

@@ -763,7 +783,7 @@ def interact_with_env(
763783

764784
# move all intermediates to cpu and detach and store into memory for learning actor and critic
765785

766-
states, actions, action_log_probs, rewards, values, dones = tuple(rearrange(t, '1 ... -> ...').cpu() for t in (states, actions, action_log_probs, rewards, values, dones))
786+
states, actions, action_log_probs, rewards, values, dones, world_model_embeds = tuple(rearrange(t, '1 ... -> ...').cpu() for t in (states, actions, action_log_probs, rewards, values, dones, world_model_embeds))
767787

768788
states, next_state = states[:, :-1], states[:, -1:]
769789

@@ -778,11 +798,12 @@ def interact_with_env(
778798
rewards,
779799
values,
780800
dones,
801+
world_model_embeds
781802
))
782803

783804
memories.extend(episode_memories)
784805

785-
return MemoriesWithNextState(memories, next_state, from_real_env = True)
806+
return MemoriesWithNextState(memories, next_state, from_real_env = True, has_world_model_embed = exists(world_model))
786807

787808
@torch.no_grad()
788809
@inputs_to_model_device
@@ -791,8 +812,8 @@ def forward(
791812
world_model: WorldModel,
792813
init_state: FrameState,
793814
memories: Memories | None = None,
794-
max_steps = float('inf')
795-
815+
max_steps = float('inf'),
816+
use_world_model_embed = False
796817
) -> MemoriesWithNextState:
797818

798819
device = init_state.device
@@ -817,20 +838,48 @@ def forward(
817838
last_done = dones[0, -1]
818839
time_step = states.shape[2] + 1
819840

841+
world_model_dim = world_model.dim if use_world_model_embed else 0
842+
world_model_embeds = torch.empty((1, 0, world_model_dim), device = device, dtype = torch.float32)
843+
820844
world_model_cache = None
821845

822846
while time_step < max_steps and not last_done:
823847

848+
world_model_embed = None
849+
850+
if use_world_model_embed:
851+
with torch.no_grad():
852+
world_model.eval()
853+
854+
world_model_embed, _ = world_model(
855+
state_or_token_ids = states[:, :, -1:],
856+
actions = actions[:, -1:],
857+
rewards = rewards[:, -1:],
858+
cache = world_model_cache,
859+
remove_cache_len_from_time = False,
860+
return_embed = True,
861+
return_cache = True,
862+
return_loss = False
863+
)
864+
865+
world_model_embed = rearrange(world_model_embed, '1 1 d -> 1 d')
866+
824867
actor_critic_input, rnn_hiddens = self.impala(next_state)
825868

826-
action, action_log_prob = self.actor(actor_critic_input, sample_action = True)
869+
action, action_log_prob = self.actor(actor_critic_input, world_model_embed = world_model_embed, sample_action = True)
827870

828871
action = rearrange(action, 'b -> b 1 1')
829872
action_log_prob = rearrange(action_log_prob, 'b -> b 1')
830873

831874
actions = cat((actions, action), dim = 1)
832875
action_log_probs = cat((action_log_probs, action_log_prob), dim = 1)
833876

877+
if exists(world_model_embed):
878+
next_embed = rearrange(world_model_embed, '... -> 1 ...')
879+
world_model_embeds = cat((world_model_embeds, next_embed), dim = 1)
880+
else:
881+
world_model_embeds = world_model_embeds.reshape(1, time_step + 1, 0)
882+
834883
(states, rewards, dones), world_model_cache = world_model.sample(
835884
prompt = states,
836885
actions = actions,
@@ -852,7 +901,7 @@ def forward(
852901

853902
# move all intermediates to cpu and detach and store into memory for learning actor and critic
854903

855-
states, actions, action_log_probs, rewards, values, dones = tuple(rearrange(t, '1 ... -> ...').cpu() for t in (states, actions, action_log_probs, rewards, values, dones))
904+
states, actions, action_log_probs, rewards, values, dones, world_model_embeds = tuple(rearrange(t, '1 ... -> ...').cpu() for t in (states, actions, action_log_probs, rewards, values, dones, world_model_embeds))
856905

857906
states, next_state = states[:, :-1], states[:, -1:]
858907

@@ -867,8 +916,9 @@ def forward(
867916
rewards,
868917
values,
869918
dones,
919+
world_model_embeds
870920
))
871921

872922
memories.extend(episode_memories)
873923

874-
return MemoriesWithNextState(memories, next_state, from_real_env = False)
924+
return MemoriesWithNextState(memories, next_state, from_real_env = False, has_world_model_embed = use_world_model_embed)

improving_transformers_world_model/world_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,8 @@ def __init__(
576576
transformer = BlockCausalTransformer(**transformer)
577577

578578
self.transformer = transformer
579+
self.dim = transformer.dim
580+
579581
assert transformer.block_size == patches_per_image, f'transformer block size is recommended to be the number of patches per game image, which is {patches_per_image}'
580582

581583
if isinstance(tokenizer, dict):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "improving-transformers-world-model"
3-
version = "0.0.58"
3+
version = "0.0.59"
44
description = "Improving Transformers World Model for RL"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def test_agent(
6868
dream_memories = agent(
6969
world_model,
7070
state[0, :, 0],
71-
max_steps = 5
71+
max_steps = 5,
72+
use_world_model_embed = actor_use_world_model_embed
7273
)
7374

7475
real_memories = agent.interact_with_env(

0 commit comments

Comments
 (0)