Skip to content

Commit 877d9ee

Browse files
committed
update epstein 3 for rl
- update epstein rl for mesa 3.0
1 parent 30a3475 commit 877d9ee

File tree

4 files changed

+15
-20
lines changed

4 files changed

+15
-20
lines changed

rl/epstein_civil_violence/agent.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from mesa.examples.advanced.epstein_civil_violence.agents import Citizen, Cop
2-
3-
from .utility import move
2+
from utility import move
43

54

65
class CitizenRL(Citizen):

rl/epstein_civil_violence/model.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import gymnasium as gym
22
import mesa
33
import numpy as np
4+
from agent import CitizenRL, CopRL
45
from mesa.examples.advanced.epstein_civil_violence.model import EpsteinCivilViolence
56
from ray.rllib.env import MultiAgentEnv
6-
7-
from .agent import CitizenRL, CopRL
8-
from .utility import create_intial_agents, grid_to_observation
7+
from utility import create_intial_agents, grid_to_observation
98

109

1110
class EpsteinCivilViolenceRL(EpsteinCivilViolence, MultiAgentEnv):
@@ -88,7 +87,7 @@ def step(self, action_dict):
8887
self.action_dict = action_dict
8988

9089
# Step the model
91-
self.schedule.step()
90+
self.agents.shuffle_do("step")
9291
self.datacollector.collect(self)
9392

9493
# Calculate rewards
@@ -104,10 +103,10 @@ def step(self, action_dict):
104103
] # Get the values from the observation grid for the neighborhood cells
105104

106105
# RL specific outputs for the environment
107-
done = {a.unique_id: False for a in self.schedule.agents}
108-
truncated = {a.unique_id: False for a in self.schedule.agents}
106+
done = {a.unique_id: False for a in self.agents}
107+
truncated = {a.unique_id: False for a in self.agents}
109108
truncated["__all__"] = np.all(list(truncated.values()))
110-
if self.schedule.time > self.max_iters:
109+
if self.time > self.max_iters:
111110
done["__all__"] = True
112111
else:
113112
done["__all__"] = False
@@ -116,7 +115,7 @@ def step(self, action_dict):
116115

117116
def cal_reward(self):
118117
rewards = {}
119-
for agent in self.schedule.agents:
118+
for agent in self.agents:
120119
if isinstance(agent, CopRL):
121120
if agent.arrest_made:
122121
# Cop is rewarded for making an arrest
@@ -149,19 +148,17 @@ def reset(self, *, seed=None, options=None):
149148
"""
150149

151150
super().reset()
152-
# Using base scheduler to maintain the order of agents
153-
self.schedule = mesa.time.BaseScheduler(self)
154151
self.grid = mesa.space.SingleGrid(self.width, self.height, torus=True)
155152
create_intial_agents(self, CitizenRL, CopRL)
156153
grid_to_observation(self, CitizenRL)
157154
# Intialize action dictionary with no action
158-
self.action_dict = {a.unique_id: (0, 0) for a in self.schedule.agents}
155+
self.action_dict = {a.unique_id: (0, 0) for a in self.agents}
159156
# Update neighbors for observation space
160-
for agent in self.schedule.agents:
157+
for agent in self.agents:
161158
agent.update_neighbors()
162-
self.schedule.step()
159+
self.agents.shuffle_do("step")
163160
observation = {}
164-
for agent in self.schedule.agents:
161+
for agent in self.agents:
165162
observation[agent.unique_id] = [
166163
self.obs_grid[neighbor[0]][neighbor[1]]
167164
for neighbor in agent.neighborhood

rl/epstein_civil_violence/train_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import os
22

3+
from model import EpsteinCivilViolenceRL
34
from ray.rllib.algorithms.ppo import PPOConfig
45
from ray.rllib.policy.policy import PolicySpec
56

6-
from .model import EpsteinCivilViolenceRL
7-
87

98
# Configuration for the PPO algorithm
109
# You can change the configuration as per your requirements

rl/epstein_civil_violence/utility.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ def create_intial_agents(self, CitizenRL, CopRL):
3030
# Initializing cops then citizens
3131
# This ensures cops act out their step before citizens
3232
for cop in cops:
33-
self.schedule.add(cop)
33+
self.add(cop)
3434
for citizen in citizens:
35-
self.schedule.add(citizen)
35+
self.add(citizen)
3636

3737

3838
def grid_to_observation(self, CitizenRL):

0 commit comments

Comments
 (0)