Skip to content

Commit 1ad1333

Browse files
zdevitofacebook-github-bot
authored andcommitted
tensor engine, free functions implement call/broadcast/stream/etc., temp (#498)
Summary: Pull Request resolved: #498 This unifies call_on_shard_and_fetch with actor endpoints. It is now possible to issue a `.call` on a free-function remote and get a ValueMesh back. `call_on_shard_and_fetch` along with `fetch_shard` and `inspect` are implemented in terms of `call_one` now. ghstack-source-id: 295571460 exported-using-ghexport Reviewed By: mariusae Differential Revision: D77978317 fbshipit-source-id: d56eb2d551eebc52829a4aed5579eb9f9e75f9df
1 parent 1e8e202 commit 1ad1333

File tree

7 files changed

+258
-59
lines changed

7 files changed

+258
-59
lines changed

python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import collections.abc
8-
from typing import Dict, final, Iterator, List, overload
8+
from typing import Dict, final, Iterator, List, overload, Sequence
99

1010
@final
1111
class Slice:
@@ -90,7 +90,7 @@ class Slice:
9090
...
9191

9292
@staticmethod
93-
def new_row_major(ranks: list[int]) -> "Slice":
93+
def new_row_major(ranks: Sequence[int]) -> "Slice":
9494
"""Returns a contiguous slice composed of ranks"""
9595
...
9696

@@ -106,7 +106,7 @@ class Shape:
106106
- `labels`: A list of strings representing the labels for each dimension.
107107
- `slice`: An Slice object representing the shape.
108108
"""
109-
def __new__(cls, labels: List[str], slice: Slice) -> "Shape": ...
109+
def __new__(cls, labels: Sequence[str], slice: Slice) -> "Shape": ...
110110
@property
111111
def ndslice(self) -> Slice: ...
112112
@property

python/monarch/_src/actor/actor_mesh.py

Lines changed: 106 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
import random
1515
import traceback
1616

17+
from abc import ABC, abstractmethod
18+
1719
from dataclasses import dataclass
20+
from operator import mul
1821
from traceback import extract_tb, StackSummary
1922
from typing import (
2023
Any,
@@ -26,11 +29,13 @@
2629
Dict,
2730
Generic,
2831
Iterable,
32+
Iterator,
2933
List,
3034
Literal,
3135
NamedTuple,
3236
Optional,
3337
ParamSpec,
38+
Sequence,
3439
Tuple,
3540
Type,
3641
TYPE_CHECKING,
@@ -217,18 +222,41 @@ def __len__(self) -> int:
217222
return len(self._shape)
218223

219224

220-
class Endpoint(Generic[P, R]):
221-
def __init__(
225+
class Extent(NamedTuple):
226+
labels: Sequence[str]
227+
sizes: Sequence[int]
228+
229+
@property
230+
def nelements(self) -> int:
231+
return functools.reduce(mul, self.sizes, 1)
232+
233+
def __str__(self) -> str:
234+
return str(dict(zip(self.labels, self.sizes)))
235+
236+
237+
class Endpoint(ABC, Generic[P, R]):
238+
@abstractmethod
239+
def _send(
222240
self,
223-
actor_mesh_ref: _ActorMeshRefImpl,
224-
name: str,
225-
impl: Callable[Concatenate[Any, P], Awaitable[R]],
226-
mailbox: Mailbox,
227-
) -> None:
228-
self._actor_mesh = actor_mesh_ref
229-
self._name = name
230-
self._signature: inspect.Signature = inspect.signature(impl)
231-
self._mailbox = mailbox
241+
args: Tuple[Any, ...],
242+
kwargs: Dict[str, Any],
243+
port: "Optional[Port]" = None,
244+
selection: Selection = "all",
245+
) -> Extent:
246+
"""
247+
Implements sending a message to the endpoint. The return value of the endpoint will
248+
be sent to port if provided. If port is not provided, the return will be dropped,
249+
and any exception will cause the actor to fail.
250+
251+
The return value is the (multi-dimension) size of the actors that were sent a message.
252+
For ActorEndpoints this will be the actor_meshes size. For free-function endpoints,
253+
this will be the size of the currently active proc_mesh.
254+
"""
255+
pass
256+
257+
@abstractmethod
258+
def _port(self, once: bool = False) -> "PortTuple[R]":
259+
pass
232260

233261
# the following are all 'adverbs' or different ways to handle the
234262
# return values of this endpoint. Adverbs should only ever take *args, **kwargs
@@ -241,46 +269,47 @@ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
241269
242270
Load balanced RPC-style entrypoint for request/response messaging.
243271
"""
244-
p: Port[R]
245-
r: PortReceiver[R]
246272
p, r = port(self, once=True)
247273
# pyre-ignore
248-
send(self, args, kwargs, port=p, selection="choose")
274+
self._send(args, kwargs, port=p, selection="choose")
249275
return r.recv()
250276

251277
def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
252-
if len(self._actor_mesh) != 1:
278+
p, r = port(self, once=True)
279+
# pyre-ignore
280+
extent = self._send(args, kwargs, port=p, selection="choose")
281+
if extent.nelements != 1:
253282
raise ValueError(
254-
f"Can only use 'call_one' on a single Actor but this actor has shape {self._actor_mesh._shape}"
283+
f"Can only use 'call_one' on a single Actor but this actor has shape {extent}"
255284
)
256-
return self.choose(*args, **kwargs)
285+
return r.recv()
257286

258287
def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
259288
p: Port[R]
260289
r: RankedPortReceiver[R]
261290
p, r = ranked_port(self)
262291
# pyre-ignore
263-
send(self, args, kwargs, port=p)
292+
extent = self._send(args, kwargs, port=p)
264293

265294
async def process() -> ValueMesh[R]:
266-
results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
267-
for _ in range(len(self._actor_mesh)):
295+
results: List[R] = [None] * extent.nelements # pyre-fixme[9]
296+
for _ in range(extent.nelements):
268297
rank, value = await r.recv()
269298
results[rank] = value
270299
call_shape = Shape(
271-
self._actor_mesh._shape.labels,
272-
NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes),
300+
extent.labels,
301+
NDSlice.new_row_major(extent.sizes),
273302
)
274303
return ValueMesh(call_shape, results)
275304

276305
def process_blocking() -> ValueMesh[R]:
277-
results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
278-
for _ in range(len(self._actor_mesh)):
306+
results: List[R] = [None] * extent.nelements # pyre-fixme[9]
307+
for _ in range(extent.nelements):
279308
rank, value = r.recv().get()
280309
results[rank] = value
281310
call_shape = Shape(
282-
self._actor_mesh._shape.labels,
283-
NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes),
311+
extent.labels,
312+
NDSlice.new_row_major(extent.sizes),
284313
)
285314
return ValueMesh(call_shape, results)
286315

@@ -295,8 +324,8 @@ async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R
295324
"""
296325
p, r = port(self)
297326
# pyre-ignore
298-
send(self, args, kwargs, port=p)
299-
for _ in range(len(self._actor_mesh)):
327+
extent = self._send(args, kwargs, port=p)
328+
for _ in range(extent.nelements):
300329
yield await r.recv()
301330

302331
def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
@@ -311,6 +340,46 @@ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
311340
send(self, args, kwargs)
312341

313342

343+
class ActorEndpoint(Endpoint[P, R]):
344+
def __init__(
345+
self,
346+
actor_mesh_ref: _ActorMeshRefImpl,
347+
name: str,
348+
impl: Callable[Concatenate[Any, P], Awaitable[R]],
349+
mailbox: Mailbox,
350+
) -> None:
351+
self._actor_mesh = actor_mesh_ref
352+
self._name = name
353+
self._signature: inspect.Signature = inspect.signature(impl)
354+
self._mailbox = mailbox
355+
356+
def _send(
357+
self,
358+
args: Tuple[Any, ...],
359+
kwargs: Dict[str, Any],
360+
port: "Optional[Port]" = None,
361+
selection: Selection = "all",
362+
) -> Extent:
363+
"""
364+
Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
365+
366+
This sends the message to all actors but does not wait for any result.
367+
"""
368+
self._signature.bind(None, *args, **kwargs)
369+
message = PythonMessage(
370+
self._name,
371+
_pickle((args, kwargs)),
372+
None if port is None else port._port_ref,
373+
None,
374+
)
375+
self._actor_mesh.cast(message, selection)
376+
shape = self._actor_mesh._shape
377+
return Extent(shape.labels, shape.ndslice.sizes)
378+
379+
def _port(self, once: bool = False) -> "PortTuple[R]":
380+
return PortTuple.create(self._mailbox, once)
381+
382+
314383
class Accumulator(Generic[P, R, A]):
315384
def __init__(
316385
self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A]
@@ -350,10 +419,13 @@ def item(self, **kwargs) -> R:
350419

351420
return self._values[self._ndslice.nditem(coordinates)]
352421

353-
def __iter__(self):
422+
def items(self) -> Iterable[Tuple[Point, R]]:
354423
for rank in self._shape.ranks():
355424
yield Point(rank, self._shape), self._values[rank]
356425

426+
def __iter__(self) -> Iterator[Tuple[Point, R]]:
427+
return iter(self.items())
428+
357429
def __len__(self) -> int:
358430
return len(self._shape)
359431

@@ -381,14 +453,7 @@ def send(
381453
382454
This sends the message to all actors but does not wait for any result.
383455
"""
384-
endpoint._signature.bind(None, *args, **kwargs)
385-
message = PythonMessage(
386-
endpoint._name,
387-
_pickle((args, kwargs)),
388-
None if port is None else port._port_ref,
389-
None,
390-
)
391-
endpoint._actor_mesh.cast(message, selection)
456+
endpoint._send(args, kwargs, port, selection)
392457

393458

394459
class EndpointProperty(Generic[P, R]):
@@ -460,7 +525,7 @@ def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
460525
# not part of the Endpoint API because they way it accepts arguments
461526
# and handles concerns is different.
462527
def port(endpoint: Endpoint[P, R], once: bool = False) -> "PortTuple[R]":
463-
return PortTuple.create(endpoint._mailbox, once)
528+
return endpoint._port(once)
464529

465530

466531
def ranked_port(
@@ -705,7 +770,7 @@ def __init__(
705770
setattr(
706771
self,
707772
attr_name,
708-
Endpoint(
773+
ActorEndpoint(
709774
self._actor_mesh_ref,
710775
attr_name,
711776
attr_value._method,
@@ -724,7 +789,7 @@ def __getattr__(self, name: str) -> Any:
724789
attr = getattr(self._class, name)
725790
if isinstance(attr, EndpointProperty):
726791
# Dynamically create the endpoint
727-
endpoint = Endpoint(
792+
endpoint = ActorEndpoint(
728793
self._actor_mesh_ref,
729794
name,
730795
attr._method,
@@ -747,7 +812,7 @@ def _create(
747812
async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None:
748813
return None
749814

750-
ep = Endpoint(
815+
ep = ActorEndpoint(
751816
self._actor_mesh_ref,
752817
"__init__",
753818
null_func,

python/monarch/common/device_mesh.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
from torch.utils._python_dispatch import TorchDispatchMode
3434
from torch.utils._pytree import tree_map
35+
from torch.utils.weak import weakref
3536

3637
from ._tensor_to_table import tensor_to_table
3738
from .context_manager import activate_first_context_manager
@@ -170,6 +171,7 @@ def __init__(
170171
self.exit = lambda: None
171172
self.ref = None
172173
self._active_mesh_context = None
174+
self._subset_of: Optional[weakref.ReferenceType["DeviceMesh"]] = None
173175

174176
def define_remotely(self):
175177
if self.ref is None:
@@ -227,8 +229,17 @@ def _labels(self) -> Tuple[str, ...]:
227229
def _new_with_shape(self, shape: Shape) -> "DeviceMesh":
228230
mesh = DeviceMesh(self.client, shape.ndslice, tuple(shape.labels))
229231
mesh.exit = self.exit
232+
mesh._subset_of = weakref.ref(self)
230233
return mesh
231234

235+
def _is_subset_of(self, other: "DeviceMesh") -> bool:
236+
p = self
237+
while p is not None:
238+
if p is other:
239+
return True
240+
p = None if p._subset_of is None else p._subset_of()
241+
return False
242+
232243
def __call__(self, **kwargs) -> "DeviceMesh":
233244
"""
234245
device_mesh(batch=3) or device_mesh(batch=slice(3, None))

0 commit comments

Comments
 (0)