Skip to content

Conversation

Copilot
Copy link
Contributor

@Copilot Copilot AI commented Sep 8, 2025

Fixes #2172

Problem

The environment checker was failing when Sequence observation spaces were nested inside composite spaces like Dict, Tuple, or OneOf. This caused assertion errors instead of the expected warnings and graceful handling.

import gymnasium as gym
from gymnasium import spaces
from stable_baselines3.common.env_checker import check_env

# This would fail with AssertionError
env = gym.Env()
env.observation_space = spaces.Dict({
    "seq": spaces.Sequence(spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32))
})
env.action_space = spaces.Discrete(2)
check_env(env)  # AssertionError: The observation returned by the `reset()` method should be a single value, not a tuple

Root Cause

The _check_unsupported_spaces() function properly detected Sequence spaces at the top level and set should_skip=True to bypass return value checks. However, it didn't recursively check for Sequence spaces within composite spaces, so nested Sequence spaces weren't detected and the checker tried to validate their return values, leading to failures.

Solution

Updated _check_unsupported_spaces() to recursively detect Sequence spaces within:

  • Dict spaces (checking all sub-spaces)
  • Tuple spaces (checking all elements)
  • OneOf spaces (checking all possible spaces)

When any nested Sequence space is found, the function now:

  1. Sets should_skip=True to bypass return value validation
  2. Emits the appropriate warning about Sequence spaces not being supported

Changes

  • Modified stable_baselines3/common/env_checker.py to add recursive Sequence space detection
  • Added comprehensive test cases covering all composite space types with both stack=True and stack=False
  • Updated changelog to document the bug fix

Testing

The fix now handles all the problematic cases from the original issue:

# All of these now pass with appropriate warnings instead of crashing
check_env(env_with_sequence_in_dict)    # ✓ PASSED
check_env(env_with_sequence_in_tuple)   # ✓ PASSED  
check_env(env_with_sequence_in_oneof)   # ✓ PASSED

Existing functionality remains unchanged - all existing tests continue to pass.


💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.

Co-authored-by: araffin <1973948+araffin@users.noreply.github.com>
@Copilot Copilot AI changed the title [WIP] Fix issue #2172 by updating the env checker (you can take inspiration from https://github.com/DLR-RM/stable-baselines3/commit/c40b5e461be734d733b8fa0c7e2b1e1451fd1c54 . Don't forget to update the changelog and the tests too Fix env checker to handle Sequence spaces in composite spaces (Dict, Tuple, OneOf) Sep 8, 2025
@Copilot Copilot AI requested a review from araffin September 8, 2025 15:03
Copilot finished work on behalf of araffin September 8, 2025 15:03
@araffin araffin marked this pull request as ready for review September 8, 2025 15:21
@araffin araffin merged commit 7883ed4 into master Sep 8, 2025
4 checks passed
@araffin araffin deleted the copilot/fix-ad8a84cb-6777-4415-9d16-878c05823070 branch September 8, 2025 15:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: putting Sequence space inside Dict, Tuple, or OneOf space will fail the env checker

2 participants