7
7
# pyre-unsafe
8
8
9
9
import pickle
10
- from typing import List
10
+ from typing import Any , List
11
11
12
12
import monarch
13
13
import pytest
14
14
15
- from monarch ._rust_bindings .monarch_hyperactor .actor import PanicFlag , PythonMessage
15
+ from monarch ._rust_bindings .monarch_hyperactor .actor import (
16
+ PanicFlag ,
17
+ PythonMessage ,
18
+ PythonMessageKind ,
19
+ )
16
20
from monarch ._rust_bindings .monarch_hyperactor .actor_mesh import (
17
21
PythonActorMesh ,
18
22
PythonActorMeshRef ,
@@ -45,12 +49,21 @@ async def handle(
45
49
shape : Shape ,
46
50
message : PythonMessage ,
47
51
panic_flag : PanicFlag ,
52
+ local_state : List [Any ] | None = None ,
48
53
) -> None :
49
54
assert rank is not None
50
- reply_port = message .response_port
51
- assert reply_port is not None
55
+
56
+ # Extract response_port from the message kind
57
+ call_method = message .kind
58
+ assert isinstance (call_method , PythonMessageKind .CallMethod )
59
+ assert call_method .response_port is not None
60
+
61
+ reply_port = call_method .response_port
52
62
reply_port .send (
53
- mailbox , PythonMessage ("pong" , pickle .dumps (f"rank: { rank } " ), None , rank )
63
+ mailbox ,
64
+ PythonMessage (
65
+ PythonMessageKind .Result (rank ), pickle .dumps (f"rank: { rank } " )
66
+ ),
54
67
)
55
68
56
69
@@ -77,7 +90,9 @@ async def verify_cast(
77
90
handle , receiver = mailbox .open_port ()
78
91
port_ref = handle .bind ()
79
92
80
- message = PythonMessage ("echo" , pickle .dumps ("ping" ), port_ref , None )
93
+ message = PythonMessage (
94
+ PythonMessageKind .CallMethod ("echo" , port_ref ), pickle .dumps ("ping" )
95
+ )
81
96
sel = Selection .from_string ("*" )
82
97
if isinstance (actor_mesh , PythonActorMesh ):
83
98
actor_mesh .cast (sel , message )
@@ -87,7 +102,9 @@ async def verify_cast(
87
102
rcv_ranks = []
88
103
for _ in range (len (cast_ranks )):
89
104
message = await receiver .recv ()
90
- rank = message .rank
105
+ result_kind = message .kind
106
+ assert isinstance (result_kind , PythonMessageKind .Result )
107
+ rank = result_kind .rank
91
108
assert rank is not None
92
109
rcv_ranks .append (rank )
93
110
rcv_ranks .sort ()
0 commit comments