Skip to content

Commit beb8c23

Browse files
committed
add reasoning model support
1 parent 57f60f9 commit beb8c23

File tree

5 files changed

+203
-0
lines changed

5 files changed

+203
-0
lines changed

async-openai-wasm/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ a `x.y.z` version.
3636
- [x] Realtime (Beta) (partially implemented)
3737
- [x] Uploads
3838
- [x] **WASM support**
39+
- [x] Reasoning Model Support: support models like DeepSeek R1 via broader support for OpenAI-compatible endpoints, see `examples/reasoning`
3940
- SSE streaming on available APIs
4041
- Ergonomic builder pattern for all request objects.
4142
- Microsoft Azure OpenAI Service (only for APIs matching OpenAI spec)
@@ -49,6 +50,8 @@ maintain parity with spec of AOS. Just like `async-openai`.
4950
+ * WASM support
5051
+ * WASM examples
5152
+ * Realtime API: Does not bundle with a specific WS implementation. Need to convert a client event into a WS message by yourself, which is just simple `your_ws_impl::Message::Text(some_client_event.into_text())`
53+
+ * Broader support for OpenAI-compatible Endpoints
54+
+ * Reasoning Model Support
5255
- * Tokio
5356
- * Non-wasm examples: please refer to the original project [async-openai](https://github.com/64bit/async-openai/).
5457
- * Builtin backoff retries: due to [this issue](https://github.com/ihrwein/backoff/issues/61).

async-openai-wasm/src/types/chat.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,11 @@ pub struct ChatCompletionResponseMessage {
430430

431431
/// If the audio output modality is requested, this object contains data about the audio response from the model. [Learn more](https://platform.openai.com/docs/guides/audio).
432432
pub audio: Option<ChatCompletionResponseMessageAudio>,
433+
434+
/// Catching anything else that a provider wants to provide, for example, a `reasoning` field
435+
#[serde(skip_serializing_if = "Option::is_none")]
436+
#[serde(flatten)]
437+
pub return_catchall: Option<serde_json::Value>,
433438
}
434439

435440
#[derive(Clone, Serialize, Default, Debug, Deserialize, Builder, PartialEq)]
@@ -816,6 +821,10 @@ pub struct CreateChatCompletionRequest {
816821
#[deprecated]
817822
#[serde(skip_serializing_if = "Option::is_none")]
818823
pub functions: Option<Vec<ChatCompletionFunctions>>,
824+
825+
#[serde(skip_serializing_if = "Option::is_none")]
826+
#[serde(flatten)]
827+
pub extra_params: Option<serde_json::Value>,
819828
}
820829

821830
/// Options for streaming response. Only set this when you set `stream: true`.
@@ -899,6 +908,11 @@ pub struct CreateChatCompletionResponse {
899908
/// The object type, which is always `chat.completion`.
900909
pub object: String,
901910
pub usage: Option<CompletionUsage>,
911+
912+
/// Catching anything else that a provider wants to provide
913+
#[serde(skip_serializing_if = "Option::is_none")]
914+
#[serde(flatten)]
915+
pub return_catchall: Option<serde_json::Value>,
902916
}
903917

904918
/// Parsed server side events stream until an \[DONE\] is received from server.
@@ -939,6 +953,11 @@ pub struct ChatCompletionStreamResponseDelta {
939953
pub role: Option<Role>,
940954
/// The refusal message generated by the model.
941955
pub refusal: Option<String>,
956+
957+
/// Catching anything else that a provider wants to provide, for example, a `reasoning` field
958+
#[serde(skip_serializing_if = "Option::is_none")]
959+
#[serde(flatten)]
960+
pub return_catchall: Option<serde_json::Value>,
942961
}
943962

944963
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
@@ -984,4 +1003,9 @@ pub struct CreateChatCompletionStreamResponse {
9841003
/// An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request.
9851004
/// When present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request.
9861005
pub usage: Option<CompletionUsage>,
1006+
1007+
/// Catching anything else that a provider wants to provide
1008+
#[serde(skip_serializing_if = "Option::is_none")]
1009+
#[serde(flatten)]
1010+
pub return_catchall: Option<serde_json::Value>,
9871011
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
use async_openai_wasm::config::OpenAIConfig;
2+
use async_openai_wasm::types::{
3+
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs,
4+
};
5+
use async_openai_wasm::Client;
6+
use futures::StreamExt;
7+
use serde_json::json;
8+
9+
const OPENROUTER_REASONING_KEY: &str = "reasoning";
10+
const OPENROUTER_BASEURL: &str = "https://openrouter.ai/api/v1";
11+
const DEEPSEEK_REASONING_KEY: &str = "reasoning_content";
12+
const DEEPSEEK_BASEURL: &str = "https://api.deepseek.com";
13+
14+
#[tokio::test]
15+
async fn test_chat_completion_reasoning() {
16+
let test_key = std::env::var("TEST_API_KEY").unwrap();
17+
let use_deepseek = std::env::var("USE_DEEPSEEK").is_ok();
18+
let (reasoning_key, base_url) = if use_deepseek {
19+
(DEEPSEEK_REASONING_KEY, DEEPSEEK_BASEURL)
20+
} else {
21+
(OPENROUTER_REASONING_KEY, OPENROUTER_BASEURL)
22+
};
23+
let client = Client::with_config(
24+
OpenAIConfig::new()
25+
.with_api_base(base_url)
26+
.with_api_key(test_key),
27+
);
28+
let request = CreateChatCompletionRequestArgs::default()
29+
.messages(vec![ChatCompletionRequestUserMessageArgs::default()
30+
.content("Hello! Do you know the Rust programming language?")
31+
.build()
32+
.unwrap()
33+
.into()])
34+
.model("deepseek/deepseek-r1")
35+
// The extra params that OpenRouter requires to get reasoning content
36+
// See https://openrouter.ai/docs/api-reference/parameters#include-reasoning
37+
.extra_params(json!({
38+
"include_reasoning" : true
39+
}))
40+
.build()
41+
.unwrap();
42+
let result = client.chat().create(request).await.unwrap();
43+
// Get the reasoning field in the response
44+
let catch_all_result = result.choices[0].message.return_catchall.as_ref().unwrap();
45+
let reasoning = catch_all_result
46+
.get(reasoning_key)
47+
.unwrap()
48+
.as_str()
49+
.unwrap();
50+
assert!(reasoning.len() > 0);
51+
println!("Reasoning: {reasoning}");
52+
}
53+
54+
#[tokio::test]
55+
async fn test_chat_completion_reasoning_stream() {
56+
let test_key = std::env::var("TEST_API_KEY").unwrap();
57+
let use_deepseek = std::env::var("USE_DEEPSEEK").is_ok();
58+
let (reasoning_key, base_url) = if use_deepseek {
59+
(DEEPSEEK_REASONING_KEY, DEEPSEEK_BASEURL)
60+
} else {
61+
(OPENROUTER_REASONING_KEY, OPENROUTER_BASEURL)
62+
};
63+
let client = Client::with_config(
64+
OpenAIConfig::new()
65+
.with_api_base(base_url)
66+
.with_api_key(test_key),
67+
);
68+
let request = CreateChatCompletionRequestArgs::default()
69+
.messages(vec![ChatCompletionRequestUserMessageArgs::default()
70+
.content("Hello! Do you know the Rust programming language?")
71+
.build()
72+
.unwrap()
73+
.into()])
74+
.model("deepseek/deepseek-r1")
75+
// The extra params that OpenRouter requires to get reasoning content
76+
// See https://openrouter.ai/docs/api-reference/parameters#include-reasoning
77+
.extra_params(json!({
78+
"include_reasoning" : true
79+
}))
80+
.build()
81+
.unwrap();
82+
83+
let mut result = client.chat().create_stream(request).await.unwrap();
84+
let mut reasoning = String::new();
85+
86+
while let Some(result) = result.next().await {
87+
if let Ok(r) = result {
88+
// Get the reasoning field in the response
89+
let catch_all_return = r.choices[0].delta.return_catchall.as_ref();
90+
let reasoning_part = catch_all_return
91+
.and_then(|val| val.get(reasoning_key))
92+
.and_then(|r| r.as_str());
93+
if let Some(reasoning_part) = reasoning_part {
94+
reasoning.push_str(reasoning_part);
95+
println!("Reasoning Part: {reasoning_part}")
96+
}
97+
}
98+
}
99+
assert!(reasoning.len() > 0);
100+
println!("Reasoning:\n{reasoning}");
101+
}

examples/reasoning/Cargo.toml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[package]
2+
name = "reasoning-example"
3+
version = "0.1.0"
4+
edition = "2021"
5+
publish = false
6+
7+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
8+
9+
[dependencies]
10+
async-openai-wasm = { path = "../../async-openai-wasm", features = ["realtime"] }
11+
serde_json = "1.0.135"
12+
futures = "0.3"
13+
tokio = { version = "1.43", features = ["fs", "macros"] }

examples/reasoning/src/main.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
use async_openai_wasm::config::OpenAIConfig;
2+
use async_openai_wasm::types::{
3+
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs,
4+
};
5+
use async_openai_wasm::Client;
6+
use futures::StreamExt;
7+
use serde_json::json;
8+
9+
const OPENROUTER_REASONING_KEY: &str = "reasoning";
10+
const OPENROUTER_BASEURL: &str = "https://openrouter.ai/api/v1";
11+
const DEEPSEEK_REASONING_KEY: &str = "reasoning_content";
12+
const DEEPSEEK_BASEURL: &str = "https://api.deepseek.com";
13+
14+
15+
#[tokio::main(flavor = "current_thread")]
16+
async fn main() {
17+
let test_key = std::env::var("TEST_API_KEY").unwrap();
18+
let use_deepseek = std::env::var("USE_DEEPSEEK").is_ok();
19+
let (reasoning_key, base_url) = if use_deepseek {
20+
(DEEPSEEK_REASONING_KEY, DEEPSEEK_BASEURL)
21+
} else {
22+
(OPENROUTER_REASONING_KEY, OPENROUTER_BASEURL)
23+
};
24+
let client = Client::with_config(
25+
OpenAIConfig::new()
26+
.with_api_base(base_url)
27+
.with_api_key(test_key),
28+
);
29+
let request = CreateChatCompletionRequestArgs::default()
30+
.messages(vec![ChatCompletionRequestUserMessageArgs::default()
31+
.content("Hello! Do you know the Rust programming language?")
32+
.build()
33+
.unwrap()
34+
.into()])
35+
.model("deepseek/deepseek-r1")
36+
// The extra params that OpenRouter requires to get reasoning content
37+
// See https://openrouter.ai/docs/api-reference/parameters#include-reasoning
38+
.extra_params(json!({
39+
"include_reasoning" : true
40+
}))
41+
.build()
42+
.unwrap();
43+
44+
let mut result = client.chat().create_stream(request).await.unwrap();
45+
let mut reasoning = String::new();
46+
47+
while let Some(result) = result.next().await {
48+
if let Ok(r) = result {
49+
// Get the reasoning field in the response
50+
let catch_all_return = r.choices[0].delta.return_catchall.as_ref();
51+
let reasoning_part = catch_all_return
52+
.and_then(|val| val.get(reasoning_key))
53+
.and_then(|r| r.as_str());
54+
if let Some(reasoning_part) = reasoning_part {
55+
reasoning.push_str(reasoning_part);
56+
println!("Reasoning Part: {reasoning_part}")
57+
}
58+
}
59+
}
60+
assert!(reasoning.len() > 0);
61+
println!("Reasoning:\n{reasoning}");
62+
}

0 commit comments

Comments
 (0)