15
15
import sys
16
16
import traceback
17
17
18
+ from abc import ABC , abstractmethod
19
+
18
20
from dataclasses import dataclass
21
+ from operator import mul
19
22
from traceback import extract_tb , StackSummary
20
23
from typing import (
21
24
Any ,
27
30
Dict ,
28
31
Generic ,
29
32
Iterable ,
33
+ Iterator ,
30
34
List ,
31
35
Literal ,
32
36
NamedTuple ,
33
37
Optional ,
34
38
ParamSpec ,
39
+ Sequence ,
35
40
Tuple ,
36
41
Type ,
37
42
TYPE_CHECKING ,
@@ -204,18 +209,32 @@ def __len__(self) -> int:
204
209
return len (self ._shape )
205
210
206
211
207
- class Endpoint (Generic [P , R ]):
208
- def __init__ (
212
+ class Extent (NamedTuple ):
213
+ labels : Sequence [str ]
214
+ sizes : Sequence [int ]
215
+
216
+ @property
217
+ def nelements (self ) -> int :
218
+ return functools .reduce (mul , self .sizes , 1 )
219
+
220
+ def __str__ (self ) -> str :
221
+ return str (dict (zip (self .labels , self .sizes )))
222
+
223
+
224
+ class Endpoint (ABC , Generic [P , R ]):
225
+ @abstractmethod
226
+ def _send (
209
227
self ,
210
- actor_mesh_ref : _ActorMeshRefImpl ,
211
- name : str ,
212
- impl : Callable [Concatenate [Any , P ], Awaitable [R ]],
213
- mailbox : Mailbox ,
214
- ) -> None :
215
- self ._actor_mesh = actor_mesh_ref
216
- self ._name = name
217
- self ._signature : inspect .Signature = inspect .signature (impl )
218
- self ._mailbox = mailbox
228
+ args : Tuple [Any , ...],
229
+ kwargs : Dict [str , Any ],
230
+ port : "Optional[Port]" = None ,
231
+ selection : Selection = "all" ,
232
+ ) -> Extent :
233
+ pass
234
+
235
+ @abstractmethod
236
+ def _port (self , once : bool = False ) -> "PortTuple[R]" :
237
+ pass
219
238
220
239
# the following are all 'adverbs' or different ways to handle the
221
240
# return values of this endpoint. Adverbs should only ever take *args, **kwargs
@@ -228,46 +247,47 @@ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
228
247
229
248
Load balanced RPC-style entrypoint for request/response messaging.
230
249
"""
231
- p : Port [R ]
232
- r : PortReceiver [R ]
233
250
p , r = port (self , once = True )
234
251
# pyre-ignore
235
- send ( self , args , kwargs , port = p , selection = "choose" )
252
+ self . _send ( args , kwargs , port = p , selection = "choose" )
236
253
return r .recv ()
237
254
238
255
def call_one (self , * args : P .args , ** kwargs : P .kwargs ) -> Future [R ]:
239
- if len (self ._actor_mesh ) != 1 :
256
+ p , r = port (self , once = True )
257
+ # pyre-ignore
258
+ extent = self ._send (args , kwargs , port = p , selection = "choose" )
259
+ if extent .nelements != 1 :
240
260
raise ValueError (
241
- f"Can only use 'call_one' on a single Actor but this actor has shape { self . _actor_mesh . _shape } "
261
+ f"Can only use 'call_one' on a single Actor but this actor has shape { extent } "
242
262
)
243
- return self . choose ( * args , ** kwargs )
263
+ return r . recv ( )
244
264
245
265
def call (self , * args : P .args , ** kwargs : P .kwargs ) -> "Future[ValueMesh[R]]" :
246
266
p : Port [R ]
247
267
r : RankedPortReceiver [R ]
248
268
p , r = ranked_port (self )
249
269
# pyre-ignore
250
- send ( self , args , kwargs , port = p )
270
+ extent = self . _send ( args , kwargs , port = p )
251
271
252
272
async def process () -> ValueMesh [R ]:
253
- results : List [R ] = [None ] * len ( self . _actor_mesh ) # pyre-fixme[9]
254
- for _ in range (len ( self . _actor_mesh ) ):
273
+ results : List [R ] = [None ] * extent . nelements # pyre-fixme[9]
274
+ for _ in range (extent . nelements ):
255
275
rank , value = await r .recv ()
256
276
results [rank ] = value
257
277
call_shape = Shape (
258
- self . _actor_mesh . _shape .labels ,
259
- NDSlice .new_row_major (self . _actor_mesh . _shape . ndslice .sizes ),
278
+ extent .labels ,
279
+ NDSlice .new_row_major (extent .sizes ),
260
280
)
261
281
return ValueMesh (call_shape , results )
262
282
263
283
def process_blocking () -> ValueMesh [R ]:
264
- results : List [R ] = [None ] * len ( self . _actor_mesh ) # pyre-fixme[9]
265
- for _ in range (len ( self . _actor_mesh ) ):
284
+ results : List [R ] = [None ] * extent . nelements # pyre-fixme[9]
285
+ for _ in range (extent . nelements ):
266
286
rank , value = r .recv ().get ()
267
287
results [rank ] = value
268
288
call_shape = Shape (
269
- self . _actor_mesh . _shape .labels ,
270
- NDSlice .new_row_major (self . _actor_mesh . _shape . ndslice .sizes ),
289
+ extent .labels ,
290
+ NDSlice .new_row_major (extent .sizes ),
271
291
)
272
292
return ValueMesh (call_shape , results )
273
293
@@ -282,8 +302,8 @@ async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R
282
302
"""
283
303
p , r = port (self )
284
304
# pyre-ignore
285
- send ( self , args , kwargs , port = p )
286
- for _ in range (len ( self . _actor_mesh ) ):
305
+ extent = self . _send ( args , kwargs , port = p )
306
+ for _ in range (extent . nelements ):
287
307
yield await r .recv ()
288
308
289
309
def broadcast (self , * args : P .args , ** kwargs : P .kwargs ) -> None :
@@ -298,6 +318,46 @@ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
298
318
send (self , args , kwargs )
299
319
300
320
321
+ class ActorEndpoint (Endpoint [P , R ]):
322
+ def __init__ (
323
+ self ,
324
+ actor_mesh_ref : _ActorMeshRefImpl ,
325
+ name : str ,
326
+ impl : Callable [Concatenate [Any , P ], Awaitable [R ]],
327
+ mailbox : Mailbox ,
328
+ ) -> None :
329
+ self ._actor_mesh = actor_mesh_ref
330
+ self ._name = name
331
+ self ._signature : inspect .Signature = inspect .signature (impl )
332
+ self ._mailbox = mailbox
333
+
334
+ def _send (
335
+ self ,
336
+ args : Tuple [Any , ...],
337
+ kwargs : Dict [str , Any ],
338
+ port : "Optional[Port]" = None ,
339
+ selection : Selection = "all" ,
340
+ ) -> Extent :
341
+ """
342
+ Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
343
+
344
+ This sends the message to all actors but does not wait for any result.
345
+ """
346
+ self ._signature .bind (None , * args , ** kwargs )
347
+ message = PythonMessage (
348
+ self ._name ,
349
+ _pickle ((args , kwargs )),
350
+ None if port is None else port ._port_ref ,
351
+ None ,
352
+ )
353
+ self ._actor_mesh .cast (message , selection )
354
+ shape = self ._actor_mesh ._shape
355
+ return Extent (shape .labels , shape .ndslice .sizes )
356
+
357
+ def _port (self , once : bool = False ) -> "PortTuple[R]" :
358
+ return PortTuple .create (self ._mailbox , once )
359
+
360
+
301
361
class Accumulator (Generic [P , R , A ]):
302
362
def __init__ (
303
363
self , endpoint : Endpoint [P , R ], identity : A , combine : Callable [[A , R ], A ]
@@ -337,10 +397,13 @@ def item(self, **kwargs) -> R:
337
397
338
398
return self ._values [self ._ndslice .nditem (coordinates )]
339
399
340
- def __iter__ (self ):
400
+ def items (self ) -> Iterable [ Tuple [ Point , R ]] :
341
401
for rank in self ._shape .ranks ():
342
402
yield Point (rank , self ._shape ), self ._values [rank ]
343
403
404
+ def __iter__ (self ) -> Iterator [Tuple [Point , R ]]:
405
+ return iter (self .items ())
406
+
344
407
def __len__ (self ) -> int :
345
408
return len (self ._shape )
346
409
@@ -368,14 +431,7 @@ def send(
368
431
369
432
This sends the message to all actors but does not wait for any result.
370
433
"""
371
- endpoint ._signature .bind (None , * args , ** kwargs )
372
- message = PythonMessage (
373
- endpoint ._name ,
374
- _pickle ((args , kwargs )),
375
- None if port is None else port ._port_ref ,
376
- None ,
377
- )
378
- endpoint ._actor_mesh .cast (message , selection )
434
+ endpoint ._send (args , kwargs , port , selection )
379
435
380
436
381
437
class EndpointProperty (Generic [P , R ]):
@@ -447,7 +503,7 @@ def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
447
503
# not part of the Endpoint API because they way it accepts arguments
448
504
# and handles concerns is different.
449
505
def port (endpoint : Endpoint [P , R ], once : bool = False ) -> "PortTuple[R]" :
450
- return PortTuple . create ( endpoint ._mailbox , once )
506
+ return endpoint ._port ( once )
451
507
452
508
453
509
def ranked_port (
@@ -676,7 +732,7 @@ def __init__(
676
732
setattr (
677
733
self ,
678
734
attr_name ,
679
- Endpoint (
735
+ ActorEndpoint (
680
736
self ._actor_mesh_ref ,
681
737
attr_name ,
682
738
attr_value ._method ,
@@ -695,7 +751,7 @@ def __getattr__(self, name: str) -> Any:
695
751
attr = getattr (self ._class , name )
696
752
if isinstance (attr , EndpointProperty ):
697
753
# Dynamically create the endpoint
698
- endpoint = Endpoint (
754
+ endpoint = ActorEndpoint (
699
755
self ._actor_mesh_ref ,
700
756
name ,
701
757
attr ._method ,
@@ -718,7 +774,7 @@ def _create(
718
774
async def null_func (* _args : Iterable [Any ], ** _kwargs : Dict [str , Any ]) -> None :
719
775
return None
720
776
721
- ep = Endpoint (
777
+ ep = ActorEndpoint (
722
778
self ._actor_mesh_ref ,
723
779
"__init__" ,
724
780
null_func ,
0 commit comments