@@ -263,6 +263,64 @@ def render(self, **kwargs):
263263 raise InvalidArgument (
264264 "Available render modes: human, ansi, html, ipython" )
265265
266+ def train (self , agents = []):
267+ """
268+ Setup a lightweight training environment for a single agent.
269+ Note: This is designed to be a lightweight starting point which can
270+ be integrated with other frameworks (i.e. gym, stable-baselines).
271+
272+ Example:
273+ env = make("tictactoe")
274+ # Training agent in first position (player 1) against the default random agent.
275+ reset, step = env.train([None, "random"])
276+
277+ obs = reset()
278+ done = False
279+ while not done:
280+ action = 0 # Action for the agent being trained.
281+ obs, reward, done, info = step(action)
282+ env.render()
283+
284+ Args:
285+ agents (list): List of agents to obtain actions from while training.
286+ The agent to train (in position), should be set to "None".
287+
288+ Returns:
289+ `tuple`[0]: Reset def that reset the environment, then advances until the agents turn.
290+ `tuple`[1]: Steps using the agent action, then advance until agents turn again.
291+ """
292+ position = None
293+ for index , agent in enumerate (agents ):
294+ if agent == None :
295+ if position != None :
296+ raise InvalidArgument (
297+ "Only one agent can be marked 'None'" )
298+ position = index
299+
300+ if position == None :
301+ raise InvalidArgument ("One agent must be marked 'None' to train." )
302+
303+ def advance ():
304+ while not self .done and self .state [position ].status == "INACTIVE" :
305+ self .step (self .__get_actions (agents = self .agents ))
306+
307+ def reset ():
308+ self .reset (len (agents ))
309+ advance ()
310+ return self .state [position ].observation
311+
312+ def step (action ):
313+ self .step (self .__get_actions (agents = agents , none_action = action ))
314+ advance ()
315+ agent = self .state [position ]
316+ return [
317+ agent .observation , agent .reward , agent .status != "ACTIVE" , agent .info
318+ ]
319+
320+ reset ()
321+
322+ return (reset , step )
323+
266324 @property
267325 def name (self ):
268326 """str: The name from the specification."""
0 commit comments