14
14
import random
15
15
import traceback
16
16
17
+ from abc import ABC , abstractmethod
18
+
17
19
from dataclasses import dataclass
20
+ from operator import mul
18
21
from traceback import extract_tb , StackSummary
19
22
from typing import (
20
23
Any ,
26
29
Dict ,
27
30
Generic ,
28
31
Iterable ,
32
+ Iterator ,
29
33
List ,
30
34
Literal ,
31
35
NamedTuple ,
32
36
Optional ,
33
37
ParamSpec ,
38
+ Sequence ,
34
39
Tuple ,
35
40
Type ,
36
41
TYPE_CHECKING ,
@@ -217,18 +222,41 @@ def __len__(self) -> int:
217
222
return len (self ._shape )
218
223
219
224
220
- class Endpoint (Generic [P , R ]):
221
- def __init__ (
225
+ class Extent (NamedTuple ):
226
+ labels : Sequence [str ]
227
+ sizes : Sequence [int ]
228
+
229
+ @property
230
+ def nelements (self ) -> int :
231
+ return functools .reduce (mul , self .sizes , 1 )
232
+
233
+ def __str__ (self ) -> str :
234
+ return str (dict (zip (self .labels , self .sizes )))
235
+
236
+
237
+ class Endpoint (ABC , Generic [P , R ]):
238
+ @abstractmethod
239
+ def _send (
222
240
self ,
223
- actor_mesh_ref : _ActorMeshRefImpl ,
224
- name : str ,
225
- impl : Callable [Concatenate [Any , P ], Awaitable [R ]],
226
- mailbox : Mailbox ,
227
- ) -> None :
228
- self ._actor_mesh = actor_mesh_ref
229
- self ._name = name
230
- self ._signature : inspect .Signature = inspect .signature (impl )
231
- self ._mailbox = mailbox
241
+ args : Tuple [Any , ...],
242
+ kwargs : Dict [str , Any ],
243
+ port : "Optional[Port]" = None ,
244
+ selection : Selection = "all" ,
245
+ ) -> Extent :
246
+ """
247
+ Implements sending a message to the endpoint. The return value of the endpoint will
248
+ be sent to port if provided. If port is not provided, the return will be dropped,
249
+ and any exception will cause the actor to fail.
250
+
251
+ The return value is the (multi-dimension) size of the actors that were sent a message.
252
+ For ActorEndpoints this will be the actor_meshes size. For free-function endpoints,
253
+ this will be the size of the currently active proc_mesh.
254
+ """
255
+ pass
256
+
257
+ @abstractmethod
258
+ def _port (self , once : bool = False ) -> "PortTuple[R]" :
259
+ pass
232
260
233
261
# the following are all 'adverbs' or different ways to handle the
234
262
# return values of this endpoint. Adverbs should only ever take *args, **kwargs
@@ -241,46 +269,47 @@ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
241
269
242
270
Load balanced RPC-style entrypoint for request/response messaging.
243
271
"""
244
- p : Port [R ]
245
- r : PortReceiver [R ]
246
272
p , r = port (self , once = True )
247
273
# pyre-ignore
248
- send ( self , args , kwargs , port = p , selection = "choose" )
274
+ self . _send ( args , kwargs , port = p , selection = "choose" )
249
275
return r .recv ()
250
276
251
277
def call_one (self , * args : P .args , ** kwargs : P .kwargs ) -> Future [R ]:
252
- if len (self ._actor_mesh ) != 1 :
278
+ p , r = port (self , once = True )
279
+ # pyre-ignore
280
+ extent = self ._send (args , kwargs , port = p , selection = "choose" )
281
+ if extent .nelements != 1 :
253
282
raise ValueError (
254
- f"Can only use 'call_one' on a single Actor but this actor has shape { self . _actor_mesh . _shape } "
283
+ f"Can only use 'call_one' on a single Actor but this actor has shape { extent } "
255
284
)
256
- return self . choose ( * args , ** kwargs )
285
+ return r . recv ( )
257
286
258
287
def call (self , * args : P .args , ** kwargs : P .kwargs ) -> "Future[ValueMesh[R]]" :
259
288
p : Port [R ]
260
289
r : RankedPortReceiver [R ]
261
290
p , r = ranked_port (self )
262
291
# pyre-ignore
263
- send ( self , args , kwargs , port = p )
292
+ extent = self . _send ( args , kwargs , port = p )
264
293
265
294
async def process () -> ValueMesh [R ]:
266
- results : List [R ] = [None ] * len ( self . _actor_mesh ) # pyre-fixme[9]
267
- for _ in range (len ( self . _actor_mesh ) ):
295
+ results : List [R ] = [None ] * extent . nelements # pyre-fixme[9]
296
+ for _ in range (extent . nelements ):
268
297
rank , value = await r .recv ()
269
298
results [rank ] = value
270
299
call_shape = Shape (
271
- self . _actor_mesh . _shape .labels ,
272
- NDSlice .new_row_major (self . _actor_mesh . _shape . ndslice .sizes ),
300
+ extent .labels ,
301
+ NDSlice .new_row_major (extent .sizes ),
273
302
)
274
303
return ValueMesh (call_shape , results )
275
304
276
305
def process_blocking () -> ValueMesh [R ]:
277
- results : List [R ] = [None ] * len ( self . _actor_mesh ) # pyre-fixme[9]
278
- for _ in range (len ( self . _actor_mesh ) ):
306
+ results : List [R ] = [None ] * extent . nelements # pyre-fixme[9]
307
+ for _ in range (extent . nelements ):
279
308
rank , value = r .recv ().get ()
280
309
results [rank ] = value
281
310
call_shape = Shape (
282
- self . _actor_mesh . _shape .labels ,
283
- NDSlice .new_row_major (self . _actor_mesh . _shape . ndslice .sizes ),
311
+ extent .labels ,
312
+ NDSlice .new_row_major (extent .sizes ),
284
313
)
285
314
return ValueMesh (call_shape , results )
286
315
@@ -295,8 +324,8 @@ async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R
295
324
"""
296
325
p , r = port (self )
297
326
# pyre-ignore
298
- send ( self , args , kwargs , port = p )
299
- for _ in range (len ( self . _actor_mesh ) ):
327
+ extent = self . _send ( args , kwargs , port = p )
328
+ for _ in range (extent . nelements ):
300
329
yield await r .recv ()
301
330
302
331
def broadcast (self , * args : P .args , ** kwargs : P .kwargs ) -> None :
@@ -311,6 +340,46 @@ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
311
340
send (self , args , kwargs )
312
341
313
342
343
+ class ActorEndpoint (Endpoint [P , R ]):
344
+ def __init__ (
345
+ self ,
346
+ actor_mesh_ref : _ActorMeshRefImpl ,
347
+ name : str ,
348
+ impl : Callable [Concatenate [Any , P ], Awaitable [R ]],
349
+ mailbox : Mailbox ,
350
+ ) -> None :
351
+ self ._actor_mesh = actor_mesh_ref
352
+ self ._name = name
353
+ self ._signature : inspect .Signature = inspect .signature (impl )
354
+ self ._mailbox = mailbox
355
+
356
+ def _send (
357
+ self ,
358
+ args : Tuple [Any , ...],
359
+ kwargs : Dict [str , Any ],
360
+ port : "Optional[Port]" = None ,
361
+ selection : Selection = "all" ,
362
+ ) -> Extent :
363
+ """
364
+ Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
365
+
366
+ This sends the message to all actors but does not wait for any result.
367
+ """
368
+ self ._signature .bind (None , * args , ** kwargs )
369
+ message = PythonMessage (
370
+ self ._name ,
371
+ _pickle ((args , kwargs )),
372
+ None if port is None else port ._port_ref ,
373
+ None ,
374
+ )
375
+ self ._actor_mesh .cast (message , selection )
376
+ shape = self ._actor_mesh ._shape
377
+ return Extent (shape .labels , shape .ndslice .sizes )
378
+
379
+ def _port (self , once : bool = False ) -> "PortTuple[R]" :
380
+ return PortTuple .create (self ._mailbox , once )
381
+
382
+
314
383
class Accumulator (Generic [P , R , A ]):
315
384
def __init__ (
316
385
self , endpoint : Endpoint [P , R ], identity : A , combine : Callable [[A , R ], A ]
@@ -350,10 +419,13 @@ def item(self, **kwargs) -> R:
350
419
351
420
return self ._values [self ._ndslice .nditem (coordinates )]
352
421
353
- def __iter__ (self ):
422
+ def items (self ) -> Iterable [ Tuple [ Point , R ]] :
354
423
for rank in self ._shape .ranks ():
355
424
yield Point (rank , self ._shape ), self ._values [rank ]
356
425
426
+ def __iter__ (self ) -> Iterator [Tuple [Point , R ]]:
427
+ return iter (self .items ())
428
+
357
429
def __len__ (self ) -> int :
358
430
return len (self ._shape )
359
431
@@ -381,14 +453,7 @@ def send(
381
453
382
454
This sends the message to all actors but does not wait for any result.
383
455
"""
384
- endpoint ._signature .bind (None , * args , ** kwargs )
385
- message = PythonMessage (
386
- endpoint ._name ,
387
- _pickle ((args , kwargs )),
388
- None if port is None else port ._port_ref ,
389
- None ,
390
- )
391
- endpoint ._actor_mesh .cast (message , selection )
456
+ endpoint ._send (args , kwargs , port , selection )
392
457
393
458
394
459
class EndpointProperty (Generic [P , R ]):
@@ -460,7 +525,7 @@ def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
460
525
# not part of the Endpoint API because they way it accepts arguments
461
526
# and handles concerns is different.
462
527
def port (endpoint : Endpoint [P , R ], once : bool = False ) -> "PortTuple[R]" :
463
- return PortTuple . create ( endpoint ._mailbox , once )
528
+ return endpoint ._port ( once )
464
529
465
530
466
531
def ranked_port (
@@ -705,7 +770,7 @@ def __init__(
705
770
setattr (
706
771
self ,
707
772
attr_name ,
708
- Endpoint (
773
+ ActorEndpoint (
709
774
self ._actor_mesh_ref ,
710
775
attr_name ,
711
776
attr_value ._method ,
@@ -724,7 +789,7 @@ def __getattr__(self, name: str) -> Any:
724
789
attr = getattr (self ._class , name )
725
790
if isinstance (attr , EndpointProperty ):
726
791
# Dynamically create the endpoint
727
- endpoint = Endpoint (
792
+ endpoint = ActorEndpoint (
728
793
self ._actor_mesh_ref ,
729
794
name ,
730
795
attr ._method ,
@@ -747,7 +812,7 @@ def _create(
747
812
async def null_func (* _args : Iterable [Any ], ** _kwargs : Dict [str , Any ]) -> None :
748
813
return None
749
814
750
- ep = Endpoint (
815
+ ep = ActorEndpoint (
751
816
self ._actor_mesh_ref ,
752
817
"__init__" ,
753
818
null_func ,
0 commit comments