Skip to content

Questions about CNN policy input channel [question] #1195

@DavidLudl

Description

@DavidLudl

Hello,

I am learning how to implement the costum CNN policy and environment with the stablebaseline 3. I am following the example "Custom Feature Extractor" in this link:
https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html

I am confusing about the channel, which is defined as observation_space.shape[0]. When I am examing the oberservation space of gymnasium:

import gymnasium as gym

env = gym.make("BreakoutNoFrameskip-v4")
print("Observation Space Shape: ", env.observation_space.shape)
print("Image Channel: ", env.observation_space.shape[0])

The output is

Observation Space Shape:  (210, 160, 3)
Image Channel:  210

But when I excute the code in the link. There is no error. But if I pass the last item of observation space shape, n_input_channels = observation_space.shape[2], which I suppose the correct channel size. The error raised. So I want to ask, whether the SB3 reshuffle the observation space shape? And when I define my own ENV, should I set the space shape C * H * W or H * W * C (where should I put the channel)?

Thank you for your time and help.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions