-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
Hi, thanks for the great work.
I've noticed an issue where obs and obs_next end up being equal even when they shouldn’t, after collecting data using a ShmemVectorEnv environment with complex observation structures.
This problem seems specific to ShmemVectorEnv. Unlike DummyVectorEnv, when the environment returns complex observation objects, the shared memory mechanism appears to entangle object references in a way that even copy.deepcopy (as suggested in the documentation) doesn’t fully resolve.
As a result, environments with complex observation structures cause ShmemVectorEnv to store obs and obs_next with the same values — meaning obs is overwritten and no longer reflects its true state.
I’ve written a small piece of code that demonstrates this issue clearly and helps reproduce the problem.
import copy
import numpy as np
import gymnasium as gym
import tianshou as ts
from typing import Dict, Any
from tianshou.algorithm.algorithm_base import RandomActionPolicy
from tianshou.data.buffer.vecbuf import VectorReplayBuffer
from tianshou.data import Collector
def make_generic_env():
class GenericEnv(gym.Env):
def __init__(self):
super().__init__()
self.num_features = 5
self.entity_id = 0
self.step_count = 0
self.observation_space = gym.spaces.Dict(
{
"id": gym.spaces.Box(low=0, high=100, shape=(1,), dtype=np.int32),
"step": gym.spaces.Box(low=0, high=100, shape=(1,), dtype=np.int32),
"features": gym.spaces.Box(
low=0, high=9999, shape=(self.num_features,), dtype=np.int32
),
}
)
self.action_space = gym.spaces.Discrete(self.num_features)
def _get_obs(self) -> Dict[str, Any]:
return {
"id": np.array([self.entity_id], dtype=np.int32),
"step": np.array([self.step_count], dtype=np.int32),
"features": np.arange(self.num_features, dtype=np.int32)
+ self.step_count,
}
def reset(self, *, seed=None, options=None):
self.step_count = 0
obs = self._get_obs()
return copy.deepcopy(obs), {}
def step(self, action):
self.step_count += 1
reward = float(action)
terminated = self.step_count >= 3
truncated = False
obs_next = self._get_obs()
return copy.deepcopy(obs_next), reward, terminated, truncated, {}
return GenericEnv()
def run_test(vector_env_type):
print(f"Testing Vector Env Type: {vector_env_type.__name__} ")
num_envs = 2
env = vector_env_type([make_generic_env for _ in range(num_envs)])
initial_obs, _ = env.reset()
policy = RandomActionPolicy(env.get_env_attr("action_space")[0])
buffer = VectorReplayBuffer(total_size=num_envs * 10, buffer_num=num_envs)
collector = Collector(policy=policy, env=env, buffer=buffer)
collector.collect(n_step=3 * num_envs, reset_before_collect=True)
print("\n Collected dictionary observations.")
# Inspect transitions
print("\n=== ReplayBuffer Contents ===")
print("Buffer observation: step variable [:5]")
print(buffer.obs.step[:5])
print("Buffer next observation: step variable [:5]")
print(buffer.obs_next.step[:5])
for vector_env_type in [ts.env.DummyVectorEnv, ts.env.ShmemVectorEnv]:
run_test(vector_env_type)The output will be something like:
Testing Vector Env Type: DummyVectorEnv =====++++=====
Collected dictionary observations.
=== ReplayBuffer Contents ===
Buffer observation: step variable [:5]
[[0]
[1]
[2]
[0]
[0]]
Buffer next observation: step variable
[[1]
[2]
[3]
[0]
[0]]
Testing Vector Env Type: ShmemVectorEnv =====++++=====
Collected dictionary observations.
=== ReplayBuffer Contents ===
Buffer observation: step variable [:5]
[[1]
[2]
[3]
[0]
[0]]
Buffer next observation: step variable
[[1]
[2]
[3]
[0]
[0]]
Therefore its clear that there is a inconsistent betweenShmemVectorEnv and DummyVectorEnv buffers.