@@ -81,6 +81,7 @@ def __init__(
81
81
82
82
# Avoid side-effects that occur when render_kwargs is manipulated
83
83
render_kwargs = copy .deepcopy (render_kwargs )
84
+ self .render_history = []
84
85
85
86
if render_kwargs is None :
86
87
render_kwargs = {}
@@ -135,7 +136,7 @@ def __init__(
135
136
self .env .reset ()
136
137
pixels_spaces = {}
137
138
for pixel_key in pixel_keys :
138
- pixels = self .env . render (** render_kwargs [pixel_key ])
139
+ pixels = self ._render (** render_kwargs [pixel_key ])
139
140
pixels : np .ndarray = pixels [- 1 ] if isinstance (pixels , List ) else pixels
140
141
141
142
if not hasattr (pixels , "dtype" ) or not hasattr (pixels , "shape" ):
@@ -184,10 +185,24 @@ def _add_pixel_observation(self, wrapped_observation):
184
185
observation [STATE_KEY ] = wrapped_observation
185
186
186
187
pixel_observations = {
187
- pixel_key : self .env . render (** self ._render_kwargs [pixel_key ])
188
+ pixel_key : self ._render (** self ._render_kwargs [pixel_key ])
188
189
for pixel_key in self ._pixel_keys
189
190
}
190
191
191
192
observation .update (pixel_observations )
192
193
193
194
return observation
195
+
196
+ def render (self , * args , ** kwargs ):
197
+ """Renders the environment."""
198
+ render = self .env .render (* args , ** kwargs )
199
+ if isinstance (render , list ):
200
+ render = self .render_history + render
201
+ self .render_history = []
202
+ return render
203
+
204
+ def _render (self , * args , ** kwargs ):
205
+ render = self .env .render (* args , ** kwargs )
206
+ if isinstance (render , list ):
207
+ self .render_history += render
208
+ return render
0 commit comments