Skip to content

Commit 9d35dc8

Browse files
authored
fix render popping when render is called internally for Wrappers (#2998)
* fix render popping when render is called internally * move _render workaround from Wrapper to PixelObservation
1 parent cb3df61 commit 9d35dc8

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

gym/core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,9 @@ def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]:
426426
"""Resets the environment with kwargs."""
427427
return self.env.reset(**kwargs)
428428

429-
def render(self, *args, **kwargs):
429+
def render(
430+
self, *args, **kwargs
431+
) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
430432
"""Renders the environment."""
431433
return self.env.render(*args, **kwargs)
432434

gym/wrappers/pixel_observation.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181

8282
# Avoid side-effects that occur when render_kwargs is manipulated
8383
render_kwargs = copy.deepcopy(render_kwargs)
84+
self.render_history = []
8485

8586
if render_kwargs is None:
8687
render_kwargs = {}
@@ -135,7 +136,7 @@ def __init__(
135136
self.env.reset()
136137
pixels_spaces = {}
137138
for pixel_key in pixel_keys:
138-
pixels = self.env.render(**render_kwargs[pixel_key])
139+
pixels = self._render(**render_kwargs[pixel_key])
139140
pixels: np.ndarray = pixels[-1] if isinstance(pixels, List) else pixels
140141

141142
if not hasattr(pixels, "dtype") or not hasattr(pixels, "shape"):
@@ -184,10 +185,24 @@ def _add_pixel_observation(self, wrapped_observation):
184185
observation[STATE_KEY] = wrapped_observation
185186

186187
pixel_observations = {
187-
pixel_key: self.env.render(**self._render_kwargs[pixel_key])
188+
pixel_key: self._render(**self._render_kwargs[pixel_key])
188189
for pixel_key in self._pixel_keys
189190
}
190191

191192
observation.update(pixel_observations)
192193

193194
return observation
195+
196+
def render(self, *args, **kwargs):
197+
"""Renders the environment."""
198+
render = self.env.render(*args, **kwargs)
199+
if isinstance(render, list):
200+
render = self.render_history + render
201+
self.render_history = []
202+
return render
203+
204+
def _render(self, *args, **kwargs):
205+
render = self.env.render(*args, **kwargs)
206+
if isinstance(render, list):
207+
self.render_history += render
208+
return render

0 commit comments

Comments
 (0)