@@ -29,14 +29,15 @@ use pyo3::types::PyType;
29
29
use serde:: Deserialize ;
30
30
use serde:: Serialize ;
31
31
32
- macro_rules! setup_rdma_context {
33
- ( $self: ident, $local_proc_id: expr) => { {
34
- let proc_id: ProcId = $local_proc_id. parse( ) . unwrap( ) ;
35
- let local_owner_id = ActorId ( proc_id, "rdma_manager" . to_string( ) , 0 ) ;
36
- let local_owner_ref: ActorRef <RdmaManagerActor > = ActorRef :: attest( local_owner_id) ;
37
- let buffer = $self. buffer. clone( ) ;
38
- ( local_owner_ref, buffer)
39
- } } ;
32
+ fn setup_rdma_context (
33
+ rdma_buffer : & PyRdmaBuffer ,
34
+ local_proc_id : String ,
35
+ ) -> ( ActorRef < RdmaManagerActor > , RdmaBuffer ) {
36
+ let proc_id: ProcId = local_proc_id. parse ( ) . unwrap ( ) ;
37
+ let local_owner_id = ActorId ( proc_id, "rdma_manager" . to_string ( ) , 0 ) ;
38
+ let local_owner_ref: ActorRef < RdmaManagerActor > = ActorRef :: attest ( local_owner_id) ;
39
+ let buffer = rdma_buffer. buffer . clone ( ) ;
40
+ ( local_owner_ref, buffer)
40
41
}
41
42
42
43
#[ pyclass( name = "_RdmaBuffer" , module = "monarch._rust_bindings.rdma" ) ]
@@ -49,16 +50,16 @@ struct PyRdmaBuffer {
49
50
async fn create_rdma_buffer (
50
51
addr : usize ,
51
52
size : usize ,
52
- proc_id : String ,
53
+ proc_id : ProcId ,
53
54
client : PyMailbox ,
54
55
) -> PyResult < PyRdmaBuffer > {
55
56
// Get the owning RdmaManagerActor's ActorRef
56
- let proc_id: ProcId = proc_id. parse ( ) . unwrap ( ) ;
57
57
let owner_id = ActorId ( proc_id, "rdma_manager" . to_string ( ) , 0 ) ;
58
58
let owner_ref: ActorRef < RdmaManagerActor > = ActorRef :: attest ( owner_id) ;
59
59
60
+ let caps = client. get_inner ( ) ;
60
61
// Create the RdmaBuffer
61
- let buffer = owner_ref. request_buffer ( & client . inner , addr, size) . await ?;
62
+ let buffer = owner_ref. request_buffer ( caps , addr, size) . await ?;
62
63
Ok ( PyRdmaBuffer { buffer, owner_ref } )
63
64
}
64
65
@@ -78,7 +79,10 @@ impl PyRdmaBuffer {
78
79
"ibverbs is not supported on this system" ,
79
80
) ) ;
80
81
}
81
- signal_safe_block_on ( py, create_rdma_buffer ( addr, size, proc_id, client) ) ?
82
+ signal_safe_block_on (
83
+ py,
84
+ create_rdma_buffer ( addr, size, proc_id. parse ( ) . unwrap ( ) , client) ,
85
+ ) ?
82
86
}
83
87
84
88
#[ classmethod]
@@ -97,7 +101,7 @@ impl PyRdmaBuffer {
97
101
}
98
102
pyo3_async_runtimes:: tokio:: future_into_py (
99
103
py,
100
- create_rdma_buffer ( addr, size, proc_id, client) ,
104
+ create_rdma_buffer ( addr, size, proc_id. parse ( ) . unwrap ( ) , client) ,
101
105
)
102
106
}
103
107
@@ -133,13 +137,12 @@ impl PyRdmaBuffer {
133
137
client : PyMailbox ,
134
138
timeout : u64 ,
135
139
) -> PyResult < Bound < ' py , PyAny > > {
136
- let ( local_owner_ref, buffer) = setup_rdma_context ! ( self , local_proc_id) ;
140
+ let ( local_owner_ref, buffer) = setup_rdma_context ( self , local_proc_id) ;
137
141
pyo3_async_runtimes:: tokio:: future_into_py ( py, async move {
138
- let local_buffer = local_owner_ref
139
- . request_buffer ( & client. inner , addr, size)
140
- . await ?;
142
+ let caps = client. get_inner ( ) ;
143
+ let local_buffer = local_owner_ref. request_buffer ( caps, addr, size) . await ?;
141
144
let _result_ = local_buffer
142
- . write_from ( & client . inner , buffer, timeout)
145
+ . write_from ( caps , buffer, timeout)
143
146
. await
144
147
. map_err ( |e| PyException :: new_err ( format ! ( "failed to read into buffer: {}" , e) ) ) ?;
145
148
Ok ( ( ) )
@@ -170,13 +173,12 @@ impl PyRdmaBuffer {
170
173
client : PyMailbox ,
171
174
timeout : u64 ,
172
175
) -> PyResult < bool > {
173
- let ( local_owner_ref, buffer) = setup_rdma_context ! ( self , local_proc_id) ;
176
+ let ( local_owner_ref, buffer) = setup_rdma_context ( self , local_proc_id) ;
174
177
signal_safe_block_on ( py, async move {
175
- let local_buffer = local_owner_ref
176
- . request_buffer ( & client. inner , addr, size)
177
- . await ?;
178
+ let caps = client. get_inner ( ) ;
179
+ let local_buffer = local_owner_ref. request_buffer ( caps, addr, size) . await ?;
178
180
local_buffer
179
- . write_from ( & client . inner , buffer, timeout)
181
+ . write_from ( caps , buffer, timeout)
180
182
. await
181
183
. map_err ( |e| PyException :: new_err ( format ! ( "failed to read into buffer: {}" , e) ) )
182
184
} ) ?
@@ -204,13 +206,12 @@ impl PyRdmaBuffer {
204
206
client : PyMailbox ,
205
207
timeout : u64 ,
206
208
) -> PyResult < Bound < ' py , PyAny > > {
207
- let ( local_owner_ref, buffer) = setup_rdma_context ! ( self , local_proc_id) ;
209
+ let ( local_owner_ref, buffer) = setup_rdma_context ( self , local_proc_id) ;
208
210
pyo3_async_runtimes:: tokio:: future_into_py ( py, async move {
209
- let local_buffer = local_owner_ref
210
- . request_buffer ( & client. inner , addr, size)
211
- . await ?;
211
+ let caps = client. get_inner ( ) ;
212
+ let local_buffer = local_owner_ref. request_buffer ( caps, addr, size) . await ?;
212
213
let _result_ = local_buffer
213
- . read_into ( & client . inner , buffer, timeout)
214
+ . read_into ( caps , buffer, timeout)
214
215
. await
215
216
. map_err ( |e| PyException :: new_err ( format ! ( "failed to write from buffer: {}" , e) ) ) ?;
216
217
Ok ( ( ) )
@@ -241,13 +242,12 @@ impl PyRdmaBuffer {
241
242
client : PyMailbox ,
242
243
timeout : u64 ,
243
244
) -> PyResult < bool > {
244
- let ( local_owner_ref, buffer) = setup_rdma_context ! ( self , local_proc_id) ;
245
+ let ( local_owner_ref, buffer) = setup_rdma_context ( self , local_proc_id) ;
245
246
signal_safe_block_on ( py, async move {
246
- let local_buffer = local_owner_ref
247
- . request_buffer ( & client. inner , addr, size)
248
- . await ?;
247
+ let caps = client. get_inner ( ) ;
248
+ let local_buffer = local_owner_ref. request_buffer ( caps, addr, size) . await ?;
249
249
local_buffer
250
- . read_into ( & client . inner , buffer, timeout)
250
+ . read_into ( caps , buffer, timeout)
251
251
. await
252
252
. map_err ( |e| PyException :: new_err ( format ! ( "failed to write from buffer: {}" , e) ) )
253
253
} ) ?
0 commit comments