Skip to content

Commit fb81c1a

Browse files
authored
Simplify solara code (#1786)
* Simplify solara code * re-introduced backend switch and removed self
1 parent ea4b213 commit fb81c1a

File tree

1 file changed

+166
-182
lines changed

1 file changed

+166
-182
lines changed

mesa/experimental/jupyter_viz.py

Lines changed: 166 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -13,59 +13,162 @@
1313
plt.switch_backend("agg")
1414

1515

16-
class JupyterContainer:
17-
def __init__(
18-
self,
19-
model_class,
20-
model_params,
21-
measures=None,
22-
name="Mesa Model",
23-
agent_portrayal=None,
24-
):
25-
self.model_class = model_class
26-
self.split_model_params(model_params)
27-
self.measures = measures
28-
self.name = name
29-
self.agent_portrayal = agent_portrayal
30-
self.thread = None
31-
32-
def split_model_params(self, model_params):
33-
self.model_params_input = {}
34-
self.model_params_fixed = {}
35-
for k, v in model_params.items():
36-
if self.check_param_is_fixed(v):
37-
self.model_params_fixed[k] = v
16+
@solara.component
17+
def JupyterViz(
18+
model_class,
19+
model_params,
20+
measures=None,
21+
name="Mesa Model",
22+
agent_portrayal=None,
23+
space_drawer=None,
24+
play_interval=400,
25+
):
26+
current_step, set_current_step = solara.use_state(0)
27+
28+
solara.Markdown(name)
29+
30+
# 0. Split model params
31+
model_params_input, model_params_fixed = split_model_params(model_params)
32+
33+
# 1. User inputs
34+
user_inputs = {}
35+
for k, v in model_params_input.items():
36+
user_input = solara.use_reactive(v["value"])
37+
user_inputs[k] = user_input.value
38+
make_user_input(user_input, k, v)
39+
40+
# 2. Model
41+
def make_model():
42+
return model_class(**user_inputs, **model_params_fixed)
43+
44+
model = solara.use_memo(make_model, dependencies=list(user_inputs.values()))
45+
46+
# 3. Buttons
47+
ModelController(model, play_interval, current_step, set_current_step)
48+
49+
with solara.GridFixed(columns=2):
50+
# 4. Space
51+
if space_drawer is None:
52+
make_space(model, agent_portrayal)
53+
else:
54+
space_drawer(model, agent_portrayal)
55+
# 5. Plots
56+
for measure in measures:
57+
if callable(measure):
58+
# Is a custom object
59+
measure(model)
3860
else:
39-
self.model_params_input[k] = v
61+
make_plot(model, measure)
62+
4063

41-
def check_param_is_fixed(self, param):
42-
if not isinstance(param, dict):
43-
return True
44-
if "type" not in param:
45-
return True
64+
@solara.component
65+
def ModelController(model, play_interval, current_step, set_current_step):
66+
playing = solara.use_reactive(False)
67+
thread = solara.use_reactive(None)
68+
69+
def on_value_play(change):
70+
if model.running:
71+
do_step()
72+
else:
73+
playing.value = False
4674

47-
def do_step(self):
48-
self.model.step()
49-
self.set_df(self.model.datacollector.get_model_vars_dataframe())
75+
def do_step():
76+
model.step()
77+
set_current_step(model.schedule.steps)
5078

51-
def do_play(self):
52-
self.model.running = True
53-
while self.model.running:
54-
self.do_step()
79+
def do_play():
80+
model.running = True
81+
while model.running:
82+
do_step()
5583

56-
def threaded_do_play(self):
57-
if self.thread is not None and self.thread.is_alive():
84+
def threaded_do_play():
85+
if thread is not None and thread.is_alive():
5886
return
59-
self.thread = threading.Thread(target=self.do_play)
60-
self.thread.start()
87+
thread.value = threading.Thread(target=do_play)
88+
thread.start()
6189

62-
def do_pause(self):
63-
if (self.thread is None) or (not self.thread.is_alive()):
90+
def do_pause():
91+
if (thread is None) or (not thread.is_alive()):
6492
return
65-
self.model.running = False
66-
self.thread.join()
93+
model.running = False
94+
thread.join()
6795

68-
def portray(self, g):
96+
with solara.Row():
97+
solara.Button(label="Step", color="primary", on_click=do_step)
98+
# This style is necessary so that the play widget has almost the same
99+
# height as typical Solara buttons.
100+
solara.Style(
101+
"""
102+
.widget-play {
103+
height: 30px;
104+
}
105+
"""
106+
)
107+
widgets.Play(
108+
value=0,
109+
interval=play_interval,
110+
repeat=True,
111+
show_repeat=False,
112+
on_value=on_value_play,
113+
playing=playing.value,
114+
on_playing=playing.set,
115+
)
116+
solara.Markdown(md_text=f"**Step:** {current_step}")
117+
# threaded_do_play is not used for now because it
118+
# doesn't work in Google colab. We use
119+
# ipywidgets.Play until it is fixed. The threading
120+
# version is definite a much better implementation,
121+
# if it works.
122+
# solara.Button(label="▶", color="primary", on_click=viz.threaded_do_play)
123+
# solara.Button(label="⏸︎", color="primary", on_click=viz.do_pause)
124+
# solara.Button(label="Reset", color="primary", on_click=do_reset)
125+
126+
127+
def split_model_params(model_params):
128+
model_params_input = {}
129+
model_params_fixed = {}
130+
for k, v in model_params.items():
131+
if check_param_is_fixed(v):
132+
model_params_fixed[k] = v
133+
else:
134+
model_params_input[k] = v
135+
return model_params_input, model_params_fixed
136+
137+
138+
def check_param_is_fixed(param):
139+
if not isinstance(param, dict):
140+
return True
141+
if "type" not in param:
142+
return True
143+
144+
145+
def make_user_input(user_input, k, v):
146+
if v["type"] == "SliderInt":
147+
solara.SliderInt(
148+
v.get("label", "label"),
149+
value=user_input,
150+
min=v.get("min"),
151+
max=v.get("max"),
152+
step=v.get("step"),
153+
)
154+
elif v["type"] == "SliderFloat":
155+
solara.SliderFloat(
156+
v.get("label", "label"),
157+
value=user_input,
158+
min=v.get("min"),
159+
max=v.get("max"),
160+
step=v.get("step"),
161+
)
162+
elif v["type"] == "Select":
163+
solara.Select(
164+
v.get("label", "label"),
165+
value=v.get("value"),
166+
values=v.get("values"),
167+
)
168+
169+
170+
def make_space(model, agent_portrayal):
171+
def portray(g):
69172
x = []
70173
y = []
71174
s = [] # size
@@ -79,7 +182,7 @@ def portray(self, g):
79182
# Is a single grid
80183
content = [content]
81184
for agent in content:
82-
data = self.agent_portrayal(agent)
185+
data = agent_portrayal(agent)
83186
x.append(i)
84187
y.append(j)
85188
if "size" in data:
@@ -93,159 +196,40 @@ def portray(self, g):
93196
out["c"] = c
94197
return out
95198

199+
space_fig = Figure()
200+
space_ax = space_fig.subplots()
201+
if isinstance(model.grid, mesa.space.NetworkGrid):
202+
_draw_network_grid(model, space_ax, agent_portrayal)
203+
else:
204+
space_ax.scatter(**portray(model.grid))
205+
space_ax.set_axis_off()
206+
solara.FigureMatplotlib(space_fig)
207+
96208

97-
def _draw_network_grid(viz, space_ax):
98-
graph = viz.model.grid.G
209+
def _draw_network_grid(model, space_ax, agent_portrayal):
210+
graph = model.grid.G
99211
pos = nx.spring_layout(graph, seed=0)
100212
nx.draw(
101213
graph,
102214
ax=space_ax,
103215
pos=pos,
104-
**viz.agent_portrayal(graph),
216+
**agent_portrayal(graph),
105217
)
106218

107219

108-
def make_space(viz):
109-
space_fig = Figure()
110-
space_ax = space_fig.subplots()
111-
if isinstance(viz.model.grid, mesa.space.NetworkGrid):
112-
_draw_network_grid(viz, space_ax)
113-
else:
114-
space_ax.scatter(**viz.portray(viz.model.grid))
115-
space_ax.set_axis_off()
116-
solara.FigureMatplotlib(space_fig, dependencies=[viz.model, viz.df])
117-
118-
119-
def make_plot(viz, measure):
220+
def make_plot(model, measure):
120221
fig = Figure()
121222
ax = fig.subplots()
122-
ax.plot(viz.df.loc[:, measure])
223+
df = model.datacollector.get_model_vars_dataframe()
224+
ax.plot(df.loc[:, measure])
123225
ax.set_ylabel(measure)
124226
# Set integer x axis
125227
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
126-
solara.FigureMatplotlib(fig, dependencies=[viz.model, viz.df])
228+
solara.FigureMatplotlib(fig)
127229

128230

129231
def make_text(renderer):
130-
def function(viz):
131-
solara.Markdown(renderer(viz.model))
232+
def function(model):
233+
solara.Markdown(renderer(model))
132234

133235
return function
134-
135-
136-
def make_user_input(user_input, k, v):
137-
if v["type"] == "SliderInt":
138-
solara.SliderInt(
139-
v.get("label", "label"),
140-
value=user_input,
141-
min=v.get("min"),
142-
max=v.get("max"),
143-
step=v.get("step"),
144-
)
145-
elif v["type"] == "SliderFloat":
146-
solara.SliderFloat(
147-
v.get("label", "label"),
148-
value=user_input,
149-
min=v.get("min"),
150-
max=v.get("max"),
151-
step=v.get("step"),
152-
)
153-
elif v["type"] == "Select":
154-
solara.Select(
155-
v.get("label", "label"),
156-
value=v.get("value"),
157-
values=v.get("values"),
158-
)
159-
160-
161-
@solara.component
162-
def MesaComponent(viz, space_drawer=None, play_interval=400):
163-
solara.Markdown(viz.name)
164-
165-
# 1. User inputs
166-
user_inputs = {}
167-
for k, v in viz.model_params_input.items():
168-
user_input = solara.use_reactive(v["value"])
169-
user_inputs[k] = user_input.value
170-
make_user_input(user_input, k, v)
171-
172-
# 2. Model
173-
def make_model():
174-
return viz.model_class(**user_inputs, **viz.model_params_fixed)
175-
176-
viz.model = solara.use_memo(make_model, dependencies=list(user_inputs.values()))
177-
viz.df, viz.set_df = solara.use_state(
178-
viz.model.datacollector.get_model_vars_dataframe()
179-
)
180-
181-
# 3. Buttons
182-
playing = solara.use_reactive(False)
183-
184-
def on_value_play(change):
185-
if viz.model.running:
186-
viz.do_step()
187-
else:
188-
playing.value = False
189-
190-
with solara.Row():
191-
solara.Button(label="Step", color="primary", on_click=viz.do_step)
192-
# This style is necessary so that the play widget has almost the same
193-
# height as typical Solara buttons.
194-
solara.Style(
195-
"""
196-
.widget-play {
197-
height: 30px;
198-
}
199-
"""
200-
)
201-
widgets.Play(
202-
value=0,
203-
interval=play_interval,
204-
repeat=True,
205-
show_repeat=False,
206-
on_value=on_value_play,
207-
playing=playing.value,
208-
on_playing=playing.set,
209-
)
210-
solara.Markdown(md_text=f"**Step:** {viz.model.schedule.steps}")
211-
# threaded_do_play is not used for now because it
212-
# doesn't work in Google colab. We use
213-
# ipywidgets.Play until it is fixed. The threading
214-
# version is definite a much better implementation,
215-
# if it works.
216-
# solara.Button(label="▶", color="primary", on_click=viz.threaded_do_play)
217-
# solara.Button(label="⏸︎", color="primary", on_click=viz.do_pause)
218-
# solara.Button(label="Reset", color="primary", on_click=do_reset)
219-
220-
with solara.GridFixed(columns=2):
221-
# 4. Space
222-
if space_drawer is None:
223-
make_space(viz)
224-
else:
225-
space_drawer(viz)
226-
# 5. Plots
227-
for measure in viz.measures:
228-
if callable(measure):
229-
# Is a custom object
230-
measure(viz)
231-
else:
232-
make_plot(viz, measure)
233-
234-
235-
# JupyterViz has to be a Solara component, so that each browser tabs runs in
236-
# their own, separate simulation thread. See https://github.com/projectmesa/mesa/issues/856.
237-
@solara.component
238-
def JupyterViz(
239-
model_class,
240-
model_params,
241-
measures=None,
242-
name="Mesa Model",
243-
agent_portrayal=None,
244-
space_drawer=None,
245-
play_interval=400,
246-
):
247-
return MesaComponent(
248-
JupyterContainer(model_class, model_params, measures, name, agent_portrayal),
249-
space_drawer=space_drawer,
250-
play_interval=play_interval,
251-
)

0 commit comments

Comments
 (0)