Skip to content

Commit 00d6b7f

Browse files
committed
[13/n] tensor engine, free functions implement call/broadcast/stream/etc., temp
Pull Request resolved: #478 This unifies call_on_shard_and_fetch with actor endpoints. It is now possible to issue a `.call` on a free-function remote and get a ValueMesh back. `call_on_shard_and_fetch` along with `fetch_shard` and `inspect` are implemented in terms of `call_one` now. ghstack-source-id: 295137297 @exported-using-ghexport Differential Revision: [D77978317](https://our.internmc.facebook.com/intern/diff/D77978317/)
1 parent ffb9824 commit 00d6b7f

File tree

7 files changed

+240
-59
lines changed

7 files changed

+240
-59
lines changed

python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import collections.abc
8-
from typing import Dict, final, Iterator, List, overload
8+
from typing import Dict, final, Iterator, List, overload, Sequence
99

1010
@final
1111
class Slice:
@@ -90,7 +90,7 @@ class Slice:
9090
...
9191

9292
@staticmethod
93-
def new_row_major(ranks: list[int]) -> "Slice":
93+
def new_row_major(ranks: Sequence[int]) -> "Slice":
9494
"""Returns a contiguous slice composed of ranks"""
9595
...
9696

@@ -106,7 +106,7 @@ class Shape:
106106
- `labels`: A list of strings representing the labels for each dimension.
107107
- `slice`: An Slice object representing the shape.
108108
"""
109-
def __new__(cls, labels: List[str], slice: Slice) -> "Shape": ...
109+
def __new__(cls, labels: Sequence[str], slice: Slice) -> "Shape": ...
110110
@property
111111
def ndslice(self) -> Slice: ...
112112
@property

python/monarch/_src/actor/actor_mesh.py

Lines changed: 97 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
import sys
1616
import traceback
1717

18+
from abc import ABC, abstractmethod
19+
1820
from dataclasses import dataclass
21+
from operator import mul
1922
from traceback import extract_tb, StackSummary
2023
from typing import (
2124
Any,
@@ -27,11 +30,13 @@
2730
Dict,
2831
Generic,
2932
Iterable,
33+
Iterator,
3034
List,
3135
Literal,
3236
NamedTuple,
3337
Optional,
3438
ParamSpec,
39+
Sequence,
3540
Tuple,
3641
Type,
3742
TYPE_CHECKING,
@@ -204,18 +209,32 @@ def __len__(self) -> int:
204209
return len(self._shape)
205210

206211

207-
class Endpoint(Generic[P, R]):
208-
def __init__(
212+
class Extent(NamedTuple):
213+
labels: Sequence[str]
214+
sizes: Sequence[int]
215+
216+
@property
217+
def nelements(self) -> int:
218+
return functools.reduce(mul, self.sizes, 1)
219+
220+
def __str__(self) -> str:
221+
return str(dict(zip(self.labels, self.sizes)))
222+
223+
224+
class Endpoint(ABC, Generic[P, R]):
225+
@abstractmethod
226+
def _send(
209227
self,
210-
actor_mesh_ref: _ActorMeshRefImpl,
211-
name: str,
212-
impl: Callable[Concatenate[Any, P], Awaitable[R]],
213-
mailbox: Mailbox,
214-
) -> None:
215-
self._actor_mesh = actor_mesh_ref
216-
self._name = name
217-
self._signature: inspect.Signature = inspect.signature(impl)
218-
self._mailbox = mailbox
228+
args: Tuple[Any, ...],
229+
kwargs: Dict[str, Any],
230+
port: "Optional[Port]" = None,
231+
selection: Selection = "all",
232+
) -> Extent:
233+
pass
234+
235+
@abstractmethod
236+
def _port(self, once: bool = False) -> "PortTuple[R]":
237+
pass
219238

220239
# the following are all 'adverbs' or different ways to handle the
221240
# return values of this endpoint. Adverbs should only ever take *args, **kwargs
@@ -228,46 +247,47 @@ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
228247
229248
Load balanced RPC-style entrypoint for request/response messaging.
230249
"""
231-
p: Port[R]
232-
r: PortReceiver[R]
233250
p, r = port(self, once=True)
234251
# pyre-ignore
235-
send(self, args, kwargs, port=p, selection="choose")
252+
self._send(args, kwargs, port=p, selection="choose")
236253
return r.recv()
237254

238255
def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
239-
if len(self._actor_mesh) != 1:
256+
p, r = port(self, once=True)
257+
# pyre-ignore
258+
extent = self._send(args, kwargs, port=p, selection="choose")
259+
if extent.nelements != 1:
240260
raise ValueError(
241-
f"Can only use 'call_one' on a single Actor but this actor has shape {self._actor_mesh._shape}"
261+
f"Can only use 'call_one' on a single Actor but this actor has shape {extent}"
242262
)
243-
return self.choose(*args, **kwargs)
263+
return r.recv()
244264

245265
def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
246266
p: Port[R]
247267
r: RankedPortReceiver[R]
248268
p, r = ranked_port(self)
249269
# pyre-ignore
250-
send(self, args, kwargs, port=p)
270+
extent = self._send(args, kwargs, port=p)
251271

252272
async def process() -> ValueMesh[R]:
253-
results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
254-
for _ in range(len(self._actor_mesh)):
273+
results: List[R] = [None] * extent.nelements # pyre-fixme[9]
274+
for _ in range(extent.nelements):
255275
rank, value = await r.recv()
256276
results[rank] = value
257277
call_shape = Shape(
258-
self._actor_mesh._shape.labels,
259-
NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes),
278+
extent.labels,
279+
NDSlice.new_row_major(extent.sizes),
260280
)
261281
return ValueMesh(call_shape, results)
262282

263283
def process_blocking() -> ValueMesh[R]:
264-
results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
265-
for _ in range(len(self._actor_mesh)):
284+
results: List[R] = [None] * extent.nelements # pyre-fixme[9]
285+
for _ in range(extent.nelements):
266286
rank, value = r.recv().get()
267287
results[rank] = value
268288
call_shape = Shape(
269-
self._actor_mesh._shape.labels,
270-
NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes),
289+
extent.labels,
290+
NDSlice.new_row_major(extent.sizes),
271291
)
272292
return ValueMesh(call_shape, results)
273293

@@ -282,8 +302,8 @@ async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R
282302
"""
283303
p, r = port(self)
284304
# pyre-ignore
285-
send(self, args, kwargs, port=p)
286-
for _ in range(len(self._actor_mesh)):
305+
extent = self._send(args, kwargs, port=p)
306+
for _ in range(extent.nelements):
287307
yield await r.recv()
288308

289309
def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
@@ -298,6 +318,46 @@ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
298318
send(self, args, kwargs)
299319

300320

321+
class ActorEndpoint(Endpoint[P, R]):
322+
def __init__(
323+
self,
324+
actor_mesh_ref: _ActorMeshRefImpl,
325+
name: str,
326+
impl: Callable[Concatenate[Any, P], Awaitable[R]],
327+
mailbox: Mailbox,
328+
) -> None:
329+
self._actor_mesh = actor_mesh_ref
330+
self._name = name
331+
self._signature: inspect.Signature = inspect.signature(impl)
332+
self._mailbox = mailbox
333+
334+
def _send(
335+
self,
336+
args: Tuple[Any, ...],
337+
kwargs: Dict[str, Any],
338+
port: "Optional[Port]" = None,
339+
selection: Selection = "all",
340+
) -> Extent:
341+
"""
342+
Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
343+
344+
This sends the message to all actors but does not wait for any result.
345+
"""
346+
self._signature.bind(None, *args, **kwargs)
347+
message = PythonMessage(
348+
self._name,
349+
_pickle((args, kwargs)),
350+
None if port is None else port._port_ref,
351+
None,
352+
)
353+
self._actor_mesh.cast(message, selection)
354+
shape = self._actor_mesh._shape
355+
return Extent(shape.labels, shape.ndslice.sizes)
356+
357+
def _port(self, once: bool = False) -> "PortTuple[R]":
358+
return PortTuple.create(self._mailbox, once)
359+
360+
301361
class Accumulator(Generic[P, R, A]):
302362
def __init__(
303363
self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A]
@@ -337,10 +397,13 @@ def item(self, **kwargs) -> R:
337397

338398
return self._values[self._ndslice.nditem(coordinates)]
339399

340-
def __iter__(self):
400+
def items(self) -> Iterable[Tuple[Point, R]]:
341401
for rank in self._shape.ranks():
342402
yield Point(rank, self._shape), self._values[rank]
343403

404+
def __iter__(self) -> Iterator[Tuple[Point, R]]:
405+
return iter(self.items())
406+
344407
def __len__(self) -> int:
345408
return len(self._shape)
346409

@@ -368,14 +431,7 @@ def send(
368431
369432
This sends the message to all actors but does not wait for any result.
370433
"""
371-
endpoint._signature.bind(None, *args, **kwargs)
372-
message = PythonMessage(
373-
endpoint._name,
374-
_pickle((args, kwargs)),
375-
None if port is None else port._port_ref,
376-
None,
377-
)
378-
endpoint._actor_mesh.cast(message, selection)
434+
endpoint._send(args, kwargs, port, selection)
379435

380436

381437
class EndpointProperty(Generic[P, R]):
@@ -447,7 +503,7 @@ def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
447503
# not part of the Endpoint API because they way it accepts arguments
448504
# and handles concerns is different.
449505
def port(endpoint: Endpoint[P, R], once: bool = False) -> "PortTuple[R]":
450-
return PortTuple.create(endpoint._mailbox, once)
506+
return endpoint._port(once)
451507

452508

453509
def ranked_port(
@@ -676,7 +732,7 @@ def __init__(
676732
setattr(
677733
self,
678734
attr_name,
679-
Endpoint(
735+
ActorEndpoint(
680736
self._actor_mesh_ref,
681737
attr_name,
682738
attr_value._method,
@@ -695,7 +751,7 @@ def __getattr__(self, name: str) -> Any:
695751
attr = getattr(self._class, name)
696752
if isinstance(attr, EndpointProperty):
697753
# Dynamically create the endpoint
698-
endpoint = Endpoint(
754+
endpoint = ActorEndpoint(
699755
self._actor_mesh_ref,
700756
name,
701757
attr._method,
@@ -718,7 +774,7 @@ def _create(
718774
async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None:
719775
return None
720776

721-
ep = Endpoint(
777+
ep = ActorEndpoint(
722778
self._actor_mesh_ref,
723779
"__init__",
724780
null_func,

python/monarch/common/device_mesh.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
from torch.utils._python_dispatch import TorchDispatchMode
3434
from torch.utils._pytree import tree_map
35+
from torch.utils.weak import weakref
3536

3637
from ._tensor_to_table import tensor_to_table
3738
from .context_manager import activate_first_context_manager
@@ -170,6 +171,7 @@ def __init__(
170171
self.exit = lambda: None
171172
self.ref = None
172173
self._active_mesh_context = None
174+
self._subset_of: Optional[weakref.ReferenceType["DeviceMesh"]] = None
173175

174176
def define_remotely(self):
175177
if self.ref is None:
@@ -227,8 +229,17 @@ def _labels(self) -> Tuple[str, ...]:
227229
def _new_with_shape(self, shape: Shape) -> "DeviceMesh":
228230
mesh = DeviceMesh(self.client, shape.ndslice, tuple(shape.labels))
229231
mesh.exit = self.exit
232+
mesh._subset_of = weakref.ref(self)
230233
return mesh
231234

235+
def _is_subset_of(self, other: "DeviceMesh") -> bool:
236+
p = self
237+
while p is not None:
238+
if p is other:
239+
return True
240+
p = None if p._subset_of is None else p._subset_of()
241+
return False
242+
232243
def __call__(self, **kwargs) -> "DeviceMesh":
233244
"""
234245
device_mesh(batch=3) or device_mesh(batch=slice(3, None))

0 commit comments

Comments
 (0)