Skip to content

Buffer observation equals to next observation using ShmemVectorEnv #1281

@heitor57

Description

@heitor57

Hi, thanks for the great work.

I've noticed an issue where obs and obs_next end up being equal even when they shouldn’t, after collecting data using a ShmemVectorEnv environment with complex observation structures.

This problem seems specific to ShmemVectorEnv. Unlike DummyVectorEnv, when the environment returns complex observation objects, the shared memory mechanism appears to entangle object references in a way that even copy.deepcopy (as suggested in the documentation) doesn’t fully resolve.

As a result, environments with complex observation structures cause ShmemVectorEnv to store obs and obs_next with the same values — meaning obs is overwritten and no longer reflects its true state.

I’ve written a small piece of code that demonstrates this issue clearly and helps reproduce the problem.

import copy
import numpy as np
import gymnasium as gym
import tianshou as ts
from typing import Dict, Any
from tianshou.algorithm.algorithm_base import RandomActionPolicy
from tianshou.data.buffer.vecbuf import VectorReplayBuffer
from tianshou.data import Collector


def make_generic_env():
    class GenericEnv(gym.Env):
        def __init__(self):
            super().__init__()
            self.num_features = 5
            self.entity_id = 0
            self.step_count = 0
            self.observation_space = gym.spaces.Dict(
                {
                    "id": gym.spaces.Box(low=0, high=100, shape=(1,), dtype=np.int32),
                    "step": gym.spaces.Box(low=0, high=100, shape=(1,), dtype=np.int32),
                    "features": gym.spaces.Box(
                        low=0, high=9999, shape=(self.num_features,), dtype=np.int32
                    ),
                }
            )
            self.action_space = gym.spaces.Discrete(self.num_features)

        def _get_obs(self) -> Dict[str, Any]:
            return {
                "id": np.array([self.entity_id], dtype=np.int32),
                "step": np.array([self.step_count], dtype=np.int32),
                "features": np.arange(self.num_features, dtype=np.int32)
                + self.step_count,
            }

        def reset(self, *, seed=None, options=None):
            self.step_count = 0
            obs = self._get_obs()
            return copy.deepcopy(obs), {}

        def step(self, action):
            self.step_count += 1
            reward = float(action)
            terminated = self.step_count >= 3
            truncated = False
            obs_next = self._get_obs()
            return copy.deepcopy(obs_next), reward, terminated, truncated, {}

    return GenericEnv()


def run_test(vector_env_type):
    print(f"Testing Vector Env Type: {vector_env_type.__name__} ")
    num_envs = 2
    env = vector_env_type([make_generic_env for _ in range(num_envs)])
    initial_obs, _ = env.reset()

    policy = RandomActionPolicy(env.get_env_attr("action_space")[0])
    buffer = VectorReplayBuffer(total_size=num_envs * 10, buffer_num=num_envs)
    collector = Collector(policy=policy, env=env, buffer=buffer)

    collector.collect(n_step=3 * num_envs, reset_before_collect=True)
    print("\n Collected dictionary observations.")

    # Inspect transitions
    print("\n=== ReplayBuffer Contents ===")
    print("Buffer observation: step variable [:5]")
    print(buffer.obs.step[:5])
    print("Buffer next observation: step variable [:5]")
    print(buffer.obs_next.step[:5])


for vector_env_type in [ts.env.DummyVectorEnv, ts.env.ShmemVectorEnv]:
    run_test(vector_env_type)

The output will be something like:

Testing Vector Env Type: DummyVectorEnv =====++++===== 

 Collected dictionary observations.

=== ReplayBuffer Contents ===
Buffer observation: step variable [:5]
[[0]
 [1]
 [2]
 [0]
 [0]]
Buffer next observation: step variable
[[1]
 [2]
 [3]
 [0]
 [0]]
Testing Vector Env Type: ShmemVectorEnv =====++++===== 
 Collected dictionary observations.

=== ReplayBuffer Contents ===
Buffer observation: step variable [:5]
[[1]
 [2]
 [3]
 [0]
 [0]]
Buffer next observation: step variable
[[1]
 [2]
 [3]
 [0]
 [0]]

Therefore its clear that there is a inconsistent betweenShmemVectorEnv and DummyVectorEnv buffers.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions