15
15
import sys
16
16
import traceback
17
17
18
+ from abc import ABC , abstractmethod
19
+
18
20
from dataclasses import dataclass
21
+ from operator import mul
19
22
from traceback import extract_tb , StackSummary
20
23
from typing import (
21
24
Any ,
27
30
Dict ,
28
31
Generic ,
29
32
Iterable ,
33
+ Iterator ,
30
34
List ,
31
35
Literal ,
32
36
NamedTuple ,
33
37
Optional ,
34
38
ParamSpec ,
39
+ Sequence ,
35
40
Tuple ,
36
41
Type ,
37
42
TYPE_CHECKING ,
@@ -204,18 +209,41 @@ def __len__(self) -> int:
204
209
return len (self ._shape )
205
210
206
211
207
- class Endpoint (Generic [P , R ]):
208
- def __init__ (
212
+ class Extent (NamedTuple ):
213
+ labels : Sequence [str ]
214
+ sizes : Sequence [int ]
215
+
216
+ @property
217
+ def nelements (self ) -> int :
218
+ return functools .reduce (mul , self .sizes , 1 )
219
+
220
+ def __str__ (self ) -> str :
221
+ return str (dict (zip (self .labels , self .sizes )))
222
+
223
+
224
+ class Endpoint (ABC , Generic [P , R ]):
225
+ @abstractmethod
226
+ def _send (
209
227
self ,
210
- actor_mesh_ref : _ActorMeshRefImpl ,
211
- name : str ,
212
- impl : Callable [Concatenate [Any , P ], Awaitable [R ]],
213
- mailbox : Mailbox ,
214
- ) -> None :
215
- self ._actor_mesh = actor_mesh_ref
216
- self ._name = name
217
- self ._signature : inspect .Signature = inspect .signature (impl )
218
- self ._mailbox = mailbox
228
+ args : Tuple [Any , ...],
229
+ kwargs : Dict [str , Any ],
230
+ port : "Optional[Port]" = None ,
231
+ selection : Selection = "all" ,
232
+ ) -> Extent :
233
+ """
234
+ Implements sending a message to the endpoint. The return value of the endpoint will
235
+ be sent to port if provided. If port is not provided, the return will be dropped,
236
+ and any exception will cause the actor to fail.
237
+
238
+ The return value is the (multi-dimension) size of the actors that were sent a message.
239
+ For ActorEndpoints this will be the actor_meshes size. For free-function endpoints,
240
+ this will be the size of the currently active proc_mesh.
241
+ """
242
+ pass
243
+
244
+ @abstractmethod
245
+ def _port (self , once : bool = False ) -> "PortTuple[R]" :
246
+ pass
219
247
220
248
# the following are all 'adverbs' or different ways to handle the
221
249
# return values of this endpoint. Adverbs should only ever take *args, **kwargs
@@ -228,46 +256,47 @@ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
228
256
229
257
Load balanced RPC-style entrypoint for request/response messaging.
230
258
"""
231
- p : Port [R ]
232
- r : PortReceiver [R ]
233
259
p , r = port (self , once = True )
234
260
# pyre-ignore
235
- send ( self , args , kwargs , port = p , selection = "choose" )
261
+ self . _send ( args , kwargs , port = p , selection = "choose" )
236
262
return r .recv ()
237
263
238
264
def call_one (self , * args : P .args , ** kwargs : P .kwargs ) -> Future [R ]:
239
- if len (self ._actor_mesh ) != 1 :
265
+ p , r = port (self , once = True )
266
+ # pyre-ignore
267
+ extent = self ._send (args , kwargs , port = p , selection = "choose" )
268
+ if extent .nelements != 1 :
240
269
raise ValueError (
241
- f"Can only use 'call_one' on a single Actor but this actor has shape { self . _actor_mesh . _shape } "
270
+ f"Can only use 'call_one' on a single Actor but this actor has shape { extent } "
242
271
)
243
- return self . choose ( * args , ** kwargs )
272
+ return r . recv ( )
244
273
245
274
def call (self , * args : P .args , ** kwargs : P .kwargs ) -> "Future[ValueMesh[R]]" :
246
275
p : Port [R ]
247
276
r : RankedPortReceiver [R ]
248
277
p , r = ranked_port (self )
249
278
# pyre-ignore
250
- send ( self , args , kwargs , port = p )
279
+ extent = self . _send ( args , kwargs , port = p )
251
280
252
281
async def process () -> ValueMesh [R ]:
253
- results : List [R ] = [None ] * len ( self . _actor_mesh ) # pyre-fixme[9]
254
- for _ in range (len ( self . _actor_mesh ) ):
282
+ results : List [R ] = [None ] * extent . nelements # pyre-fixme[9]
283
+ for _ in range (extent . nelements ):
255
284
rank , value = await r .recv ()
256
285
results [rank ] = value
257
286
call_shape = Shape (
258
- self . _actor_mesh . _shape .labels ,
259
- NDSlice .new_row_major (self . _actor_mesh . _shape . ndslice .sizes ),
287
+ extent .labels ,
288
+ NDSlice .new_row_major (extent .sizes ),
260
289
)
261
290
return ValueMesh (call_shape , results )
262
291
263
292
def process_blocking () -> ValueMesh [R ]:
264
- results : List [R ] = [None ] * len ( self . _actor_mesh ) # pyre-fixme[9]
265
- for _ in range (len ( self . _actor_mesh ) ):
293
+ results : List [R ] = [None ] * extent . nelements # pyre-fixme[9]
294
+ for _ in range (extent . nelements ):
266
295
rank , value = r .recv ().get ()
267
296
results [rank ] = value
268
297
call_shape = Shape (
269
- self . _actor_mesh . _shape .labels ,
270
- NDSlice .new_row_major (self . _actor_mesh . _shape . ndslice .sizes ),
298
+ extent .labels ,
299
+ NDSlice .new_row_major (extent .sizes ),
271
300
)
272
301
return ValueMesh (call_shape , results )
273
302
@@ -282,8 +311,8 @@ async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R
282
311
"""
283
312
p , r = port (self )
284
313
# pyre-ignore
285
- send ( self , args , kwargs , port = p )
286
- for _ in range (len ( self . _actor_mesh ) ):
314
+ extent = self . _send ( args , kwargs , port = p )
315
+ for _ in range (extent . nelements ):
287
316
yield await r .recv ()
288
317
289
318
def broadcast (self , * args : P .args , ** kwargs : P .kwargs ) -> None :
@@ -298,6 +327,46 @@ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
298
327
send (self , args , kwargs )
299
328
300
329
330
+ class ActorEndpoint (Endpoint [P , R ]):
331
+ def __init__ (
332
+ self ,
333
+ actor_mesh_ref : _ActorMeshRefImpl ,
334
+ name : str ,
335
+ impl : Callable [Concatenate [Any , P ], Awaitable [R ]],
336
+ mailbox : Mailbox ,
337
+ ) -> None :
338
+ self ._actor_mesh = actor_mesh_ref
339
+ self ._name = name
340
+ self ._signature : inspect .Signature = inspect .signature (impl )
341
+ self ._mailbox = mailbox
342
+
343
+ def _send (
344
+ self ,
345
+ args : Tuple [Any , ...],
346
+ kwargs : Dict [str , Any ],
347
+ port : "Optional[Port]" = None ,
348
+ selection : Selection = "all" ,
349
+ ) -> Extent :
350
+ """
351
+ Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
352
+
353
+ This sends the message to all actors but does not wait for any result.
354
+ """
355
+ self ._signature .bind (None , * args , ** kwargs )
356
+ message = PythonMessage (
357
+ self ._name ,
358
+ _pickle ((args , kwargs )),
359
+ None if port is None else port ._port_ref ,
360
+ None ,
361
+ )
362
+ self ._actor_mesh .cast (message , selection )
363
+ shape = self ._actor_mesh ._shape
364
+ return Extent (shape .labels , shape .ndslice .sizes )
365
+
366
+ def _port (self , once : bool = False ) -> "PortTuple[R]" :
367
+ return PortTuple .create (self ._mailbox , once )
368
+
369
+
301
370
class Accumulator (Generic [P , R , A ]):
302
371
def __init__ (
303
372
self , endpoint : Endpoint [P , R ], identity : A , combine : Callable [[A , R ], A ]
@@ -337,10 +406,13 @@ def item(self, **kwargs) -> R:
337
406
338
407
return self ._values [self ._ndslice .nditem (coordinates )]
339
408
340
- def __iter__ (self ):
409
+ def items (self ) -> Iterable [ Tuple [ Point , R ]] :
341
410
for rank in self ._shape .ranks ():
342
411
yield Point (rank , self ._shape ), self ._values [rank ]
343
412
413
+ def __iter__ (self ) -> Iterator [Tuple [Point , R ]]:
414
+ return iter (self .items ())
415
+
344
416
def __len__ (self ) -> int :
345
417
return len (self ._shape )
346
418
@@ -368,14 +440,7 @@ def send(
368
440
369
441
This sends the message to all actors but does not wait for any result.
370
442
"""
371
- endpoint ._signature .bind (None , * args , ** kwargs )
372
- message = PythonMessage (
373
- endpoint ._name ,
374
- _pickle ((args , kwargs )),
375
- None if port is None else port ._port_ref ,
376
- None ,
377
- )
378
- endpoint ._actor_mesh .cast (message , selection )
443
+ endpoint ._send (args , kwargs , port , selection )
379
444
380
445
381
446
class EndpointProperty (Generic [P , R ]):
@@ -447,7 +512,7 @@ def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
447
512
# not part of the Endpoint API because they way it accepts arguments
448
513
# and handles concerns is different.
449
514
def port (endpoint : Endpoint [P , R ], once : bool = False ) -> "PortTuple[R]" :
450
- return PortTuple . create ( endpoint ._mailbox , once )
515
+ return endpoint ._port ( once )
451
516
452
517
453
518
def ranked_port (
@@ -676,7 +741,7 @@ def __init__(
676
741
setattr (
677
742
self ,
678
743
attr_name ,
679
- Endpoint (
744
+ ActorEndpoint (
680
745
self ._actor_mesh_ref ,
681
746
attr_name ,
682
747
attr_value ._method ,
@@ -695,7 +760,7 @@ def __getattr__(self, name: str) -> Any:
695
760
attr = getattr (self ._class , name )
696
761
if isinstance (attr , EndpointProperty ):
697
762
# Dynamically create the endpoint
698
- endpoint = Endpoint (
763
+ endpoint = ActorEndpoint (
699
764
self ._actor_mesh_ref ,
700
765
name ,
701
766
attr ._method ,
@@ -718,7 +783,7 @@ def _create(
718
783
async def null_func (* _args : Iterable [Any ], ** _kwargs : Dict [str , Any ]) -> None :
719
784
return None
720
785
721
- ep = Endpoint (
786
+ ep = ActorEndpoint (
722
787
self ._actor_mesh_ref ,
723
788
"__init__" ,
724
789
null_func ,
0 commit comments