Skip to content

Commit 8eaef1c

Browse files
64bitifsheldon
authored andcommitted
feat: Bring your own types (64bit#342)
* attribute proc macro to bring your own types * keep original fn as it is add new with _byot suffix * update macro * update macro * use macro in main crate + add test * byot: assistants * byot: vector_stores * add where_clause attribute arg * remove print * byot: files * byot: images * add stream arg to attribute * byot: chat * byot: completions * fix comment * fix * byot: audio * byot: embeddings * byot: Fine Tunning * add byot tests * byot: moderations * byot tests: moderations * byot: threads * byot tests: threads * byot: messages * byot tests: messages * byot: runs * byot tests: runs * byot: steps * byot tests: run steps * byot: vector store files * byot test: vector store files * byot: vector store file batches * byot test: vector store file batches * cargo fmt * byot: batches * byot tests: batches * format * remove AssistantFiles and related apis (/assistants/assistant_id/files/..) * byot: audit logs * byot tests: audit logs * keep non byot code checks * byot: invites * byot tests: invites * remove message files API * byot: project api keys * byot tests: project api keys * byot: project service accounts * byot tests: project service accounts * byot: project users * byot tests: project users * byot: projects * byot tests: projects * byot: uploads * byot tests: uploads * byot: users * byot tests: users * add example to demonstrate bring-your-own-types * update README * update doc * cargo fmt * update doc in lib.rs * tests passing * fix for complier warning * fix compiler #[allow(unused_mut)] * cargo fix * fix all warnings * add Voices * publish = false for all examples * specify versions (cherry picked from commit 638bf75)
1 parent 7dfbb0e commit 8eaef1c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1189
-335
lines changed

Cargo.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
[workspace]
2-
members = ["async-openai-wasm", "examples/*" ]
2+
members = ["async-openai-wasm", "async-openai-*", "examples/*"]
33
# Only check / build main crates by default (check all with `--workspace`)
4-
default-members = ["async-openai-wasm"]
4+
default-members = ["async-openai-wasm", "async-openai-*"]
55
resolver = "2"
6+
7+
[workspace.package]
8+
rust-version = "1.75"

async-openai-macros/Cargo.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
[package]
2+
name = "async-openai-macros"
3+
version = "0.1.0"
4+
authors = ["Himanshu Neema"]
5+
keywords = ["openai", "macros", "ai"]
6+
description = "Macros for async-openai"
7+
edition = "2021"
8+
license = "MIT"
9+
homepage = "https://github.com/64bit/async-openai"
10+
repository = "https://github.com/64bit/async-openai"
11+
rust-version = { workspace = true }
12+
13+
[lib]
14+
proc-macro = true
15+
16+
[dependencies]
17+
syn = { version = "2.0", features = ["full"] }
18+
quote = "1.0"
19+
proc-macro2 = "1.0"

async-openai-macros/src/lib.rs

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
use proc_macro::TokenStream;
2+
use quote::{quote, ToTokens};
3+
use syn::{
4+
parse::{Parse, ParseStream},
5+
parse_macro_input,
6+
punctuated::Punctuated,
7+
token::Comma,
8+
FnArg, GenericParam, Generics, ItemFn, Pat, PatType, TypeParam, WhereClause,
9+
};
10+
11+
// Parse attribute arguments like #[byot(T0: Display + Debug, T1: Clone, R: Serialize)]
12+
struct BoundArgs {
13+
bounds: Vec<(String, syn::TypeParamBound)>,
14+
where_clause: Option<String>,
15+
stream: bool, // Add stream flag
16+
}
17+
18+
impl Parse for BoundArgs {
19+
fn parse(input: ParseStream) -> syn::Result<Self> {
20+
let mut bounds = Vec::new();
21+
let mut where_clause = None;
22+
let mut stream = false; // Default to false
23+
let vars = Punctuated::<syn::MetaNameValue, Comma>::parse_terminated(input)?;
24+
25+
for var in vars {
26+
let name = var.path.get_ident().unwrap().to_string();
27+
match name.as_str() {
28+
"where_clause" => {
29+
where_clause = Some(var.value.into_token_stream().to_string());
30+
}
31+
"stream" => {
32+
stream = var.value.into_token_stream().to_string().contains("true");
33+
}
34+
_ => {
35+
let bound: syn::TypeParamBound =
36+
syn::parse_str(&var.value.into_token_stream().to_string())?;
37+
bounds.push((name, bound));
38+
}
39+
}
40+
}
41+
Ok(BoundArgs {
42+
bounds,
43+
where_clause,
44+
stream,
45+
})
46+
}
47+
}
48+
49+
#[proc_macro_attribute]
50+
pub fn byot_passthrough(_args: TokenStream, item: TokenStream) -> TokenStream {
51+
item
52+
}
53+
54+
#[proc_macro_attribute]
55+
pub fn byot(args: TokenStream, item: TokenStream) -> TokenStream {
56+
let bounds_args = parse_macro_input!(args as BoundArgs);
57+
let input = parse_macro_input!(item as ItemFn);
58+
let mut new_generics = Generics::default();
59+
let mut param_count = 0;
60+
61+
// Process function arguments
62+
let mut new_params = Vec::new();
63+
let args = input
64+
.sig
65+
.inputs
66+
.iter()
67+
.map(|arg| {
68+
match arg {
69+
FnArg::Receiver(receiver) => receiver.to_token_stream(),
70+
FnArg::Typed(PatType { pat, .. }) => {
71+
if let Pat::Ident(pat_ident) = &**pat {
72+
let generic_name = format!("T{}", param_count);
73+
let generic_ident =
74+
syn::Ident::new(&generic_name, proc_macro2::Span::call_site());
75+
76+
// Create type parameter with optional bounds
77+
let mut type_param = TypeParam::from(generic_ident.clone());
78+
if let Some((_, bound)) = bounds_args
79+
.bounds
80+
.iter()
81+
.find(|(name, _)| name == &generic_name)
82+
{
83+
type_param.bounds.extend(vec![bound.clone()]);
84+
}
85+
86+
new_params.push(GenericParam::Type(type_param));
87+
param_count += 1;
88+
quote! { #pat_ident: #generic_ident }
89+
} else {
90+
arg.to_token_stream()
91+
}
92+
}
93+
}
94+
})
95+
.collect::<Vec<_>>();
96+
97+
// Add R type parameter with optional bounds
98+
let generic_r = syn::Ident::new("R", proc_macro2::Span::call_site());
99+
let mut return_type_param = TypeParam::from(generic_r.clone());
100+
if let Some((_, bound)) = bounds_args.bounds.iter().find(|(name, _)| name == "R") {
101+
return_type_param.bounds.extend(vec![bound.clone()]);
102+
}
103+
new_params.push(GenericParam::Type(return_type_param));
104+
105+
// Add all generic parameters
106+
new_generics.params.extend(new_params);
107+
108+
let fn_name = &input.sig.ident;
109+
let byot_fn_name = syn::Ident::new(&format!("{}_byot", fn_name), fn_name.span());
110+
let vis = &input.vis;
111+
let block = &input.block;
112+
let attrs = &input.attrs;
113+
let asyncness = &input.sig.asyncness;
114+
115+
// Parse where clause if provided
116+
let where_clause = if let Some(where_str) = bounds_args.where_clause {
117+
match syn::parse_str::<WhereClause>(&format!("where {}", where_str.replace("\"", ""))) {
118+
Ok(where_clause) => quote! { #where_clause },
119+
Err(e) => return TokenStream::from(e.to_compile_error()),
120+
}
121+
} else {
122+
quote! {}
123+
};
124+
125+
// Generate return type based on stream flag
126+
let return_type = if bounds_args.stream {
127+
quote! { Result<::std::pin::Pin<Box<dyn ::futures::Stream<Item = Result<R, OpenAIError>> + Send>>, OpenAIError> }
128+
} else {
129+
quote! { Result<R, OpenAIError> }
130+
};
131+
132+
let expanded = quote! {
133+
#(#attrs)*
134+
#input
135+
136+
#(#attrs)*
137+
#vis #asyncness fn #byot_fn_name #new_generics (#(#args),*) -> #return_type #where_clause #block
138+
};
139+
140+
expanded.into()
141+
}

async-openai-wasm/src/assistant_files.rs

Lines changed: 0 additions & 66 deletions
This file was deleted.

async-openai-wasm/src/assistants.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::{
77
AssistantObject, CreateAssistantRequest, DeleteAssistantResponse, ListAssistantsResponse,
88
ModifyAssistantRequest,
99
},
10-
AssistantFiles, Client,
10+
Client,
1111
};
1212

1313
/// Build assistants that can call models and use tools to perform tasks.
@@ -22,12 +22,8 @@ impl<'c, C: Config> Assistants<'c, C> {
2222
Self { client }
2323
}
2424

25-
/// Assistant [AssistantFiles] API group
26-
pub fn files(&self, assistant_id: &str) -> AssistantFiles<C> {
27-
AssistantFiles::new(self.client, assistant_id)
28-
}
29-
3025
/// Create an assistant with a model and instructions.
26+
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
3127
pub async fn create(
3228
&self,
3329
request: CreateAssistantRequest,
@@ -36,13 +32,15 @@ impl<'c, C: Config> Assistants<'c, C> {
3632
}
3733

3834
/// Retrieves an assistant.
35+
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
3936
pub async fn retrieve(&self, assistant_id: &str) -> Result<AssistantObject, OpenAIError> {
4037
self.client
4138
.get(&format!("/assistants/{assistant_id}"))
4239
.await
4340
}
4441

4542
/// Modifies an assistant.
43+
#[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
4644
pub async fn update(
4745
&self,
4846
assistant_id: &str,
@@ -54,17 +52,19 @@ impl<'c, C: Config> Assistants<'c, C> {
5452
}
5553

5654
/// Delete an assistant.
55+
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
5756
pub async fn delete(&self, assistant_id: &str) -> Result<DeleteAssistantResponse, OpenAIError> {
5857
self.client
5958
.delete(&format!("/assistants/{assistant_id}"))
6059
.await
6160
}
6261

6362
/// Returns a list of assistants.
63+
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
6464
pub async fn list<Q>(&self, query: &Q) -> Result<ListAssistantsResponse, OpenAIError>
6565
where
6666
Q: Serialize + ?Sized,
6767
{
68-
self.client.get_with_query("/assistants", query).await
68+
self.client.get_with_query("/assistants", &query).await
6969
}
7070
}

async-openai-wasm/src/audio.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ impl<'c, C: Config> Audio<'c, C> {
2424
}
2525

2626
/// Transcribes audio into the input language.
27+
#[crate::byot(
28+
T0 = Clone,
29+
R = serde::de::DeserializeOwned,
30+
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
31+
)]
2732
pub async fn transcribe(
2833
&self,
2934
request: CreateTranscriptionRequest,
@@ -34,6 +39,11 @@ impl<'c, C: Config> Audio<'c, C> {
3439
}
3540

3641
/// Transcribes audio into the input language.
42+
#[crate::byot(
43+
T0 = Clone,
44+
R = serde::de::DeserializeOwned,
45+
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
46+
)]
3747
pub async fn transcribe_verbose_json(
3848
&self,
3949
request: CreateTranscriptionRequest,
@@ -54,6 +64,11 @@ impl<'c, C: Config> Audio<'c, C> {
5464
}
5565

5666
/// Translates audio into English.
67+
#[crate::byot(
68+
T0 = Clone,
69+
R = serde::de::DeserializeOwned,
70+
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
71+
)]
5772
pub async fn translate(
5873
&self,
5974
request: CreateTranslationRequest,
@@ -62,6 +77,11 @@ impl<'c, C: Config> Audio<'c, C> {
6277
}
6378

6479
/// Translates audio into English.
80+
#[crate::byot(
81+
T0 = Clone,
82+
R = serde::de::DeserializeOwned,
83+
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
84+
)]
6585
pub async fn translate_verbose_json(
6686
&self,
6787
request: CreateTranslationRequest,

async-openai-wasm/src/audit_logs.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@ impl<'c, C: Config> AuditLogs<'c, C> {
1515
}
1616

1717
/// List user actions and configuration changes within this organization.
18+
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
1819
pub async fn get<Q>(&self, query: &Q) -> Result<ListAuditLogsResponse, OpenAIError>
1920
where
2021
Q: Serialize + ?Sized,
2122
{
2223
self.client
23-
.get_with_query("/organization/audit_logs", query)
24+
.get_with_query("/organization/audit_logs", &query)
2425
.await
2526
}
2627
}

async-openai-wasm/src/batches.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,28 @@ impl<'c, C: Config> Batches<'c, C> {
2020
}
2121

2222
/// Creates and executes a batch from an uploaded file of requests
23+
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
2324
pub async fn create(&self, request: BatchRequest) -> Result<Batch, OpenAIError> {
2425
self.client.post("/batches", request).await
2526
}
2627

2728
/// List your organization's batches.
29+
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
2830
pub async fn list<Q>(&self, query: &Q) -> Result<ListBatchesResponse, OpenAIError>
2931
where
3032
Q: Serialize + ?Sized,
3133
{
32-
self.client.get_with_query("/batches", query).await
34+
self.client.get_with_query("/batches", &query).await
3335
}
3436

3537
/// Retrieves a batch.
38+
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
3639
pub async fn retrieve(&self, batch_id: &str) -> Result<Batch, OpenAIError> {
3740
self.client.get(&format!("/batches/{batch_id}")).await
3841
}
3942

4043
/// Cancels an in-progress batch. The batch will be in status `cancelling` for up to 10 minutes, before changing to `cancelled`, where it will have partial results (if any) available in the output file.
44+
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
4145
pub async fn cancel(&self, batch_id: &str) -> Result<Batch, OpenAIError> {
4246
self.client
4347
.post(

0 commit comments

Comments
 (0)