Skip to content

Commit e0e08cf

Browse files
committed
[12/n] tensor engine, demote call_on_shard_and_fetch
Pull Request resolved: #470 remote is about to become an endpoint, along with the endpoint 'adverbs'. By convention adverbs can't take additional arguments besides the endpoints args/kwargs, so like the `send` method `call_on_shard_and_fetch` needs to become a free function. Devmate says more **Task Completed** Successfully converted the [`**call_on_shard_and_fetch**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Fmonarch%2Fcommon%2Fremote.py%22%2C179%5D "call_on_shard_and_fetch") method to a free function! Here's a summary of all the changes made: Changes Made: ------------- ### 1. **Removed the method from the [`**Remote**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Fmonarch%2Fcommon%2Fremote.py%22%2C65%5D "Remote") class** * Deleted the [`**call_on_shard_and_fetch**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Fmonarch%2Fcommon%2Fremote.py%22%2C179%5D "call_on_shard_and_fetch") method from the [`**Remote**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Fmonarch%2Fcommon%2Fremote.py%22%2C65%5D "Remote") class in [`**remote.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Fmonarch%2Fcommon%2Fremote.py%22%2Cnull%5D "/data/users/zdevito/fbsource/fbcode/monarch/python/monarch/common/remote.py") ### 2. **Created a new free function** * Added `call_on_shard_and_fetch(remote_obj: Remote, *args, shard: Dict[str, int] | None = None, **kwargs) -> Future[Any]` as a standalone function * The function takes the [`**Remote**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Fmonarch%2Fcommon%2Fremote.py%22%2C65%5D "Remote") instance as its first parameter, followed by the same arguments as the original method * Updated the function to call the internal [`**_call_on_shard_and_fetch**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Fmonarch%2Fcommon%2Fremote.py%22%2C144%5D "_call_on_shard_and_fetch") with the appropriate parameters from the [`**Remote**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Fmonarch%2Fcommon%2Fremote.py%22%2C65%5D "Remote") object ### 3. **Updated all call sites across the codebase** * **Test files**: Updated `test_remote_functions.py` and `test_rust_backend.py` to use the new free function syntax * **Example files**: Updated controller example and FSDP examples to use the new function * **Module files**: Updated `opaque_module.py` and `opaque_object.py` to use the new function * **Internal usage**: Updated the cached propagation logic in [`**remote.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Fmonarch%2Fcommon%2Fremote.py%22%2Cnull%5D "remote.py") to use the new function ### 4. **Updated imports** * Added `from monarch.common.remote import call_on_shard_and_fetch` to all files that needed it * The function is also available through the main `monarch` module via the existing export in `__init__.py` Syntax Change: -------------- **Before (method):** `remote_obj.call_on_shard_and_fetch(arg1, arg2, shard={"host": 0})` **After (free function):** `call_on_shard_and_fetch(remote_obj, arg1, arg2, shard={"host": 0})` Files Modified: --------------- * [`**remote.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Fmonarch%2Fcommon%2Fremote.py%22%2Cnull%5D "/data/users/zdevito/fbsource/fbcode/monarch/python/monarch/common/remote.py") (main changes) * [`**test_remote_functions.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Ftests%2Ftest_remote_functions.py%22%2Cnull%5D "/data/users/zdevito/fbsource/fbcode/monarch/python/tests/test_remote_functions.py") * [`**test_rust_backend.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Ftests%2Ftest_rust_backend.py%22%2Cnull%5D "/data/users/zdevito/fbsource/fbcode/monarch/python/tests/test_rust_backend.py") * [`**example.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fexamples%2Fmeta%2Fcontroller%2Fexample.py%22%2Cnull%5D "/data/users/zdevito/fbsource/fbcode/monarch/examples/meta/controller/example.py") * [`**opaque_module.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Fmonarch%2Fopaque_module.py%22%2Cnull%5D "/data/users/zdevito/fbsource/fbcode/monarch/python/monarch/opaque_module.py") * [`**opaque_object.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Fmonarch%2Fopaque_object.py%22%2Cnull%5D "/data/users/zdevito/fbsource/fbcode/monarch/python/monarch/opaque_object.py") * [`**fsdp_parametrization.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fexamples%2Fmeta%2Ffsdp%2Ffsdp_parametrization.py%22%2Cnull%5D "/data/users/zdevito/fbsource/fbcode/monarch/examples/meta/fsdp/fsdp_parametrization.py") * [`**fsdp_setattr.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fexamples%2Fmeta%2Ffsdp%2Ffsdp_setattr.py%22%2Cnull%5D "/data/users/zdevito/fbsource/fbcode/monarch/examples/meta/fsdp/fsdp_setattr.py") * [`**fetch.py**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Fmonarch%2Ffetch.py%22%2Cnull%5D "/data/users/zdevito/fbsource/fbcode/monarch/python/monarch/fetch.py") All changes have been validated and no errors were introduced. The function maintains the same behavior and API, just with a different calling convention where the [`**Remote**`](command:code-compose.open?%5B%22%2Fdata%2Fusers%2Fzdevito%2Ffbsource%2Ffbcode%2Fmonarch%2Fpython%2Fmonarch%2Fcommon%2Fremote.py%22%2C65%5D "Remote") object is passed as the first argument instead of being the method receiver. ghstack-source-id: 295034203 Differential Revision: [D77967708](https://our.internmc.facebook.com/intern/diff/D77967708/)
1 parent 51b872e commit e0e08cf

File tree

6 files changed

+41
-38
lines changed

6 files changed

+41
-38
lines changed

python/monarch/common/remote.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ def __init__(self, impl: Any, propagator_arg: Propagator):
7272
def _resolvable(self):
7373
return resolvable_function(self._remote_impl)
7474

75+
@property
76+
def _maybe_resolvable(self):
77+
return (
78+
None
79+
if self._remote_impl is None
80+
else resolvable_function(self._remote_impl)
81+
)
82+
7583
def _propagate(self, args, kwargs, fake_args, fake_kwargs):
7684
if self._propagator_arg is None or self._propagator_arg == "cached":
7785
if self._cache is None:
@@ -104,13 +112,6 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
104112
stream._active,
105113
)
106114

107-
def call_on_shard_and_fetch(
108-
self, *args, shard: Dict[str, int] | None = None, **kwargs
109-
) -> Future[R]:
110-
return _call_on_shard_and_fetch(
111-
self._resolvable, self._fetch_propagate, *args, shard=shard, **kwargs
112-
)
113-
114115

115116
# This can't just be Callable because otherwise we are not
116117
# allowed to use type arguments in the return value.
@@ -148,21 +149,26 @@ def remote(function: Any = None, *, propagate: Propagator = None) -> Any:
148149
return Remote(function, propagate)
149150

150151

151-
def _call_on_shard_and_fetch(
152-
rfunction: ResolvableFunction | None,
153-
propagator: Any,
152+
remote_identity = Remote(None, lambda x: x)
153+
154+
155+
def call_on_shard_and_fetch(
156+
remote_obj: Remote[P, R],
154157
/,
155158
*args: object,
156159
shard: dict[str, int] | None = None,
157160
**kwargs: object,
158-
) -> Future:
161+
) -> Future[R]:
159162
"""
160163
Call `function` at the coordinates `shard` of the current device mesh, and retrieve the result as a Future.
161164
function - the remote function to call
162165
*args/**kwargs - arguments to the function
163166
shard - a dictionary from mesh dimension name to coordinate of the shard
164167
If None, this will fetch from coordinate 0 for all dimensions (useful after all_reduce/all_gather)
165168
"""
169+
170+
rfunction = remote_obj._maybe_resolvable
171+
propagator = remote_obj._fetch_propagate
166172
ambient_mesh = device_mesh._active
167173

168174
if rfunction is None:
@@ -271,8 +277,8 @@ def _cached_propagation(_cache, rfunction, args, kwargs):
271277
if key not in _cache:
272278
_miss += 1
273279
args_no_pg, kwargs_no_pg = tree_map(_mock_pgs, (args, kwargs))
274-
result_with_placeholders, output_pattern = _propagate.call_on_shard_and_fetch(
275-
function=rfunction, args=args_no_pg, kwargs=kwargs_no_pg
280+
result_with_placeholders, output_pattern = call_on_shard_and_fetch(
281+
_propagate, function=rfunction, args=args_no_pg, kwargs=kwargs_no_pg
276282
).result()
277283

278284
_, unflatten_result = flatten(

python/monarch/fetch.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from monarch.common.future import Future
1717

18-
from monarch.common.remote import _call_on_shard_and_fetch
18+
from monarch.common.remote import call_on_shard_and_fetch, remote_identity
1919

2020
T = TypeVar("T")
2121

@@ -37,9 +37,7 @@ def fetch_shard(
3737
shard = {}
3838
shard.update(kwargs)
3939

40-
return _call_on_shard_and_fetch(
41-
None, lambda *args, **kwargs: None, obj, shard=shard
42-
)
40+
return call_on_shard_and_fetch(remote_identity, obj, shard=shard)
4341

4442

4543
def show(obj: T, shard: dict[str, int] | None = None, **kwargs: int) -> object:

python/monarch/opaque_module.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from monarch.common.function_caching import TensorGroup, TensorGroupPattern
1111
from monarch.common.opaque_ref import OpaqueRef
12-
from monarch.common.remote import remote
12+
from monarch.common.remote import call_on_shard_and_fetch, remote
1313
from monarch.common.tensor_factory import TensorFactory
1414
from monarch.common.tree import flatten
1515
from monarch.opaque_object import _fresh_opaque_ref, OpaqueObject
@@ -144,11 +144,9 @@ def __init__(self, *args, **kwargs):
144144

145145
def parameters(self):
146146
if self._parameters is None:
147-
tensor_group_pattern = (
148-
remote(_get_parameters_shape)
149-
.call_on_shard_and_fetch(self._object)
150-
.result()
151-
)
147+
tensor_group_pattern = call_on_shard_and_fetch(
148+
remote(_get_parameters_shape), self._object
149+
).result()
152150
self._parameters = [
153151
p.requires_grad_(True)
154152
for p in remote(

python/monarch/opaque_object.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515

1616
from monarch.common.opaque_ref import OpaqueRef
17-
from monarch.common.remote import remote
17+
from monarch.common.remote import call_on_shard_and_fetch, remote
1818

1919

2020
def _invoke_method(obj: OpaqueRef, method_name: str, *args, **kwargs):
@@ -83,6 +83,6 @@ def call_method(self, method_name, propagation, *args, **kwargs):
8383
return endpoint(self, method_name, *args, **kwargs)
8484

8585
def call_method_on_shard_and_fetch(self, method_name, *args, **kwargs):
86-
return remote(_invoke_method).call_on_shard_and_fetch(
87-
self, method_name, *args, **kwargs
86+
return call_on_shard_and_fetch(
87+
remote(_invoke_method), self, method_name, *args, **kwargs
8888
)

python/tests/test_remote_functions.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from monarch.cached_remote_function import remote_autograd_function
3636
from monarch.common import remote as remote_module
3737
from monarch.common.device_mesh import DeviceMesh
38-
from monarch.common.remote import Remote
38+
from monarch.common.remote import call_on_shard_and_fetch, Remote
3939
from monarch.mesh_controller import RemoteException as NewRemoteException
4040

4141
from monarch.opaque_module import OpaqueModule
@@ -634,11 +634,10 @@ def test_fetch_preprocess(self, backend_type):
634634
with self.local_device_mesh(2, 2, backend_type):
635635
assert (
636636
"an argument processed"
637-
== remote("monarch.worker._testing_function.do_some_processing")
638-
.call_on_shard_and_fetch(
637+
== call_on_shard_and_fetch(
638+
remote("monarch.worker._testing_function.do_some_processing"),
639639
"an argument",
640-
)
641-
.result()
640+
).result()
642641
)
643642

644643
def test_cached_remote_function(self, backend_type):
@@ -733,7 +732,7 @@ def cuda_works(x):
733732

734733
with self.local_device_mesh(2, 2, backend_type):
735734
a = torch.ones(())
736-
assert check.call_on_shard_and_fetch(bar(a, a)).result()
735+
assert call_on_shard_and_fetch(check, bar(a, a)).result()
737736
# ensure we do not attempt to pickle closures
738737
close()
739738

@@ -776,7 +775,7 @@ def simple():
776775

777776
with self.local_device_mesh(1, 1, backend_type):
778777
# This should be a valid return than an exception to raise
779-
simple.call_on_shard_and_fetch().result()
778+
call_on_shard_and_fetch(simple).result()
780779

781780
def test_opaque_object(self, backend_type):
782781
with self.local_device_mesh(2, 2, backend_type):

python/tests/test_rust_backend.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch.utils._python_dispatch
1818
from monarch import fetch_shard, no_mesh, remote, Stream
1919
from monarch.common.device_mesh import DeviceMesh
20+
from monarch.common.remote import call_on_shard_and_fetch
2021
from monarch.rust_local_mesh import local_meshes, LoggingLocation, SocketType
2122
from torch.nn.attention import sdpa_kernel, SDPBackend
2223
from torch.nn.functional import scaled_dot_product_attention
@@ -111,9 +112,10 @@ def test_fetch_preprocess(self):
111112
with local_mesh():
112113
assert (
113114
"an argument processed"
114-
== remote("monarch.worker._testing_function.do_some_processing")
115-
.call_on_shard_and_fetch("an argument")
116-
.result()
115+
== call_on_shard_and_fetch(
116+
remote("monarch.worker._testing_function.do_some_processing"),
117+
"an argument",
118+
).result()
117119
)
118120

119121
def test_brutal_shutdown(self):
@@ -143,8 +145,8 @@ def has_nan(t):
143145
return torch.isnan(t).any().item()
144146

145147
t = torch.rand(3, 4)
146-
res = has_nan.call_on_shard_and_fetch(
147-
t, shard={"host": 0, "gpu": 0}
148+
res = call_on_shard_and_fetch(
149+
has_nan, t, shard={"host": 0, "gpu": 0}
148150
).result()
149151

150152
self.assertFalse(res)

0 commit comments

Comments
 (0)