13
13
plt .switch_backend ("agg" )
14
14
15
15
16
- class JupyterContainer :
17
- def __init__ (
18
- self ,
19
- model_class ,
20
- model_params ,
21
- measures = None ,
22
- name = "Mesa Model" ,
23
- agent_portrayal = None ,
24
- ):
25
- self .model_class = model_class
26
- self .split_model_params (model_params )
27
- self .measures = measures
28
- self .name = name
29
- self .agent_portrayal = agent_portrayal
30
- self .thread = None
31
-
32
- def split_model_params (self , model_params ):
33
- self .model_params_input = {}
34
- self .model_params_fixed = {}
35
- for k , v in model_params .items ():
36
- if self .check_param_is_fixed (v ):
37
- self .model_params_fixed [k ] = v
16
+ @solara .component
17
+ def JupyterViz (
18
+ model_class ,
19
+ model_params ,
20
+ measures = None ,
21
+ name = "Mesa Model" ,
22
+ agent_portrayal = None ,
23
+ space_drawer = None ,
24
+ play_interval = 400 ,
25
+ ):
26
+ current_step , set_current_step = solara .use_state (0 )
27
+
28
+ solara .Markdown (name )
29
+
30
+ # 0. Split model params
31
+ model_params_input , model_params_fixed = split_model_params (model_params )
32
+
33
+ # 1. User inputs
34
+ user_inputs = {}
35
+ for k , v in model_params_input .items ():
36
+ user_input = solara .use_reactive (v ["value" ])
37
+ user_inputs [k ] = user_input .value
38
+ make_user_input (user_input , k , v )
39
+
40
+ # 2. Model
41
+ def make_model ():
42
+ return model_class (** user_inputs , ** model_params_fixed )
43
+
44
+ model = solara .use_memo (make_model , dependencies = list (user_inputs .values ()))
45
+
46
+ # 3. Buttons
47
+ ModelController (model , play_interval , current_step , set_current_step )
48
+
49
+ with solara .GridFixed (columns = 2 ):
50
+ # 4. Space
51
+ if space_drawer is None :
52
+ make_space (model , agent_portrayal )
53
+ else :
54
+ space_drawer (model , agent_portrayal )
55
+ # 5. Plots
56
+ for measure in measures :
57
+ if callable (measure ):
58
+ # Is a custom object
59
+ measure (model )
38
60
else :
39
- self .model_params_input [k ] = v
61
+ make_plot (model , measure )
62
+
40
63
41
- def check_param_is_fixed (self , param ):
42
- if not isinstance (param , dict ):
43
- return True
44
- if "type" not in param :
45
- return True
64
+ @solara .component
65
+ def ModelController (model , play_interval , current_step , set_current_step ):
66
+ playing = solara .use_reactive (False )
67
+ thread = solara .use_reactive (None )
68
+
69
+ def on_value_play (change ):
70
+ if model .running :
71
+ do_step ()
72
+ else :
73
+ playing .value = False
46
74
47
- def do_step (self ):
48
- self . model .step ()
49
- self . set_df ( self . model .datacollector . get_model_vars_dataframe () )
75
+ def do_step ():
76
+ model .step ()
77
+ set_current_step ( model .schedule . steps )
50
78
51
- def do_play (self ):
52
- self . model .running = True
53
- while self . model .running :
54
- self . do_step ()
79
+ def do_play ():
80
+ model .running = True
81
+ while model .running :
82
+ do_step ()
55
83
56
- def threaded_do_play (self ):
57
- if self . thread is not None and self . thread .is_alive ():
84
+ def threaded_do_play ():
85
+ if thread is not None and thread .is_alive ():
58
86
return
59
- self . thread = threading .Thread (target = self . do_play )
60
- self . thread .start ()
87
+ thread . value = threading .Thread (target = do_play )
88
+ thread .start ()
61
89
62
- def do_pause (self ):
63
- if (self . thread is None ) or (not self . thread .is_alive ()):
90
+ def do_pause ():
91
+ if (thread is None ) or (not thread .is_alive ()):
64
92
return
65
- self . model .running = False
66
- self . thread .join ()
93
+ model .running = False
94
+ thread .join ()
67
95
68
- def portray (self , g ):
96
+ with solara .Row ():
97
+ solara .Button (label = "Step" , color = "primary" , on_click = do_step )
98
+ # This style is necessary so that the play widget has almost the same
99
+ # height as typical Solara buttons.
100
+ solara .Style (
101
+ """
102
+ .widget-play {
103
+ height: 30px;
104
+ }
105
+ """
106
+ )
107
+ widgets .Play (
108
+ value = 0 ,
109
+ interval = play_interval ,
110
+ repeat = True ,
111
+ show_repeat = False ,
112
+ on_value = on_value_play ,
113
+ playing = playing .value ,
114
+ on_playing = playing .set ,
115
+ )
116
+ solara .Markdown (md_text = f"**Step:** { current_step } " )
117
+ # threaded_do_play is not used for now because it
118
+ # doesn't work in Google colab. We use
119
+ # ipywidgets.Play until it is fixed. The threading
120
+ # version is definite a much better implementation,
121
+ # if it works.
122
+ # solara.Button(label="▶", color="primary", on_click=viz.threaded_do_play)
123
+ # solara.Button(label="⏸︎", color="primary", on_click=viz.do_pause)
124
+ # solara.Button(label="Reset", color="primary", on_click=do_reset)
125
+
126
+
127
+ def split_model_params (model_params ):
128
+ model_params_input = {}
129
+ model_params_fixed = {}
130
+ for k , v in model_params .items ():
131
+ if check_param_is_fixed (v ):
132
+ model_params_fixed [k ] = v
133
+ else :
134
+ model_params_input [k ] = v
135
+ return model_params_input , model_params_fixed
136
+
137
+
138
+ def check_param_is_fixed (param ):
139
+ if not isinstance (param , dict ):
140
+ return True
141
+ if "type" not in param :
142
+ return True
143
+
144
+
145
+ def make_user_input (user_input , k , v ):
146
+ if v ["type" ] == "SliderInt" :
147
+ solara .SliderInt (
148
+ v .get ("label" , "label" ),
149
+ value = user_input ,
150
+ min = v .get ("min" ),
151
+ max = v .get ("max" ),
152
+ step = v .get ("step" ),
153
+ )
154
+ elif v ["type" ] == "SliderFloat" :
155
+ solara .SliderFloat (
156
+ v .get ("label" , "label" ),
157
+ value = user_input ,
158
+ min = v .get ("min" ),
159
+ max = v .get ("max" ),
160
+ step = v .get ("step" ),
161
+ )
162
+ elif v ["type" ] == "Select" :
163
+ solara .Select (
164
+ v .get ("label" , "label" ),
165
+ value = v .get ("value" ),
166
+ values = v .get ("values" ),
167
+ )
168
+
169
+
170
+ def make_space (model , agent_portrayal ):
171
+ def portray (g ):
69
172
x = []
70
173
y = []
71
174
s = [] # size
@@ -79,7 +182,7 @@ def portray(self, g):
79
182
# Is a single grid
80
183
content = [content ]
81
184
for agent in content :
82
- data = self . agent_portrayal (agent )
185
+ data = agent_portrayal (agent )
83
186
x .append (i )
84
187
y .append (j )
85
188
if "size" in data :
@@ -93,159 +196,40 @@ def portray(self, g):
93
196
out ["c" ] = c
94
197
return out
95
198
199
+ space_fig = Figure ()
200
+ space_ax = space_fig .subplots ()
201
+ if isinstance (model .grid , mesa .space .NetworkGrid ):
202
+ _draw_network_grid (model , space_ax , agent_portrayal )
203
+ else :
204
+ space_ax .scatter (** portray (model .grid ))
205
+ space_ax .set_axis_off ()
206
+ solara .FigureMatplotlib (space_fig )
207
+
96
208
97
- def _draw_network_grid (viz , space_ax ):
98
- graph = viz . model .grid .G
209
+ def _draw_network_grid (model , space_ax , agent_portrayal ):
210
+ graph = model .grid .G
99
211
pos = nx .spring_layout (graph , seed = 0 )
100
212
nx .draw (
101
213
graph ,
102
214
ax = space_ax ,
103
215
pos = pos ,
104
- ** viz . agent_portrayal (graph ),
216
+ ** agent_portrayal (graph ),
105
217
)
106
218
107
219
108
- def make_space (viz ):
109
- space_fig = Figure ()
110
- space_ax = space_fig .subplots ()
111
- if isinstance (viz .model .grid , mesa .space .NetworkGrid ):
112
- _draw_network_grid (viz , space_ax )
113
- else :
114
- space_ax .scatter (** viz .portray (viz .model .grid ))
115
- space_ax .set_axis_off ()
116
- solara .FigureMatplotlib (space_fig , dependencies = [viz .model , viz .df ])
117
-
118
-
119
- def make_plot (viz , measure ):
220
+ def make_plot (model , measure ):
120
221
fig = Figure ()
121
222
ax = fig .subplots ()
122
- ax .plot (viz .df .loc [:, measure ])
223
+ df = model .datacollector .get_model_vars_dataframe ()
224
+ ax .plot (df .loc [:, measure ])
123
225
ax .set_ylabel (measure )
124
226
# Set integer x axis
125
227
ax .xaxis .set_major_locator (MaxNLocator (integer = True ))
126
- solara .FigureMatplotlib (fig , dependencies = [ viz . model , viz . df ] )
228
+ solara .FigureMatplotlib (fig )
127
229
128
230
129
231
def make_text (renderer ):
130
- def function (viz ):
131
- solara .Markdown (renderer (viz . model ))
232
+ def function (model ):
233
+ solara .Markdown (renderer (model ))
132
234
133
235
return function
134
-
135
-
136
- def make_user_input (user_input , k , v ):
137
- if v ["type" ] == "SliderInt" :
138
- solara .SliderInt (
139
- v .get ("label" , "label" ),
140
- value = user_input ,
141
- min = v .get ("min" ),
142
- max = v .get ("max" ),
143
- step = v .get ("step" ),
144
- )
145
- elif v ["type" ] == "SliderFloat" :
146
- solara .SliderFloat (
147
- v .get ("label" , "label" ),
148
- value = user_input ,
149
- min = v .get ("min" ),
150
- max = v .get ("max" ),
151
- step = v .get ("step" ),
152
- )
153
- elif v ["type" ] == "Select" :
154
- solara .Select (
155
- v .get ("label" , "label" ),
156
- value = v .get ("value" ),
157
- values = v .get ("values" ),
158
- )
159
-
160
-
161
- @solara .component
162
- def MesaComponent (viz , space_drawer = None , play_interval = 400 ):
163
- solara .Markdown (viz .name )
164
-
165
- # 1. User inputs
166
- user_inputs = {}
167
- for k , v in viz .model_params_input .items ():
168
- user_input = solara .use_reactive (v ["value" ])
169
- user_inputs [k ] = user_input .value
170
- make_user_input (user_input , k , v )
171
-
172
- # 2. Model
173
- def make_model ():
174
- return viz .model_class (** user_inputs , ** viz .model_params_fixed )
175
-
176
- viz .model = solara .use_memo (make_model , dependencies = list (user_inputs .values ()))
177
- viz .df , viz .set_df = solara .use_state (
178
- viz .model .datacollector .get_model_vars_dataframe ()
179
- )
180
-
181
- # 3. Buttons
182
- playing = solara .use_reactive (False )
183
-
184
- def on_value_play (change ):
185
- if viz .model .running :
186
- viz .do_step ()
187
- else :
188
- playing .value = False
189
-
190
- with solara .Row ():
191
- solara .Button (label = "Step" , color = "primary" , on_click = viz .do_step )
192
- # This style is necessary so that the play widget has almost the same
193
- # height as typical Solara buttons.
194
- solara .Style (
195
- """
196
- .widget-play {
197
- height: 30px;
198
- }
199
- """
200
- )
201
- widgets .Play (
202
- value = 0 ,
203
- interval = play_interval ,
204
- repeat = True ,
205
- show_repeat = False ,
206
- on_value = on_value_play ,
207
- playing = playing .value ,
208
- on_playing = playing .set ,
209
- )
210
- solara .Markdown (md_text = f"**Step:** { viz .model .schedule .steps } " )
211
- # threaded_do_play is not used for now because it
212
- # doesn't work in Google colab. We use
213
- # ipywidgets.Play until it is fixed. The threading
214
- # version is definite a much better implementation,
215
- # if it works.
216
- # solara.Button(label="▶", color="primary", on_click=viz.threaded_do_play)
217
- # solara.Button(label="⏸︎", color="primary", on_click=viz.do_pause)
218
- # solara.Button(label="Reset", color="primary", on_click=do_reset)
219
-
220
- with solara .GridFixed (columns = 2 ):
221
- # 4. Space
222
- if space_drawer is None :
223
- make_space (viz )
224
- else :
225
- space_drawer (viz )
226
- # 5. Plots
227
- for measure in viz .measures :
228
- if callable (measure ):
229
- # Is a custom object
230
- measure (viz )
231
- else :
232
- make_plot (viz , measure )
233
-
234
-
235
- # JupyterViz has to be a Solara component, so that each browser tabs runs in
236
- # their own, separate simulation thread. See https://github.com/projectmesa/mesa/issues/856.
237
- @solara .component
238
- def JupyterViz (
239
- model_class ,
240
- model_params ,
241
- measures = None ,
242
- name = "Mesa Model" ,
243
- agent_portrayal = None ,
244
- space_drawer = None ,
245
- play_interval = 400 ,
246
- ):
247
- return MesaComponent (
248
- JupyterContainer (model_class , model_params , measures , name , agent_portrayal ),
249
- space_drawer = space_drawer ,
250
- play_interval = play_interval ,
251
- )
0 commit comments