@@ -388,7 +388,7 @@ impl<C: Config> Client<C> {
388
388
path : & str ,
389
389
request : I ,
390
390
event_mapper : impl Fn ( eventsource_stream:: Event ) -> Result < O , OpenAIError > + Send + ' static ,
391
- ) -> OpenAIEventMappedStream < O >
391
+ ) -> OpenAIEventStream < O >
392
392
where
393
393
I : Serialize ,
394
394
O : DeserializeOwned + Send + ' static ,
@@ -402,7 +402,7 @@ impl<C: Config> Client<C> {
402
402
. eventsource ( )
403
403
. unwrap ( ) ;
404
404
405
- OpenAIEventMappedStream :: new ( event_source, event_mapper)
405
+ OpenAIEventStream :: with_event_mapping ( event_source, event_mapper)
406
406
}
407
407
408
408
/// Make HTTP GET request to receive SSE
@@ -426,115 +426,57 @@ impl<C: Config> Client<C> {
426
426
427
427
/// Request which responds with SSE.
428
428
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
429
+
429
430
#[ pin_project]
430
- pub struct OpenAIEventStream < O : DeserializeOwned + Send + ' static > {
431
+ pub struct OpenAIEventStream < O >
432
+ where
433
+ O : DeserializeOwned + Send + ' static ,
434
+ {
431
435
#[ pin]
432
436
stream : Filter <
433
437
EventSource ,
434
438
future:: Ready < bool > ,
435
439
fn ( & Result < Event , reqwest_eventsource:: Error > ) -> future:: Ready < bool > ,
436
440
> ,
441
+ event_mapper :
442
+ Option < Box < dyn Fn ( eventsource_stream:: Event ) -> Result < O , OpenAIError > + Send + ' static > > ,
437
443
done : bool ,
438
444
_phantom_data : PhantomData < O > ,
439
445
}
440
446
441
- impl < O : DeserializeOwned + Send + ' static > OpenAIEventStream < O > {
442
- pub ( crate ) fn new ( event_source : EventSource ) -> Self {
447
+ impl < O > OpenAIEventStream < O >
448
+ where
449
+ O : DeserializeOwned + Send + ' static ,
450
+ {
451
+ pub ( crate ) fn with_event_mapping < M > ( event_source : EventSource , event_mapper : M ) -> Self
452
+ where
453
+ M : Fn ( eventsource_stream:: Event ) -> Result < O , OpenAIError > + Send + ' static ,
454
+ {
443
455
Self {
444
456
stream : event_source. filter ( |result|
445
457
// filter out the first event which is always Event::Open
446
458
future:: ready ( !( result. is_ok ( ) && result. as_ref ( ) . unwrap ( ) . eq ( & Event :: Open ) ) ) ) ,
447
459
done : false ,
460
+ event_mapper : Some ( Box :: new ( event_mapper) ) ,
448
461
_phantom_data : PhantomData ,
449
462
}
450
463
}
451
- }
452
-
453
- impl < O : DeserializeOwned + Send + ' static > Stream for OpenAIEventStream < O > {
454
- type Item = Result < O , OpenAIError > ;
455
464
456
- fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
457
- let this = self . project ( ) ;
458
- if * this. done {
459
- return Poll :: Ready ( None ) ;
460
- }
461
- let stream: Pin < & mut _ > = this. stream ;
462
- match stream. poll_next ( cx) {
463
- Poll :: Ready ( response) => {
464
- match response {
465
- None => Poll :: Ready ( None ) , // end of the stream
466
- Some ( result) => match result {
467
- Ok ( event) => match event {
468
- Event :: Open => unreachable ! ( ) , // it has been filtered out
469
- Event :: Message ( message) => {
470
- if message. data == "[DONE]" {
471
- * this. done = true ;
472
- Poll :: Ready ( None ) // end of the stream, defined by OpenAI
473
- } else {
474
- // deserialize the data
475
- match serde_json:: from_str :: < O > ( & message. data ) {
476
- Err ( e) => {
477
- * this. done = true ;
478
- Poll :: Ready ( Some ( Err ( map_deserialization_error (
479
- e,
480
- & message. data . as_bytes ( ) ,
481
- ) ) ) )
482
- }
483
- Ok ( output) => Poll :: Ready ( Some ( Ok ( output) ) ) ,
484
- }
485
- }
486
- }
487
- } ,
488
- Err ( e) => {
489
- * this. done = true ;
490
- Poll :: Ready ( Some ( Err ( OpenAIError :: StreamError ( e. to_string ( ) ) ) ) )
491
- }
492
- } ,
493
- }
494
- }
495
- Poll :: Pending => Poll :: Pending ,
496
- }
497
- }
498
- }
499
-
500
- #[ pin_project]
501
- pub struct OpenAIEventMappedStream < O >
502
- where
503
- O : Send + ' static ,
504
- {
505
- #[ pin]
506
- stream : Filter <
507
- EventSource ,
508
- future:: Ready < bool > ,
509
- fn ( & Result < Event , reqwest_eventsource:: Error > ) -> future:: Ready < bool > ,
510
- > ,
511
- event_mapper : Box < dyn Fn ( eventsource_stream:: Event ) -> Result < O , OpenAIError > + Send + ' static > ,
512
- done : bool ,
513
- _phantom_data : PhantomData < O > ,
514
- }
515
-
516
- impl < O > OpenAIEventMappedStream < O >
517
- where
518
- O : Send + ' static ,
519
- {
520
- pub ( crate ) fn new < M > ( event_source : EventSource , event_mapper : M ) -> Self
521
- where
522
- M : Fn ( eventsource_stream:: Event ) -> Result < O , OpenAIError > + Send + ' static ,
523
- {
465
+ pub ( crate ) fn new ( event_source : EventSource ) -> Self {
524
466
Self {
525
467
stream : event_source. filter ( |result|
526
468
// filter out the first event which is always Event::Open
527
469
future:: ready ( !( result. is_ok ( ) && result. as_ref ( ) . unwrap ( ) . eq ( & Event :: Open ) ) ) ) ,
528
470
done : false ,
529
- event_mapper : Box :: new ( event_mapper ) ,
471
+ event_mapper : None ,
530
472
_phantom_data : PhantomData ,
531
473
}
532
474
}
533
475
}
534
476
535
- impl < O > Stream for OpenAIEventMappedStream < O >
477
+ impl < O > Stream for OpenAIEventStream < O >
536
478
where
537
- O : Send + ' static ,
479
+ O : DeserializeOwned + Send + ' static ,
538
480
{
539
481
type Item = Result < O , OpenAIError > ;
540
482
@@ -552,13 +494,32 @@ where
552
494
Ok ( event) => match event {
553
495
Event :: Open => unreachable ! ( ) , // it has been filtered out
554
496
Event :: Message ( message) => {
555
- if message. data == "[DONE]" {
556
- * this. done = true ;
557
- }
558
- let response = ( this. event_mapper ) ( message) ;
559
- match response {
560
- Ok ( output) => Poll :: Ready ( Some ( Ok ( output) ) ) ,
561
- Err ( _) => Poll :: Ready ( None ) ,
497
+ if let Some ( event_mapper) = this. event_mapper . as_ref ( ) {
498
+ if message. data == "[DONE]" {
499
+ * this. done = true ;
500
+ }
501
+ let response = event_mapper ( message) ;
502
+ match response {
503
+ Ok ( output) => Poll :: Ready ( Some ( Ok ( output) ) ) ,
504
+ Err ( _) => Poll :: Ready ( None ) ,
505
+ }
506
+ } else {
507
+ if message. data == "[DONE]" {
508
+ * this. done = true ;
509
+ Poll :: Ready ( None ) // end of the stream, defined by OpenAI
510
+ } else {
511
+ // deserialize the data
512
+ match serde_json:: from_str :: < O > ( & message. data ) {
513
+ Err ( e) => {
514
+ * this. done = true ;
515
+ Poll :: Ready ( Some ( Err ( map_deserialization_error (
516
+ e,
517
+ & message. data . as_bytes ( ) ,
518
+ ) ) ) )
519
+ }
520
+ Ok ( output) => Poll :: Ready ( Some ( Ok ( output) ) ) ,
521
+ }
522
+ }
562
523
}
563
524
}
564
525
} ,
0 commit comments