Skip to content

Commit 41be443

Browse files
authored
Creating threads to update visualization asynchronously (#2656)
This PR introduces a thread-based mechanism to handle model updates and visualization rendering concurrently. It ensures smooth execution of simulations by separating model step execution and visualization updates into independent threads, improving performance and responsiveness during simulations. Fixes #2604 ### Motive Previously, the visualization process could become a bottleneck during rapid simulations, as rendering was tightly coupled with model updates. This caused delays and UI responsiveness issues. By separating these operations into threads, the model execution is no longer hindered by the visualization process. ### Implementation 1. Introduced separate thread for visualisation: Handles visualization updates triggered by threading.Event (visualization_pause_event) to synchronize rendering with model steps. 2. Thread Synchronization: Implemented visualization_pause_event to signal the visualization thread after each model step is completed, ensuring rendering happens efficiently without blocking the simulation. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1575c88 commit 41be443

File tree

1 file changed

+114
-19
lines changed

1 file changed

+114
-19
lines changed

mesa/visualization/solara_viz.py

Lines changed: 114 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
import asyncio
2727
import inspect
28+
import threading
29+
import time
2830
from collections.abc import Callable
2931
from typing import TYPE_CHECKING, Literal
3032

@@ -57,6 +59,7 @@ def SolaraViz(
5759
simulator: Simulator | None = None,
5860
model_params=None,
5961
name: str | None = None,
62+
use_threads: bool = False,
6063
):
6164
"""Solara visualization component.
6265
@@ -76,6 +79,8 @@ def SolaraViz(
7679
This controls the speed of the model's automatic stepping. Defaults to 100 ms.
7780
render_interval (int, optional): Controls how often plots are updated during a simulation,
7881
allowing users to skip intermediate steps and update graphs less frequently.
82+
use_threads: Flag for indicating whether to utilize multi-threading for model execution.
83+
When checked, the model will utilize multiple threads,adjust based on system capabilities.
7984
simulator: A simulator that controls the model (optional)
8085
model_params (dict, optional): Parameters for (re-)instantiating a model.
8186
Can include user-adjustable parameters and fixed parameters. Defaults to None.
@@ -114,6 +119,7 @@ def SolaraViz(
114119
reactive_model_parameters = solara.use_reactive({})
115120
reactive_play_interval = solara.use_reactive(play_interval)
116121
reactive_render_interval = solara.use_reactive(render_interval)
122+
reactive_use_threads = solara.use_reactive(use_threads)
117123
with solara.AppBar():
118124
solara.AppBarTitle(name if name else model.value.__class__.__name__)
119125
solara.lab.ThemeToggle()
@@ -136,12 +142,25 @@ def SolaraViz(
136142
max=100,
137143
step=2,
138144
)
145+
if reactive_use_threads.value:
146+
solara.Text("Increase play interval to avoid skipping plots")
147+
148+
def set_reactive_use_threads(value):
149+
reactive_use_threads.set(value)
150+
151+
solara.Checkbox(
152+
label="Use Threads",
153+
value=reactive_use_threads,
154+
on_value=set_reactive_use_threads,
155+
)
156+
139157
if not isinstance(simulator, Simulator):
140158
ModelController(
141159
model,
142160
model_parameters=reactive_model_parameters,
143161
play_interval=reactive_play_interval,
144162
render_interval=reactive_render_interval,
163+
use_threads=reactive_use_threads,
145164
)
146165
else:
147166
SimulatorController(
@@ -150,6 +169,7 @@ def SolaraViz(
150169
model_parameters=reactive_model_parameters,
151170
play_interval=reactive_play_interval,
152171
render_interval=reactive_render_interval,
172+
use_threads=reactive_use_threads,
153173
)
154174
with solara.Card("Model Parameters"):
155175
ModelCreator(
@@ -211,6 +231,7 @@ def ModelController(
211231
model_parameters: dict | solara.Reactive[dict] = None,
212232
play_interval: int | solara.Reactive[int] = 100,
213233
render_interval: int | solara.Reactive[int] = 1,
234+
use_threads: bool | solara.Reactive[bool] = False,
214235
):
215236
"""Create controls for model execution (step, play, pause, reset).
216237
@@ -219,37 +240,70 @@ def ModelController(
219240
model_parameters: Reactive parameters for (re-)instantiating a model.
220241
play_interval: Interval for playing the model steps in milliseconds.
221242
render_interval: Controls how often the plots are updated during simulation steps.Higher value reduce update frequency.
243+
use_threads: Flag for indicating whether to utilize multi-threading for model execution.
222244
"""
223245
playing = solara.use_reactive(False)
224246
running = solara.use_reactive(True)
247+
225248
if model_parameters is None:
226249
model_parameters = {}
227250
model_parameters = solara.use_reactive(model_parameters)
228-
229-
async def step():
230-
while playing.value and running.value:
231-
await asyncio.sleep(play_interval.value / 1000)
232-
do_step()
251+
visualization_pause_event = solara.use_memo(lambda: threading.Event(), [])
252+
253+
def step():
254+
try:
255+
while running.value and playing.value:
256+
time.sleep(play_interval.value / 1000)
257+
do_step()
258+
if use_threads.value:
259+
visualization_pause_event.set()
260+
except Exception as e:
261+
print(f"Error in step: {e}")
262+
return
263+
264+
def visualization_task():
265+
if use_threads.value:
266+
try:
267+
while playing.value and running.value:
268+
visualization_pause_event.wait()
269+
visualization_pause_event.clear()
270+
force_update()
271+
except Exception as e:
272+
print(f"Error in visualization_task: {e}")
233273

234274
solara.lab.use_task(
235-
step, dependencies=[playing.value, running.value], prefer_threaded=False
275+
step, dependencies=[playing.value, running.value], prefer_threaded=True
276+
)
277+
278+
solara.use_thread(
279+
visualization_task,
280+
dependencies=[playing.value, running.value],
236281
)
237282

238283
@function_logger(__name__)
239284
def do_step():
240285
"""Advance the model by the number of steps specified by the render_interval slider."""
241-
for _ in range(render_interval.value):
242-
model.value.step()
286+
if playing.value:
287+
for _ in range(render_interval.value):
288+
model.value.step()
289+
running.value = model.value.running
290+
if not playing.value:
291+
break
292+
if not use_threads.value:
293+
force_update()
243294

244-
running.value = model.value.running
245-
246-
force_update()
295+
else:
296+
for _ in range(render_interval.value):
297+
model.value.step()
298+
running.value = model.value.running
299+
force_update()
247300

248301
@function_logger(__name__)
249302
def do_reset():
250303
"""Reset the model to its initial state."""
251304
playing.value = False
252305
running.value = True
306+
visualization_pause_event.clear()
253307
_mesa_logger.log(
254308
10,
255309
f"creating new {model.value.__class__} instance with {model_parameters.value}",
@@ -285,6 +339,7 @@ def SimulatorController(
285339
model_parameters: dict | solara.Reactive[dict] = None,
286340
play_interval: int | solara.Reactive[int] = 100,
287341
render_interval: int | solara.Reactive[int] = 1,
342+
use_threads: bool | solara.Reactive[bool] = False,
288343
):
289344
"""Create controls for model execution (step, play, pause, reset).
290345
@@ -294,6 +349,7 @@ def SimulatorController(
294349
model_parameters: Reactive parameters for (re-)instantiating a model.
295350
play_interval: Interval for playing the model steps in milliseconds.
296351
render_interval: Controls how often the plots are updated during simulation steps.Higher values reduce update frequency.
352+
use_threads: Flag for indicating whether to utilize multi-threading for model execution.
297353
298354
Notes:
299355
The `step button` increments the step by the value specified in the `render_interval` slider.
@@ -304,27 +360,66 @@ def SimulatorController(
304360
if model_parameters is None:
305361
model_parameters = {}
306362
model_parameters = solara.use_reactive(model_parameters)
307-
308-
async def step():
309-
while playing.value and running.value:
310-
await asyncio.sleep(play_interval.value / 1000)
311-
do_step()
363+
visualization_pause_event = solara.use_memo(lambda: threading.Event(), [])
364+
pause_step_event = solara.use_memo(lambda: threading.Event(), [])
365+
366+
def step():
367+
try:
368+
while running.value and playing.value:
369+
time.sleep(play_interval.value / 1000)
370+
if use_threads.value:
371+
pause_step_event.wait()
372+
pause_step_event.clear()
373+
do_step()
374+
if use_threads.value:
375+
visualization_pause_event.set()
376+
except Exception as e:
377+
print(f"Error in step: {e}")
378+
379+
def visualization_task():
380+
if use_threads.value:
381+
try:
382+
loop = asyncio.new_event_loop()
383+
asyncio.set_event_loop(loop)
384+
pause_step_event.set()
385+
while playing.value and running.value:
386+
visualization_pause_event.wait()
387+
visualization_pause_event.clear()
388+
force_update()
389+
pause_step_event.set()
390+
except Exception as e:
391+
print(f"Error in visualization_task: {e}")
392+
return
312393

313394
solara.lab.use_task(
314395
step, dependencies=[playing.value, running.value], prefer_threaded=False
315396
)
397+
solara.lab.use_task(visualization_task, dependencies=[playing.value])
316398

317399
def do_step():
318400
"""Advance the model by the number of steps specified by the render_interval slider."""
319-
simulator.run_for(render_interval.value)
320-
running.value = model.value.running
321-
force_update()
401+
if playing.value:
402+
for _ in range(render_interval.value):
403+
simulator.run_for(1)
404+
running.value = model.value.running
405+
if not playing.value:
406+
break
407+
if not use_threads.value:
408+
force_update()
409+
410+
else:
411+
for _ in range(render_interval.value):
412+
simulator.run_for(1)
413+
running.value = model.value.running
414+
force_update()
322415

323416
def do_reset():
324417
"""Reset the model to its initial state."""
325418
playing.value = False
326419
running.value = True
327420
simulator.reset()
421+
visualization_pause_event.clear()
422+
pause_step_event.clear()
328423
model.value = model.value = model.value.__class__(
329424
simulator=simulator, **model_parameters.value
330425
)

0 commit comments

Comments
 (0)