Skip to content

Commit b960301

Browse files
committed
clean code
1 parent e41e8ab commit b960301

File tree

6 files changed

+52
-104
lines changed

6 files changed

+52
-104
lines changed

async-openai-macros/src/lib.rs

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,13 @@ struct BoundArgs {
1313
bounds: Vec<(String, syn::TypeParamBound)>,
1414
where_clause: Option<String>,
1515
stream: bool, // Add stream flag
16-
use_mapped_events: bool,
1716
}
1817

1918
impl Parse for BoundArgs {
2019
fn parse(input: ParseStream) -> syn::Result<Self> {
2120
let mut bounds = Vec::new();
2221
let mut where_clause = None;
2322
let mut stream = false; // Default to false
24-
let mut use_mapped_events = false; // Default to false
2523
let vars = Punctuated::<syn::MetaNameValue, Comma>::parse_terminated(input)?;
2624

2725
for var in vars {
@@ -33,9 +31,6 @@ impl Parse for BoundArgs {
3331
"stream" => {
3432
stream = var.value.into_token_stream().to_string().contains("true");
3533
}
36-
"use_mapped_events" => {
37-
use_mapped_events = var.value.into_token_stream().to_string().contains("true");
38-
}
3934
_ => {
4035
let bound: syn::TypeParamBound =
4136
syn::parse_str(&var.value.into_token_stream().to_string())?;
@@ -47,7 +42,6 @@ impl Parse for BoundArgs {
4742
bounds,
4843
where_clause,
4944
stream,
50-
use_mapped_events,
5145
})
5246
}
5347
}
@@ -130,11 +124,7 @@ pub fn byot(args: TokenStream, item: TokenStream) -> TokenStream {
130124

131125
// Generate return type based on stream flag
132126
let return_type = if bounds_args.stream {
133-
if bounds_args.use_mapped_events {
134-
quote! { Result<crate::client::OpenAIEventMappedStream<R>, OpenAIError> }
135-
} else {
136-
quote! { Result<crate::client::OpenAIEventStream<R>, OpenAIError> }
137-
}
127+
quote! { Result<crate::client::OpenAIEventStream<R>, OpenAIError> }
138128
} else {
139129
quote! { Result<R, OpenAIError> }
140130
};

async-openai-wasm/src/client.rs

Lines changed: 48 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ impl<C: Config> Client<C> {
388388
path: &str,
389389
request: I,
390390
event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
391-
) -> OpenAIEventMappedStream<O>
391+
) -> OpenAIEventStream<O>
392392
where
393393
I: Serialize,
394394
O: DeserializeOwned + Send + 'static,
@@ -402,7 +402,7 @@ impl<C: Config> Client<C> {
402402
.eventsource()
403403
.unwrap();
404404

405-
OpenAIEventMappedStream::new(event_source, event_mapper)
405+
OpenAIEventStream::with_event_mapping(event_source, event_mapper)
406406
}
407407

408408
/// Make HTTP GET request to receive SSE
@@ -426,115 +426,57 @@ impl<C: Config> Client<C> {
426426

427427
/// Request which responds with SSE.
428428
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
429+
429430
#[pin_project]
430-
pub struct OpenAIEventStream<O: DeserializeOwned + Send + 'static> {
431+
pub struct OpenAIEventStream<O>
432+
where
433+
O: DeserializeOwned + Send + 'static,
434+
{
431435
#[pin]
432436
stream: Filter<
433437
EventSource,
434438
future::Ready<bool>,
435439
fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>,
436440
>,
441+
event_mapper:
442+
Option<Box<dyn Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static>>,
437443
done: bool,
438444
_phantom_data: PhantomData<O>,
439445
}
440446

441-
impl<O: DeserializeOwned + Send + 'static> OpenAIEventStream<O> {
442-
pub(crate) fn new(event_source: EventSource) -> Self {
447+
impl<O> OpenAIEventStream<O>
448+
where
449+
O: DeserializeOwned + Send + 'static,
450+
{
451+
pub(crate) fn with_event_mapping<M>(event_source: EventSource, event_mapper: M) -> Self
452+
where
453+
M: Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
454+
{
443455
Self {
444456
stream: event_source.filter(|result|
445457
// filter out the first event which is always Event::Open
446458
future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
447459
done: false,
460+
event_mapper: Some(Box::new(event_mapper)),
448461
_phantom_data: PhantomData,
449462
}
450463
}
451-
}
452-
453-
impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
454-
type Item = Result<O, OpenAIError>;
455464

456-
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
457-
let this = self.project();
458-
if *this.done {
459-
return Poll::Ready(None);
460-
}
461-
let stream: Pin<&mut _> = this.stream;
462-
match stream.poll_next(cx) {
463-
Poll::Ready(response) => {
464-
match response {
465-
None => Poll::Ready(None), // end of the stream
466-
Some(result) => match result {
467-
Ok(event) => match event {
468-
Event::Open => unreachable!(), // it has been filtered out
469-
Event::Message(message) => {
470-
if message.data == "[DONE]" {
471-
*this.done = true;
472-
Poll::Ready(None) // end of the stream, defined by OpenAI
473-
} else {
474-
// deserialize the data
475-
match serde_json::from_str::<O>(&message.data) {
476-
Err(e) => {
477-
*this.done = true;
478-
Poll::Ready(Some(Err(map_deserialization_error(
479-
e,
480-
&message.data.as_bytes(),
481-
))))
482-
}
483-
Ok(output) => Poll::Ready(Some(Ok(output))),
484-
}
485-
}
486-
}
487-
},
488-
Err(e) => {
489-
*this.done = true;
490-
Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
491-
}
492-
},
493-
}
494-
}
495-
Poll::Pending => Poll::Pending,
496-
}
497-
}
498-
}
499-
500-
#[pin_project]
501-
pub struct OpenAIEventMappedStream<O>
502-
where
503-
O: Send + 'static,
504-
{
505-
#[pin]
506-
stream: Filter<
507-
EventSource,
508-
future::Ready<bool>,
509-
fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>,
510-
>,
511-
event_mapper: Box<dyn Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static>,
512-
done: bool,
513-
_phantom_data: PhantomData<O>,
514-
}
515-
516-
impl<O> OpenAIEventMappedStream<O>
517-
where
518-
O: Send + 'static,
519-
{
520-
pub(crate) fn new<M>(event_source: EventSource, event_mapper: M) -> Self
521-
where
522-
M: Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
523-
{
465+
pub(crate) fn new(event_source: EventSource) -> Self {
524466
Self {
525467
stream: event_source.filter(|result|
526468
// filter out the first event which is always Event::Open
527469
future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
528470
done: false,
529-
event_mapper: Box::new(event_mapper),
471+
event_mapper: None,
530472
_phantom_data: PhantomData,
531473
}
532474
}
533475
}
534476

535-
impl<O> Stream for OpenAIEventMappedStream<O>
477+
impl<O> Stream for OpenAIEventStream<O>
536478
where
537-
O: Send + 'static,
479+
O: DeserializeOwned + Send + 'static,
538480
{
539481
type Item = Result<O, OpenAIError>;
540482

@@ -552,13 +494,32 @@ where
552494
Ok(event) => match event {
553495
Event::Open => unreachable!(), // it has been filtered out
554496
Event::Message(message) => {
555-
if message.data == "[DONE]" {
556-
*this.done = true;
557-
}
558-
let response = (this.event_mapper)(message);
559-
match response {
560-
Ok(output) => Poll::Ready(Some(Ok(output))),
561-
Err(_) => Poll::Ready(None),
497+
if let Some(event_mapper) = this.event_mapper.as_ref() {
498+
if message.data == "[DONE]" {
499+
*this.done = true;
500+
}
501+
let response = event_mapper(message);
502+
match response {
503+
Ok(output) => Poll::Ready(Some(Ok(output))),
504+
Err(_) => Poll::Ready(None),
505+
}
506+
} else {
507+
if message.data == "[DONE]" {
508+
*this.done = true;
509+
Poll::Ready(None) // end of the stream, defined by OpenAI
510+
} else {
511+
// deserialize the data
512+
match serde_json::from_str::<O>(&message.data) {
513+
Err(e) => {
514+
*this.done = true;
515+
Poll::Ready(Some(Err(map_deserialization_error(
516+
e,
517+
&message.data.as_bytes(),
518+
))))
519+
}
520+
Ok(output) => Poll::Ready(Some(Ok(output))),
521+
}
522+
}
562523
}
563524
}
564525
},

async-openai-wasm/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ pub use audio::Audio;
163163
pub use audit_logs::AuditLogs;
164164
pub use batches::Batches;
165165
pub use chat::Chat;
166-
pub use client::{Client, OpenAIEventMappedStream, OpenAIEventStream};
166+
pub use client::{Client, OpenAIEventStream};
167167
pub use completion::Completions;
168168
pub use embedding::Embeddings;
169169
pub use file::Files;

async-openai-wasm/src/runs.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ impl<'c, C: Config> Runs<'c, C> {
4747
T0 = serde::Serialize,
4848
R = serde::de::DeserializeOwned,
4949
stream = "true",
50-
use_mapped_events = "true",
5150
where_clause = "R: std::marker::Send + 'static + TryFrom<eventsource_stream::Event, Error = OpenAIError>"
5251
)]
5352
#[allow(unused_mut)]
@@ -134,7 +133,6 @@ impl<'c, C: Config> Runs<'c, C> {
134133
T1 = serde::Serialize,
135134
R = serde::de::DeserializeOwned,
136135
stream = "true",
137-
use_mapped_events = "true",
138136
where_clause = "R: std::marker::Send + 'static + TryFrom<eventsource_stream::Event, Error = OpenAIError>"
139137
)]
140138
#[allow(unused_mut)]

async-openai-wasm/src/threads.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ impl<'c, C: Config> Threads<'c, C> {
4646
T0 = serde::Serialize,
4747
R = serde::de::DeserializeOwned,
4848
stream = "true",
49-
use_mapped_events = "true",
5049
where_clause = "R: std::marker::Send + 'static + TryFrom<eventsource_stream::Event, Error = OpenAIError>"
5150
)]
5251
#[allow(unused_mut)]

async-openai-wasm/src/types/assistant_stream.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use serde::Deserialize;
22

3-
use crate::client::OpenAIEventMappedStream;
3+
use crate::client::OpenAIEventStream;
44
use crate::error::{ApiError, OpenAIError, map_deserialization_error};
55

66
use super::{
@@ -107,7 +107,7 @@ pub enum AssistantStreamEvent {
107107
Done(String),
108108
}
109109

110-
pub type AssistantEventStream = OpenAIEventMappedStream<AssistantStreamEvent>;
110+
pub type AssistantEventStream = OpenAIEventStream<AssistantStreamEvent>;
111111

112112
impl TryFrom<eventsource_stream::Event> for AssistantStreamEvent {
113113
type Error = OpenAIError;

0 commit comments

Comments
 (0)