Skip to content

Commit d451ef0

Browse files
authored
Add UserInputs component (#1788)
* Add UserInputs component * update tests * remove global state dependency from UserInputs * Simplify UserInputs change handler
1 parent 9572e65 commit d451ef0

File tree

2 files changed

+99
-71
lines changed

2 files changed

+99
-71
lines changed

mesa/experimental/jupyter_viz.py

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,26 @@ def JupyterViz(
3939

4040
current_step, set_current_step = solara.use_state(0)
4141

42-
solara.Markdown(name)
43-
44-
# 0. Split model params
45-
model_params_input, model_params_fixed = split_model_params(model_params)
46-
47-
# 1. User inputs
48-
user_inputs = {}
49-
for name, options in model_params_input.items():
50-
user_input = solara.use_reactive(options["value"])
51-
user_inputs[name] = user_input.value
52-
make_user_input(user_input, name, options)
42+
# 1. Set up model parameters
43+
user_params, fixed_params = split_model_params(model_params)
44+
model_parameters, set_model_parameters = solara.use_state(
45+
fixed_params | {k: v["value"] for k, v in user_params.items()}
46+
)
5347

54-
# 2. Model
48+
# 2. Set up Model
5549
def make_model():
56-
return model_class(**user_inputs, **model_params_fixed)
50+
model = model_class(**model_parameters)
51+
set_current_step(0)
52+
return model
5753

58-
model = solara.use_memo(make_model, dependencies=list(user_inputs.values()))
54+
model = solara.use_memo(make_model, dependencies=list(model_parameters.values()))
5955

60-
# 3. Buttons
56+
def handle_change_model_params(name: str, value: any):
57+
set_model_parameters(model_parameters | {name: value})
58+
59+
# 3. Set up UI
60+
solara.Markdown(name)
61+
UserInputs(user_params, on_change=handle_change_model_params)
6162
ModelController(model, play_interval, current_step, set_current_step)
6263

6364
with solara.GridFixed(columns=2):
@@ -160,44 +161,53 @@ def check_param_is_fixed(param):
160161
return True
161162

162163

163-
def make_user_input(user_input, name, options):
164-
"""Initialize a user input for configurable model parameters.
164+
@solara.component
165+
def UserInputs(user_params, on_change=None):
166+
"""Initialize user inputs for configurable model parameters.
165167
Currently supports :class:`solara.SliderInt`, :class:`solara.SliderFloat`,
166168
and :class:`solara.Select`.
167169
168-
Args:
169-
user_input: :class:`solara.reactive` object with initial value
170-
name: field name; used as fallback for label if 'label' is not in options
171-
options: dictionary with options for the input, including label,
170+
Props:
171+
user_params: dictionary with options for the input, including label,
172172
min and max values, and other fields specific to the input type.
173+
on_change: function to be called with (name, value) when the value of an input changes.
173174
"""
174-
# label for the input is "label" from options or name
175-
label = options.get("label", name)
176-
input_type = options.get("type")
177-
if input_type == "SliderInt":
178-
solara.SliderInt(
179-
label,
180-
value=user_input,
181-
min=options.get("min"),
182-
max=options.get("max"),
183-
step=options.get("step"),
184-
)
185-
elif input_type == "SliderFloat":
186-
solara.SliderFloat(
187-
label,
188-
value=user_input,
189-
min=options.get("min"),
190-
max=options.get("max"),
191-
step=options.get("step"),
192-
)
193-
elif input_type == "Select":
194-
solara.Select(
195-
label,
196-
value=options.get("value"),
197-
values=options.get("values"),
198-
)
199-
else:
200-
raise ValueError(f"{input_type} is not a supported input type")
175+
176+
for name, options in user_params.items():
177+
# label for the input is "label" from options or name
178+
label = options.get("label", name)
179+
input_type = options.get("type")
180+
181+
def change_handler(value, name=name):
182+
on_change(name, value)
183+
184+
if input_type == "SliderInt":
185+
solara.SliderInt(
186+
label,
187+
value=options.get("value"),
188+
on_value=change_handler,
189+
min=options.get("min"),
190+
max=options.get("max"),
191+
step=options.get("step"),
192+
)
193+
elif input_type == "SliderFloat":
194+
solara.SliderFloat(
195+
label,
196+
value=options.get("value"),
197+
on_value=change_handler,
198+
min=options.get("min"),
199+
max=options.get("max"),
200+
step=options.get("step"),
201+
)
202+
elif input_type == "Select":
203+
solara.Select(
204+
label,
205+
value=options.get("value"),
206+
on_value=change_handler,
207+
values=options.get("values"),
208+
)
209+
else:
210+
raise ValueError(f"{input_type} is not a supported input type")
201211

202212

203213
def make_space(model, agent_portrayal):

tests/test_jupyter_viz.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,71 @@
11
import unittest
22
from unittest.mock import Mock, patch
33

4+
import ipyvuetify as vw
45
import solara
56

6-
from mesa.experimental.jupyter_viz import JupyterViz, make_user_input
7+
from mesa.experimental.jupyter_viz import JupyterViz, UserInputs
78

89

910
class TestMakeUserInput(unittest.TestCase):
1011
def test_unsupported_type(self):
12+
@solara.component
13+
def Test(user_params):
14+
UserInputs(user_params)
15+
1116
"""unsupported input type should raise ValueError"""
1217
# bogus type
1318
with self.assertRaisesRegex(ValueError, "not a supported input type"):
14-
make_user_input(10, "input", {"type": "bogus"})
19+
solara.render(Test({"mock": {"type": "bogus"}}), handle_error=False)
20+
1521
# no type is specified
1622
with self.assertRaisesRegex(ValueError, "not a supported input type"):
17-
make_user_input(10, "input", {})
23+
solara.render(Test({"mock": {}}), handle_error=False)
24+
25+
def test_slider_int(self):
26+
@solara.component
27+
def Test(user_params):
28+
UserInputs(user_params)
1829

19-
@patch("mesa.experimental.jupyter_viz.solara")
20-
def test_slider_int(self, mock_solara):
21-
value = 10
22-
name = "num_agents"
2330
options = {
2431
"type": "SliderInt",
32+
"value": 10,
2533
"label": "number of agents",
2634
"min": 10,
2735
"max": 20,
2836
"step": 1,
2937
}
30-
make_user_input(value, name, options)
31-
mock_solara.SliderInt.assert_called_with(
32-
options["label"],
33-
value=value,
34-
min=options["min"],
35-
max=options["max"],
36-
step=options["step"],
37-
)
38+
user_params = {"num_agents": options}
39+
_, rc = solara.render(Test(user_params), handle_error=False)
40+
slider_int = rc.find(vw.Slider).widget
41+
42+
assert slider_int.v_model == options["value"]
43+
assert slider_int.label == options["label"]
44+
assert slider_int.min == options["min"]
45+
assert slider_int.max == options["max"]
46+
assert slider_int.step == options["step"]
3847

39-
@patch("mesa.experimental.jupyter_viz.solara")
40-
def test_label_fallback(self, mock_solara):
48+
def test_label_fallback(self):
4149
"""name should be used as fallback label"""
42-
value = 10
43-
name = "num_agents"
50+
51+
@solara.component
52+
def Test(user_params):
53+
UserInputs(user_params)
54+
4455
options = {
4556
"type": "SliderInt",
57+
"value": 10,
4658
}
47-
make_user_input(value, name, options)
48-
mock_solara.SliderInt.assert_called_with(
49-
name, value=value, min=None, max=None, step=None
50-
)
59+
60+
user_params = {"num_agents": options}
61+
_, rc = solara.render(Test(user_params), handle_error=False)
62+
slider_int = rc.find(vw.Slider).widget
63+
64+
assert slider_int.v_model == options["value"]
65+
assert slider_int.label == "num_agents"
66+
assert slider_int.min is None
67+
assert slider_int.max is None
68+
assert slider_int.step is None
5169

5270

5371
class TestJupyterViz(unittest.TestCase):

0 commit comments

Comments
 (0)