Skip to content

Commit bf9b5d5

Browse files
committed
feat: allow responses api to use tool
1 parent 6174180 commit bf9b5d5

File tree

9 files changed

+183
-87
lines changed

9 files changed

+183
-87
lines changed

async-openai/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ secrecy = { version = "0.10.3", features = ["serde"] }
5050
bytes = "1.9.0"
5151
eventsource-stream = "0.2.3"
5252
tokio-tungstenite = { version = "0.26.1", optional = true, default-features = false }
53-
schemars = "0.8.22"
53+
schemars = "0.9.0"
5454

5555
[dev-dependencies]
5656
tokio-test = "0.4.4"

async-openai/src/tools.rs

Lines changed: 90 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use serde::{Deserialize, Serialize};
1212
use serde_json::json;
1313

1414
use crate::types::{
15+
responses::{self, Function, ToolDefinition},
1516
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
1617
ChatCompletionRequestToolMessage, ChatCompletionTool, ChatCompletionToolType, FunctionCall,
1718
FunctionObject,
@@ -29,7 +30,7 @@ pub trait Tool: Send + Sync {
2930

3031
/// Returns the name of the tool.
3132
fn name() -> String {
32-
Self::Args::schema_name()
33+
Self::Args::schema_name().to_string()
3334
}
3435

3536
/// Returns an optional description of the tool.
@@ -42,8 +43,8 @@ pub trait Tool: Send + Sync {
4243
None
4344
}
4445

45-
/// Creates a ChatCompletionTool definition for the tool.
46-
fn definition() -> ChatCompletionTool {
46+
/// Returns the tool's definition for chat.
47+
fn definition_for_chat() -> ChatCompletionTool {
4748
ChatCompletionTool {
4849
r#type: ChatCompletionToolType::Function,
4950
function: FunctionObject {
@@ -55,6 +56,16 @@ pub trait Tool: Send + Sync {
5556
}
5657
}
5758

59+
/// Returns the tool's definition for responses.
60+
fn definition_for_responses() -> ToolDefinition {
61+
ToolDefinition::Function(Function {
62+
name: Self::name(),
63+
description: Self::description(),
64+
parameters: json!(schema_for!(Self::Args)),
65+
strict: Self::strict().unwrap_or(false),
66+
})
67+
}
68+
5869
/// Executes the tool with the given arguments.
5970
/// Returns a Future that resolves to either the tool's output or an error.
6071
fn call(
@@ -66,8 +77,14 @@ pub trait Tool: Send + Sync {
6677
/// A dynamic trait for tools that allows for runtime tool management.
6778
/// This trait provides a way to work with tools without knowing their concrete types at compile time.
6879
pub trait ToolDyn: Send + Sync {
69-
/// Returns the tool's definition as a ChatCompletionTool.
70-
fn definition(&self) -> ChatCompletionTool;
80+
/// Returns the tool's name.
81+
fn name(&self) -> String;
82+
83+
/// Returns the tool's definition for chat.
84+
fn definition_for_chat(&self) -> ChatCompletionTool;
85+
86+
/// Returns the tool's definition for responses.
87+
fn definition_for_responses(&self) -> ToolDefinition;
7188

7289
/// Executes the tool with the given JSON string arguments.
7390
/// Returns a Future that resolves to either a JSON string output or an error string.
@@ -79,8 +96,16 @@ pub trait ToolDyn: Send + Sync {
7996

8097
// Implementation of ToolDyn for any type that implements Tool
8198
impl<T: Tool> ToolDyn for T {
82-
fn definition(&self) -> ChatCompletionTool {
83-
T::definition()
99+
fn name(&self) -> String {
100+
T::name()
101+
}
102+
103+
fn definition_for_chat(&self) -> ChatCompletionTool {
104+
T::definition_for_chat()
105+
}
106+
107+
fn definition_for_responses(&self) -> ToolDefinition {
108+
T::definition_for_responses()
84109
}
85110

86111
fn call(
@@ -125,29 +150,39 @@ impl ToolManager {
125150

126151
/// Adds a new tool to the manager.
127152
pub fn add_tool<T: Tool + 'static>(&mut self, tool: T) {
128-
self.tools
129-
.insert(T::name(), Arc::new(tool));
153+
self.tools.insert(T::name(), Arc::new(tool));
130154
}
131155

132156
/// Adds a new tool with an Arc to the manager.
133157
///
134158
/// Use this if you want to access this tool after being added to the manager.
135159
pub fn add_tool_dyn(&mut self, tool: Arc<dyn ToolDyn>) {
136-
self.tools.insert(tool.definition().function.name, tool);
160+
self.tools.insert(tool.name(), tool);
137161
}
138162

139163
/// Removes a tool from the manager.
140164
pub fn remove_tool(&mut self, name: &str) -> bool {
141165
self.tools.remove(name).is_some()
142166
}
143167

144-
/// Returns the definitions of all tools in the manager.
145-
pub fn get_tools(&self) -> Vec<ChatCompletionTool> {
146-
self.tools.values().map(|tool| tool.definition()).collect()
168+
/// Returns the definitions of all tools for chat in the manager.
169+
pub fn get_tools_for_chat(&self) -> Vec<ChatCompletionTool> {
170+
self.tools
171+
.values()
172+
.map(|tool| tool.definition_for_chat())
173+
.collect()
174+
}
175+
176+
/// Returns the definitions of all tools for responses in the manager.
177+
pub fn get_tools_for_responses(&self) -> Vec<ToolDefinition> {
178+
self.tools
179+
.values()
180+
.map(|tool| tool.definition_for_responses())
181+
.collect()
147182
}
148183

149-
/// Executes multiple tool calls concurrently and returns their results.
150-
pub async fn call(
184+
/// Executes multiple tool calls concurrently and returns their results for chat.
185+
pub async fn call_for_chat(
151186
&self,
152187
calls: impl IntoIterator<Item = ChatCompletionMessageToolCall>,
153188
) -> Vec<ChatCompletionRequestToolMessage> {
@@ -183,6 +218,46 @@ impl ToolManager {
183218
}
184219
outputs
185220
}
221+
222+
/// Executes multiple tool calls concurrently and returns their results for responses.
223+
pub async fn call_for_responses(
224+
&self,
225+
calls: impl IntoIterator<Item = responses::FunctionCall>,
226+
) -> Vec<responses::InputItem> {
227+
let mut handles = Vec::new();
228+
let mut outputs = Vec::new();
229+
230+
// Spawn a task for each tool call
231+
for call in calls {
232+
if let Some(tool) = self.tools.get(&call.name).cloned() {
233+
let handle = tokio::spawn(async move { tool.call(call.arguments).await });
234+
handles.push((call.call_id, handle));
235+
} else {
236+
outputs.push(responses::InputItem::Custom(json!({
237+
"type": "function_call_output",
238+
"call_id": call.call_id,
239+
"output": "Tool call failed: tool not found",
240+
})));
241+
}
242+
}
243+
244+
// Collect results from all spawned tasks
245+
for (id, handle) in handles {
246+
let output = match handle.await {
247+
Ok(Ok(output)) => output,
248+
Ok(Err(e)) => {
249+
format!("Tool call failed: {}", e)
250+
}
251+
Err(_) => "Tool call failed: runtime error".to_string(),
252+
};
253+
outputs.push(responses::InputItem::Custom(json!({
254+
"type": "function_call_output",
255+
"call_id": id,
256+
"output": output,
257+
})));
258+
}
259+
outputs
260+
}
186261
}
187262

188263
/// A manager for handling streaming tool calls.

async-openai/src/types/impls.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use crate::{
1313

1414
use bytes::Bytes;
1515

16+
#[allow(deprecated)]
1617
use super::{
1718
responses::{CodeInterpreterContainer, Input, InputContent, Role as ResponsesRole},
1819
AddUploadPartRequest, AudioInput, AudioResponseFormat, ChatCompletionFunctionCall,
@@ -531,6 +532,7 @@ impl From<String> for ChatCompletionToolChoiceOption {
531532
}
532533
}
533534

535+
#[allow(deprecated)]
534536
impl From<(String, serde_json::Value)> for ChatCompletionFunctions {
535537
fn from(value: (String, serde_json::Value)) -> Self {
536538
Self {

examples/responses-function-call/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ async-openai = {path = "../../async-openai"}
99
serde_json = "1.0.135"
1010
tokio = { version = "1.43.0", features = ["full"] }
1111
serde = { version = "1.0.219", features = ["derive"] }
12+
schemars = "0.9.0"
13+
rand = "0.9.1"
Lines changed: 82 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,23 @@
11
use async_openai::{
2+
tools::{Tool, ToolManager},
23
types::responses::{
3-
CreateResponseArgs, FunctionArgs, FunctionCall, Input, InputItem, InputMessageArgs,
4-
OutputContent, Role, ToolDefinition,
4+
CreateResponseArgs, FunctionCall, Input, InputItem, InputMessageArgs, OutputContent, Role,
55
},
66
Client,
77
};
8-
use serde::Deserialize;
8+
use rand::{rng, seq::IndexedRandom, Rng};
9+
use schemars::JsonSchema;
10+
use serde::{Deserialize, Serialize};
911
use std::error::Error;
1012

11-
#[derive(Debug, Deserialize)]
12-
struct WeatherFunctionArgs {
13-
location: String,
14-
units: String,
15-
}
16-
17-
fn check_weather(location: String, units: String) -> String {
18-
format!("The weather in {location} is 25 {units}")
19-
}
20-
2113
#[tokio::main]
2214
async fn main() -> Result<(), Box<dyn Error>> {
2315
let client = Client::new();
2416

25-
let tools = vec![ToolDefinition::Function(
26-
FunctionArgs::default()
27-
.name("get_weather")
28-
.description("Retrieves current weather for the given location")
29-
.parameters(serde_json::json!(
30-
{
31-
"type": "object",
32-
"properties": {
33-
"location": {
34-
"type": "string",
35-
"description": "City and country e.g. Bogotá, Colombia"
36-
},
37-
"units": {
38-
"type": "string",
39-
"enum": [
40-
"celsius",
41-
"fahrenheit"
42-
],
43-
"description": "Units the temperature will be returned in."
44-
}
45-
},
46-
"required": [
47-
"location",
48-
"units"
49-
],
50-
"additionalProperties": false
51-
}
52-
))
53-
.build()?,
54-
)];
17+
let weather_tool = WeatherTool;
18+
let mut tool_manager = ToolManager::new();
19+
tool_manager.add_tool(weather_tool);
20+
let tools = tool_manager.get_tools_for_responses();
5521

5622
let mut input_messages = vec![InputItem::Message(
5723
InputMessageArgs::default()
@@ -71,40 +37,33 @@ async fn main() -> Result<(), Box<dyn Error>> {
7137

7238
let response = client.responses().create(request).await?;
7339

40+
for output_content in response.output.clone() {
41+
input_messages.push(InputItem::Custom(serde_json::to_value(output_content)?));
42+
}
43+
7444
// the model might ask for us to do a function call
75-
let function_call_request: Option<FunctionCall> =
76-
response.output.into_iter().find_map(|output_content| {
45+
let function_call_request: Vec<FunctionCall> = response
46+
.output
47+
.into_iter()
48+
.filter_map(|output_content| {
7749
if let OutputContent::FunctionCall(inner) = output_content {
7850
Some(inner)
7951
} else {
8052
None
8153
}
82-
});
54+
})
55+
.collect();
8356

84-
let Some(function_call_request) = function_call_request else {
57+
if function_call_request.is_empty() {
8558
println!("No function_call request found");
8659
return Ok(());
8760
};
8861

89-
let function_result = match function_call_request.name.as_str() {
90-
"get_weather" => {
91-
let args: WeatherFunctionArgs = serde_json::from_str(&function_call_request.arguments)?;
92-
check_weather(args.location, args.units)
93-
}
94-
_ => {
95-
println!("Unknown function {}", function_call_request.name);
96-
return Ok(());
97-
}
98-
};
62+
let function_result = tool_manager
63+
.call_for_responses(function_call_request.clone())
64+
.await;
9965

100-
input_messages.push(InputItem::Custom(serde_json::to_value(
101-
&OutputContent::FunctionCall(function_call_request.clone()),
102-
)?));
103-
input_messages.push(InputItem::Custom(serde_json::json!({
104-
"type": "function_call_output",
105-
"call_id": function_call_request.call_id,
106-
"output": function_result,
107-
})));
66+
input_messages.extend(function_result);
10867

10968
let request = CreateResponseArgs::default()
11069
.max_output_tokens(512u32)
@@ -121,3 +80,61 @@ async fn main() -> Result<(), Box<dyn Error>> {
12180

12281
Ok(())
12382
}
83+
84+
#[derive(Debug, JsonSchema, Deserialize, Serialize)]
85+
enum Unit {
86+
Fahrenheit,
87+
Celsius,
88+
}
89+
90+
#[derive(Debug, JsonSchema, Deserialize)]
91+
struct WeatherRequest {
92+
/// The city and state, e.g. San Francisco, CA
93+
location: String,
94+
unit: Unit,
95+
}
96+
97+
#[derive(Debug, Serialize)]
98+
struct WeatherResponse {
99+
location: String,
100+
temperature: String,
101+
unit: Unit,
102+
forecast: String,
103+
}
104+
105+
struct WeatherTool;
106+
107+
impl Tool for WeatherTool {
108+
type Args = WeatherRequest;
109+
type Output = WeatherResponse;
110+
type Error = String;
111+
112+
fn name() -> String {
113+
"get_current_weather".to_string()
114+
}
115+
116+
fn description() -> Option<String> {
117+
Some("Get the current weather in a given location".to_string())
118+
}
119+
120+
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
121+
let mut rng = rng();
122+
123+
let temperature: i32 = rng.random_range(20..=55);
124+
125+
let forecasts = [
126+
"sunny", "cloudy", "overcast", "rainy", "windy", "foggy", "snowy",
127+
];
128+
129+
let forecast = forecasts.choose(&mut rng).unwrap_or(&"sunny");
130+
131+
let weather_info = WeatherResponse {
132+
location: args.location,
133+
temperature: temperature.to_string(),
134+
unit: args.unit,
135+
forecast: forecast.to_string(),
136+
};
137+
138+
Ok(weather_info)
139+
}
140+
}

examples/tool-call-stream/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ serde = "1.0"
1313
serde_json = "1.0.135"
1414
tokio = { version = "1.43.0", features = ["full"] }
1515
futures = "0.3.31"
16-
schemars = "0.8.22"
16+
schemars = "0.9.0"

0 commit comments

Comments
 (0)