@@ -64,6 +64,7 @@ use hyperactor::message::Bind;
64
64
use hyperactor:: message:: Bindings ;
65
65
use hyperactor:: message:: Unbind ;
66
66
use pin_project:: pin_project;
67
+ use pin_project:: pinned_drop;
67
68
use serde:: Deserialize ;
68
69
use serde:: Serialize ;
69
70
use tokio:: io:: AsyncRead ;
@@ -94,10 +95,15 @@ pub struct OwnedReadHalf {
94
95
}
95
96
96
97
/// Wrap a `PortRef<IoMsg>` as a `AsyncWrite`.
98
+ #[ pin_project( PinnedDrop ) ]
97
99
pub struct OwnedWriteHalf < C : CanSend > {
98
100
peer : ActorId ,
101
+ #[ pin]
99
102
caps : C ,
103
+ #[ pin]
100
104
port : PortRef < Io > ,
105
+ #[ pin]
106
+ shutdown : bool ,
101
107
}
102
108
103
109
/// A duplex bytestream connection between two actors. Can generally be used like a `TcpStream`.
@@ -144,7 +150,12 @@ impl OwnedReadHalf {
144
150
145
151
impl < C : CanSend > OwnedWriteHalf < C > {
146
152
fn new ( peer : ActorId , caps : C , port : PortRef < Io > ) -> Self {
147
- Self { peer, caps, port }
153
+ Self {
154
+ peer,
155
+ caps,
156
+ port,
157
+ shutdown : false ,
158
+ }
148
159
}
149
160
150
161
pub fn peer ( & self ) -> & ActorId {
@@ -159,10 +170,13 @@ impl<C: CanSend> OwnedWriteHalf<C> {
159
170
}
160
171
}
161
172
162
- impl < C : CanSend > Drop for OwnedWriteHalf < C > {
163
- fn drop ( & mut self ) {
164
- // Send EOF on drop.
165
- let _ = self . port . send ( & self . caps , Io :: Eof ) ;
173
+ #[ pinned_drop]
174
+ impl < C : CanSend > PinnedDrop for OwnedWriteHalf < C > {
175
+ fn drop ( self : Pin < & mut Self > ) {
176
+ let this = self . project ( ) ;
177
+ if !* this. shutdown {
178
+ let _ = this. port . send ( & * this. caps , Io :: Eof ) ;
179
+ }
166
180
}
167
181
}
168
182
@@ -247,7 +261,14 @@ impl<C: CanSend> AsyncWrite for OwnedWriteHalf<C> {
247
261
_cx : & mut Context < ' _ > ,
248
262
buf : & [ u8 ] ,
249
263
) -> Poll < Result < usize , std:: io:: Error > > {
250
- match self . port . send ( & self . caps , Io :: Data ( buf. into ( ) ) ) {
264
+ let this = self . project ( ) ;
265
+ if * this. shutdown {
266
+ return Poll :: Ready ( Err ( std:: io:: Error :: new (
267
+ std:: io:: ErrorKind :: BrokenPipe ,
268
+ "write after shutdown" ,
269
+ ) ) ) ;
270
+ }
271
+ match this. port . send ( & * this. caps , Io :: Data ( buf. into ( ) ) ) {
251
272
Ok ( ( ) ) => Poll :: Ready ( Ok ( buf. len ( ) ) ) ,
252
273
Err ( e) => Poll :: Ready ( Err ( std:: io:: Error :: other ( e) ) ) ,
253
274
}
@@ -263,7 +284,11 @@ impl<C: CanSend> AsyncWrite for OwnedWriteHalf<C> {
263
284
) -> Poll < Result < ( ) , std:: io:: Error > > {
264
285
// Send EOF on shutdown.
265
286
match self . port . send ( & self . caps , Io :: Eof ) {
266
- Ok ( ( ) ) => Poll :: Ready ( Ok ( ( ) ) ) ,
287
+ Ok ( ( ) ) => {
288
+ let mut this = self . project ( ) ;
289
+ * this. shutdown = true ;
290
+ Poll :: Ready ( Ok ( ( ) ) )
291
+ }
267
292
Err ( e) => Poll :: Ready ( Err ( std:: io:: Error :: other ( e) ) ) ,
268
293
}
269
294
}
@@ -471,4 +496,36 @@ mod tests {
471
496
472
497
Ok ( ( ) )
473
498
}
499
+
500
+ #[ tokio:: test]
501
+ async fn test_no_eof_on_drop_after_shutdown ( ) -> Result < ( ) > {
502
+ let proc = Proc :: local ( ) ;
503
+ let client = proc. attach ( "client" ) ?;
504
+
505
+ let ( connect, completer) = Connect :: allocate ( client. actor_id ( ) . clone ( ) , client. clone ( ) ) ;
506
+ let ( mut rd, _) = accept ( & client, client. actor_id ( ) . clone ( ) , connect)
507
+ . await ?
508
+ . into_split ( ) ;
509
+ let ( _, mut wr) = completer. complete ( ) . await ?. into_split ( ) ;
510
+
511
+ // Write some data
512
+ let send = [ 1u8 , 2u8 , 3u8 ] ;
513
+ wr. write_all ( & send) . await ?;
514
+
515
+ // Explicitly shutdown the writer - this sends EOF and sets shutdown=true
516
+ wr. shutdown ( ) . await ?;
517
+
518
+ // Reader should receive the data and then EOF (from explicit shutdown, not from drop)
519
+ let mut recv = vec ! [ ] ;
520
+ rd. read_to_end ( & mut recv) . await ?;
521
+ assert_eq ! ( & send, recv. as_slice( ) ) ;
522
+
523
+ // Drop the writer after explicit shutdown - this should NOT send another EOF
524
+ drop ( wr) ;
525
+
526
+ // Verify we didn't see another EOF message.
527
+ assert ! ( rd. inner. into_inner( ) . port. try_recv( ) . unwrap( ) . is_none( ) ) ;
528
+
529
+ Ok ( ( ) )
530
+ }
474
531
}
0 commit comments