Skip to content

Commit 5a2d0ee

Browse files
andrewjcgfacebook-github-bot
authored andcommitted
Initial support for auto-reloading (#403)
Summary: Pull Request resolved: #403 Adds initial support to add hot reloading of synced code. This is roughly based on the ipython strategy of tracking mtimes of loaded modules under the workspace and reloading modules whose backing files are update on an explicit `reload()` actor call. Reviewed By: suo Differential Revision: D77415168 fbshipit-source-id: 512678055512536677f74cf389cebc2884330331
1 parent 216063f commit 5a2d0ee

File tree

3 files changed

+323
-2
lines changed

3 files changed

+323
-2
lines changed
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import contextlib
8+
import dataclasses
9+
import importlib
10+
import importlib.abc
11+
import importlib.util
12+
import itertools
13+
import sys
14+
import threading
15+
from pathlib import Path
16+
from types import ModuleType
17+
from typing import Dict, List, Optional, Tuple
18+
19+
from monarch.actor_mesh import Actor, endpoint
20+
from monarch.code_sync import WorkspaceLocation
21+
22+
23+
class SysAuditHookGuard(contextlib.AbstractContextManager):
24+
"""
25+
A guard (and context manager), which will unregister an import hook when
26+
closed or deleted.
27+
"""
28+
29+
def __init__(self, hooks, idx):
30+
self._hooks = hooks
31+
self._idx = idx
32+
33+
def close(self):
34+
self._hooks.pop(self._idx, None)
35+
36+
def __enter__(self):
37+
return self
38+
39+
def __exit__(self, *args):
40+
self.close()
41+
42+
def __del__(self):
43+
self.close()
44+
45+
46+
class SysAuditHookMultiplexer:
47+
"""
48+
Multiplexes import hooks to multiple hooks.
49+
50+
`sys.addaudithook`s can only be added and not removed, so this class provides
51+
a global singleton that can be used to multiplex multiple hooks which support
52+
removal.
53+
"""
54+
55+
def __init__(self):
56+
self._idx = itertools.count()
57+
self._hooks = {}
58+
59+
def _callback(self, event, args):
60+
for hook in self._hooks.values():
61+
hook(event, args)
62+
63+
def add(self, hook) -> SysAuditHookGuard:
64+
idx = next(self._idx)
65+
self._hooks[idx] = hook
66+
return SysAuditHookGuard(self._hooks, idx)
67+
68+
_instance_lock = threading.Lock()
69+
_instance = None
70+
71+
@classmethod
72+
def singleton(cls):
73+
if cls._instance is None:
74+
with cls._instance_lock:
75+
if cls._instance is None:
76+
cls._instance = SysAuditHookMultiplexer()
77+
sys.addaudithook(cls._instance._callback)
78+
return cls._instance
79+
80+
81+
@dataclasses.dataclass
82+
class ThreadLocalState(threading.local):
83+
last_import: Optional[str] = None
84+
85+
86+
class SysAuditImportHook:
87+
"""
88+
An audit hook that processes and coalesces import/exec events and calls a
89+
user-defined callback with the module name and module object which was
90+
imported.
91+
"""
92+
93+
def __init__(self, callback):
94+
self._callback = callback
95+
self._state = ThreadLocalState()
96+
97+
@classmethod
98+
def install(cls, callback) -> SysAuditHookGuard:
99+
return SysAuditHookMultiplexer.singleton().add(SysAuditImportHook(callback))
100+
101+
def _py_filename(self, filename: Path) -> Path:
102+
if filename.suffix in (".pyc", ".pyo"):
103+
return filename.with_suffix(".py")
104+
return filename
105+
106+
def __call__(self, event, args):
107+
if event == "import":
108+
# While `filename` is specific as an argument to the import event, it's
109+
# almost always `None`, so we need to wait for a subsequent exec event
110+
# to get the filename.
111+
module, _, _, _, _ = args
112+
self._state.last_import = module
113+
elif event == "exec":
114+
module_name = self._state.last_import
115+
if module_name is None:
116+
return
117+
# We always expect an exec right after an import, so we can clear the
118+
# last import module name we store.
119+
self._state.last_import = None
120+
module = sys.modules.get(module_name)
121+
if module is None:
122+
return
123+
if module.__file__ is None:
124+
return
125+
(code_obj,) = args
126+
if code_obj.co_filename is None:
127+
return
128+
# code objects store the original source name, not the pyc
129+
if self._py_filename(Path(module.__file__)) != Path(code_obj.co_filename):
130+
return
131+
self._callback(module_name, module)
132+
133+
134+
@dataclasses.dataclass(frozen=True, kw_only=True)
135+
class Fingerprint:
136+
mtime: float
137+
size: int
138+
139+
@classmethod
140+
def for_path(cls, path: Path) -> "Fingerprint":
141+
stat = path.stat()
142+
return Fingerprint(mtime=stat.st_mtime, size=stat.st_size)
143+
144+
145+
class AutoReloader:
146+
"""
147+
Track changes to modules in a workspace and reload them when they change.
148+
"""
149+
150+
def __init__(self, workspace: Path, reload=importlib.reload):
151+
self._workspace = workspace
152+
self._reload = reload
153+
self._tracked_modules: Dict[str, Tuple[Path, Fingerprint]] = {}
154+
self._track_all_imported()
155+
156+
def _maybe_track_module(self, name: str, module: ModuleType):
157+
filename = getattr(module, "__file__", None)
158+
if filename is None:
159+
return
160+
filename = Path(filename)
161+
162+
# Ignore modules that are not in the workspace.
163+
if not filename.is_relative_to(self._workspace):
164+
return
165+
166+
self._tracked_modules[name] = (
167+
filename,
168+
Fingerprint.for_path(filename),
169+
)
170+
171+
def _track_all_imported(self):
172+
for name, module in sys.modules.items():
173+
if module is None:
174+
continue
175+
self._maybe_track_module(name, module)
176+
177+
def import_callback(self, name: str, module: ModuleType):
178+
"""
179+
Callback for when a module has been imported.
180+
"""
181+
182+
self._maybe_track_module(name, module)
183+
184+
def reload_changes(self) -> List[str]:
185+
"""
186+
Reload all modules that have changed since they were last imported.
187+
"""
188+
189+
reloaded = []
190+
191+
for module_name, (filename, stored_fingerprint) in list(
192+
self._tracked_modules.items()
193+
):
194+
fingerprint = Fingerprint.for_path(filename)
195+
if fingerprint == stored_fingerprint:
196+
continue
197+
reloaded.append(module_name)
198+
self._reload(sys.modules[module_name])
199+
self._tracked_modules[module_name] = (filename, fingerprint)
200+
201+
return reloaded
202+
203+
204+
class AutoReloadActor(Actor):
205+
def __init__(self, workspace: WorkspaceLocation):
206+
self._reloader = AutoReloader(workspace.resolve())
207+
self._hook_guard = SysAuditImportHook.install(self._reloader.import_callback)
208+
209+
@endpoint
210+
async def reload(self) -> None:
211+
changed = self._reloader.reload_changes()
212+
print(f"reloaded modules: {changed}")

python/monarch/proc_mesh.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
)
4545
from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice
4646
from monarch.actor_mesh import _Actor, _ActorMeshRefImpl, Actor, ActorMeshRef
47-
4847
from monarch.code_sync import RsyncMeshClient, WorkspaceLocation
48+
from monarch.code_sync.auto_reload import AutoReloadActor
4949
from monarch.common._device_utils import _local_device_count
5050
from monarch.common.shape import MeshTrait
5151
from monarch.rdma import RDMAManager
@@ -86,6 +86,7 @@ def __init__(
8686
self._mailbox: Mailbox = self._proc_mesh.client
8787
self._rdma_manager: Optional[RDMAManager] = None
8888
self._rsync_mesh_client: Optional[RsyncMeshClient] = None
89+
self._auto_reload_actor: Optional[AutoReloadActor] = None
8990
self._maybe_device_mesh: Optional[DeviceMesh] = _device_mesh
9091
if _mock_shape is None:
9192
self._rdma_manager = self._spawn_blocking("rdma_manager", RDMAManager)
@@ -213,7 +214,7 @@ def rank_tensor(self, dim: str | Sequence[str]) -> "torch.Tensor":
213214
def rank_tensors(self) -> Dict[str, "torch.Tensor"]:
214215
return self._device_mesh.ranks
215216

216-
async def sync_workspace(self) -> None:
217+
async def sync_workspace(self, auto_reload: bool = False) -> None:
217218
if self._rsync_mesh_client is None:
218219
# TODO(agallagher): We need some way to configure and pass this
219220
# in -- right now we're assuming the `gpu` dimension, which isn't
@@ -233,7 +234,16 @@ async def sync_workspace(self) -> None:
233234
local_workspace=os.getcwd(),
234235
remote_workspace=WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"),
235236
)
237+
self._auto_reload_actor = self._spawn_blocking(
238+
"auto_reload",
239+
AutoReloadActor,
240+
WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"),
241+
)
242+
assert self._rsync_mesh_client is not None
236243
await self._rsync_mesh_client.sync_workspace()
244+
if auto_reload:
245+
assert self._auto_reload_actor is not None
246+
await self._auto_reload_actor.reload.call()
237247

238248

239249
async def local_proc_mesh_nonblocking(
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import compileall
8+
import contextlib
9+
import importlib
10+
import os
11+
import py_compile
12+
import sys
13+
import tempfile
14+
import unittest
15+
from pathlib import Path
16+
from typing import Any, Generator
17+
18+
from monarch.code_sync.auto_reload import AutoReloader, SysAuditImportHook
19+
20+
21+
def write_text(path: Path, content: str):
22+
with open(path, "w") as f:
23+
print(content, file=f, end="")
24+
os.fsync(f.fileno()) # needed for mtimes changes to be reflected immediately
25+
26+
27+
@contextlib.contextmanager
28+
def importable_workspace() -> Generator[Path, Any, Any]:
29+
"""Context manager to add the workspace to sys.path."""
30+
with tempfile.TemporaryDirectory() as workspace:
31+
sys.path.insert(0, workspace)
32+
try:
33+
yield Path(workspace)
34+
finally:
35+
for module in list(sys.modules.values()):
36+
filename = getattr(module, "__file__", None)
37+
if filename is not None and filename.startswith(workspace + "/"):
38+
del sys.modules[module.__name__]
39+
sys.path.remove(workspace)
40+
41+
42+
class TestAutoReloader(unittest.TestCase):
43+
def test_source_change(self):
44+
with importable_workspace() as workspace:
45+
reloader = AutoReloader(workspace)
46+
with SysAuditImportHook.install(reloader.import_callback):
47+
filename = workspace / "test_module.py"
48+
write_text(filename, "foo = 1\n")
49+
50+
import test_module # pyre-ignore: Undefined import [21]
51+
52+
self.assertEqual(Path(test_module.__file__), filename)
53+
self.assertEqual(test_module.foo, 1)
54+
55+
write_text(filename, "foo = 2\nbar = 4\n")
56+
os.remove(importlib.util.cache_from_source(filename)) # force recompile
57+
58+
self.assertEqual(
59+
reloader.reload_changes(),
60+
["test_module"],
61+
)
62+
self.assertEqual(test_module.foo, 2)
63+
64+
def test_pyc_only_change(self):
65+
with importable_workspace() as workspace:
66+
reloader = AutoReloader(workspace)
67+
with SysAuditImportHook.install(reloader.import_callback):
68+
filename = workspace / "test_module.py"
69+
pyc = filename.with_suffix(".pyc")
70+
71+
write_text(filename, "foo = 1\n")
72+
compileall.compile_dir(
73+
workspace,
74+
legacy=True,
75+
quiet=True,
76+
invalidation_mode=py_compile.PycInvalidationMode.CHECKED_HASH,
77+
)
78+
filename.unlink()
79+
80+
import test_module # pyre-ignore: Undefined import [21]
81+
82+
self.assertEqual(Path(test_module.__file__), pyc)
83+
self.assertEqual(test_module.foo, 1)
84+
85+
write_text(filename, "foo = 2\nbar = 4\n")
86+
pyc.unlink() # force recompile
87+
compileall.compile_dir(
88+
workspace,
89+
legacy=True,
90+
quiet=True,
91+
invalidation_mode=py_compile.PycInvalidationMode.CHECKED_HASH,
92+
)
93+
filename.unlink()
94+
95+
self.assertEqual(
96+
reloader.reload_changes(),
97+
["test_module"],
98+
)
99+
self.assertEqual(test_module.foo, 2)

0 commit comments

Comments
 (0)