Skip to content

Commit e96d582

Browse files
committed
toolcall, image fix & stream parser wip
1 parent 3cbc715 commit e96d582

File tree

5 files changed

+163
-87
lines changed

5 files changed

+163
-87
lines changed

llm/bedrock/src/client.rs

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ impl BedrockClient {
126126
})?;
127127

128128
trace!("Initializing SSE stream");
129-
129+
trace!("Response: {:?}", response.headers().clone());
130130
EventSource::new(response)
131131
.map_err(|err| from_event_source_error("Failed to create SSE stream", err))
132132
}
@@ -288,34 +288,45 @@ pub enum Role {
288288
}
289289

290290
#[derive(Debug, Clone, Serialize, Deserialize)]
291-
#[serde(tag = "type")]
291+
#[serde(untagged)]
292292
pub enum ContentBlock {
293-
#[serde(rename = "text")]
294293
Text { text: String },
295-
#[serde(rename = "image")]
296294
Image {
297-
#[serde(rename = "format")]
298-
format: ImageFormat,
299-
#[serde(rename = "source")]
300-
source: ImageSource,
295+
image: ImageBlock,
301296
},
302-
#[serde(rename = "toolUse")]
303297
ToolUse {
304-
#[serde(rename = "toolUseId")]
305-
tool_use_id: String,
306-
name: String,
307-
input: Value,
298+
#[serde(rename = "toolUse")]
299+
tool_use: ToolUseBlock,
308300
},
309-
#[serde(rename = "toolResult")]
310301
ToolResult {
311-
#[serde(rename = "toolUseId")]
312-
tool_use_id: String,
313-
content: Vec<ToolResultContentBlock>,
314-
#[serde(skip_serializing_if = "Option::is_none")]
315-
status: Option<ToolResultStatus>,
302+
#[serde(rename = "toolResult")]
303+
tool_result: ToolResultBlock,
316304
},
317305
}
318306

307+
#[derive(Debug, Clone, Serialize, Deserialize)]
308+
pub struct ImageBlock {
309+
pub format: ImageFormat,
310+
pub source: ImageSource,
311+
}
312+
313+
#[derive(Debug, Clone, Serialize, Deserialize)]
314+
pub struct ToolUseBlock {
315+
#[serde(rename = "toolUseId")]
316+
pub tool_use_id: String,
317+
pub name: String,
318+
pub input: Value,
319+
}
320+
321+
#[derive(Debug, Clone, Serialize, Deserialize)]
322+
pub struct ToolResultBlock {
323+
#[serde(rename = "toolUseId")]
324+
pub tool_use_id: String,
325+
pub content: Vec<ToolResultContentBlock>,
326+
#[serde(skip_serializing_if = "Option::is_none")]
327+
pub status: Option<ToolResultStatus>,
328+
}
329+
319330
#[derive(Debug, Clone, Serialize, Deserialize)]
320331
pub enum ImageFormat {
321332
#[serde(rename = "png")]
@@ -329,7 +340,6 @@ pub enum ImageFormat {
329340
}
330341

331342
#[derive(Debug, Clone, Serialize, Deserialize)]
332-
#[serde(tag = "bytes")]
333343
pub struct ImageSource {
334344
pub bytes: String,
335345
}
@@ -395,18 +405,25 @@ pub struct ToolSpec {
395405
pub name: String,
396406
pub description: String,
397407
#[serde(rename = "inputSchema")]
398-
pub input_schema: Value,
408+
pub input_schema: ToolInputSchema,
399409
}
400410

401411
#[derive(Debug, Clone, Serialize, Deserialize)]
402-
#[serde(tag = "type")]
412+
pub struct ToolInputSchema {
413+
pub json: Value,
414+
}
415+
416+
#[derive(Debug, Clone, Serialize, Deserialize)]
417+
#[serde(untagged)]
403418
pub enum ToolChoice {
404-
#[serde(rename = "auto")]
405-
Auto,
406-
#[serde(rename = "any")]
407-
Any,
408-
#[serde(rename = "tool")]
409-
Tool { name: String },
419+
Auto { auto: serde_json::Value },
420+
Any { any: serde_json::Value },
421+
Tool { tool: ToolChoiceTool },
422+
}
423+
424+
#[derive(Debug, Clone, Serialize, Deserialize)]
425+
pub struct ToolChoiceTool {
426+
pub name: String,
410427
}
411428

412429
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -420,20 +437,15 @@ pub struct GuardrailConfig {
420437

421438
#[derive(Debug, Clone, Serialize, Deserialize)]
422439
pub struct ConverseResponse {
423-
#[serde(rename = "responseMetadata")]
424-
pub response_metadata: ResponseMetadata,
425440
pub output: Output,
426441
#[serde(rename = "stopReason")]
427442
pub stop_reason: StopReason,
428443
pub usage: Usage,
429444
pub metrics: Metrics,
445+
#[serde(rename = "additionalModelResponseFields", skip_serializing_if = "Option::is_none")]
446+
pub additional_model_response_fields: Option<Value>,
430447
}
431448

432-
#[derive(Debug, Clone, Serialize, Deserialize)]
433-
pub struct ResponseMetadata {
434-
#[serde(rename = "requestId")]
435-
pub request_id: String,
436-
}
437449

438450
#[derive(Debug, Clone, Serialize, Deserialize)]
439451
pub struct Output {

llm/bedrock/src/conversions.rs

Lines changed: 112 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@ use crate::client::{
22
ContentBlock, ConverseRequest, ConverseResponse, ImageFormat, ImageSource as ClientImageSource,
33
InferenceConfig, Message as ClientMessage, Role as ClientRole, StopReason, SystemContentBlock,
44
Tool, ToolChoice, ToolConfig, ToolSpec, ToolResultContentBlock, ToolResultStatus,
5+
ImageBlock, ToolUseBlock, ToolResultBlock, ToolInputSchema, ToolChoiceTool,
56
};
67
use base64::{engine::general_purpose, Engine as _};
78
use golem_llm::golem::llm::llm::{
89
ChatEvent, CompleteResponse, Config, ContentPart, Error, ErrorCode, FinishReason,
910
ImageReference, ImageSource, Message, ResponseMetadata, Role, ToolCall,
1011
ToolDefinition, ToolResult, Usage,
1112
};
12-
use std::collections::HashMap;
13+
use reqwest::{Client, Url};
14+
use std::{collections::HashMap, fs, path::Path};
1315

1416
pub fn messages_to_request(
1517
messages: Vec<Message>,
@@ -69,10 +71,14 @@ pub fn messages_to_request(
6971

7072
let tool_choice = config.tool_choice.map(convert_tool_choice);
7173

72-
Some(ToolConfig {
73-
tools,
74-
tool_choice,
75-
})
74+
if tools.is_empty() {
75+
None
76+
} else {
77+
Some(ToolConfig {
78+
tools,
79+
tool_choice,
80+
})
81+
}
7682
};
7783

7884
Ok(ConverseRequest {
@@ -91,11 +97,19 @@ pub fn messages_to_request(
9197
}
9298

9399
fn convert_tool_choice(tool_name: String) -> ToolChoice {
100+
use serde_json::Value;
101+
94102
match tool_name.as_str() {
95-
"auto" => ToolChoice::Auto,
96-
"any" => ToolChoice::Any,
103+
"auto" => ToolChoice::Auto {
104+
auto: Value::Object(serde_json::Map::new()),
105+
},
106+
"any" => ToolChoice::Any {
107+
any: Value::Object(serde_json::Map::new()),
108+
},
97109
name => ToolChoice::Tool {
98-
name: name.to_string(),
110+
tool: ToolChoiceTool {
111+
name: name.to_string(),
112+
},
99113
},
100114
}
101115
}
@@ -107,10 +121,10 @@ pub fn process_response(response: ConverseResponse) -> ChatEvent {
107121
for content in response.output.message.content {
108122
match content {
109123
ContentBlock::Text { text } => contents.push(ContentPart::Text(text)),
110-
ContentBlock::Image { format, source } => {
111-
match general_purpose::STANDARD.decode(&source.bytes) {
124+
ContentBlock::Image { image } => {
125+
match general_purpose::STANDARD.decode(&image.source.bytes) {
112126
Ok(decoded_data) => {
113-
let mime_type = match format {
127+
let mime_type = match image.format {
114128
ImageFormat::Jpeg => "image/jpeg",
115129
ImageFormat::Png => "image/png",
116130
ImageFormat::Gif => "image/gif",
@@ -133,14 +147,10 @@ pub fn process_response(response: ConverseResponse) -> ChatEvent {
133147
}
134148
}
135149
}
136-
ContentBlock::ToolUse {
137-
tool_use_id,
138-
name,
139-
input,
140-
} => tool_calls.push(ToolCall {
141-
id: tool_use_id,
142-
name,
143-
arguments_json: serde_json::to_string(&input).unwrap(),
150+
ContentBlock::ToolUse { tool_use } => tool_calls.push(ToolCall {
151+
id: tool_use.tool_use_id,
152+
name: tool_use.name,
153+
arguments_json: serde_json::to_string(&tool_use.input).unwrap(),
144154
}),
145155
ContentBlock::ToolResult { .. } => {}
146156
}
@@ -149,7 +159,7 @@ pub fn process_response(response: ConverseResponse) -> ChatEvent {
149159
if contents.is_empty() && !tool_calls.is_empty() {
150160
ChatEvent::ToolRequest(tool_calls)
151161
} else {
152-
let request_id = response.response_metadata.request_id.clone();
162+
let request_id = "bedrock-response".to_string();
153163

154164
let metadata = ResponseMetadata {
155165
finish_reason: Some(stop_reason_to_finish_reason(response.stop_reason)),
@@ -176,9 +186,11 @@ pub fn tool_results_to_messages(
176186
for (tool_call, tool_result) in tool_results {
177187
messages.push(ClientMessage {
178188
content: vec![ContentBlock::ToolUse {
179-
tool_use_id: tool_call.id.clone(),
180-
name: tool_call.name,
181-
input: serde_json::from_str(&tool_call.arguments_json).unwrap(),
189+
tool_use: ToolUseBlock {
190+
tool_use_id: tool_call.id.clone(),
191+
name: tool_call.name,
192+
input: serde_json::from_str(&tool_call.arguments_json).unwrap(),
193+
},
182194
}],
183195
role: ClientRole::Assistant,
184196
});
@@ -200,9 +212,11 @@ pub fn tool_results_to_messages(
200212

201213
messages.push(ClientMessage {
202214
content: vec![ContentBlock::ToolResult {
203-
tool_use_id: tool_call.id,
204-
content,
205-
status,
215+
tool_result: ToolResultBlock {
216+
tool_use_id: tool_call.id,
217+
content,
218+
status,
219+
},
206220
}],
207221
role: ClientRole::User,
208222
});
@@ -239,13 +253,45 @@ fn message_to_content(message: &Message) -> Result<Vec<ContentBlock>, Error> {
239253
text: text.clone(),
240254
}),
241255
ContentPart::Image(image_reference) => match image_reference {
242-
ImageReference::Url(_image_url) => {
243-
return Err(Error {
244-
code: ErrorCode::InvalidRequest,
245-
message: "Bedrock API does not support image URLs, only base64 encoded images".to_string(),
246-
provider_error_json: None,
256+
ImageReference::Url(image_url) => {
257+
let url = &image_url.url;
258+
let mut format = ImageFormat::Png;
259+
let bytes = if Url::parse(url).is_ok() {
260+
let client = Client::new();
261+
let response = client.get(url).send().map_err(|e| Error {
262+
code: ErrorCode::InvalidRequest,
263+
message: format!("Failed to fetch image from URL: {}", e),
264+
provider_error_json: None,
265+
});
266+
response.map(|r| {
267+
format = match r.headers().get("Content-Type").unwrap().to_str().unwrap() {
268+
"image/jpeg" => ImageFormat::Jpeg,
269+
"image/png" => ImageFormat::Png,
270+
"image/gif" => ImageFormat::Gif,
271+
"image/webp" => ImageFormat::Webp,
272+
_ => ImageFormat::Jpeg,
273+
};
274+
r.bytes().unwrap().to_vec()
275+
})
276+
} else {
277+
let path = Path::new(url);
278+
fs::read(path).map_err(|e| Error {
279+
code: ErrorCode::InvalidRequest,
280+
message: format!("Failed to read image from path: {}", e),
281+
provider_error_json: None,
282+
})
283+
};
284+
285+
let base64_data = general_purpose::STANDARD.encode(&bytes.unwrap());
286+
result.push(ContentBlock::Image {
287+
image: ImageBlock {
288+
format: ImageFormat::Png,
289+
source: ClientImageSource {
290+
bytes: base64_data,
291+
},
292+
},
247293
});
248-
}
294+
},
249295
ImageReference::Inline(image_source) => {
250296
let base64_data = general_purpose::STANDARD.encode(&image_source.data);
251297
let format = match image_source.mime_type.as_str() {
@@ -257,9 +303,11 @@ fn message_to_content(message: &Message) -> Result<Vec<ContentBlock>, Error> {
257303
};
258304

259305
result.push(ContentBlock::Image {
260-
format,
261-
source: ClientImageSource {
262-
bytes: base64_data,
306+
image: ImageBlock {
307+
format,
308+
source: ClientImageSource {
309+
bytes: base64_data,
310+
},
263311
},
264312
});
265313
}
@@ -286,18 +334,34 @@ fn message_to_system_content(message: &Message) -> Vec<SystemContentBlock> {
286334
}
287335

288336
fn tool_definition_to_tool(tool: &ToolDefinition) -> Result<Tool, Error> {
289-
match serde_json::from_str(&tool.parameters_schema) {
290-
Ok(json_schema) => Ok(Tool {
291-
tool_spec: ToolSpec {
292-
name: tool.name.clone(),
293-
description: tool.description.clone().unwrap_or_default(),
294-
input_schema: json_schema,
337+
use serde_json::Value;
338+
339+
let schema_value = if tool.parameters_schema.trim().is_empty() {
340+
serde_json::json!({
341+
"type": "object",
342+
"properties": {},
343+
"additionalProperties": false
344+
})
345+
} else {
346+
match serde_json::from_str::<Value>(&tool.parameters_schema) {
347+
Ok(value) => value,
348+
Err(error) => {
349+
return Err(Error {
350+
code: ErrorCode::InternalError,
351+
message: format!("Failed to parse tool parameters for {}: {error}", tool.name),
352+
provider_error_json: None,
353+
});
354+
}
355+
}
356+
};
357+
358+
Ok(Tool {
359+
tool_spec: ToolSpec {
360+
name: tool.name.clone(),
361+
description: tool.description.clone().unwrap_or_default(),
362+
input_schema: ToolInputSchema {
363+
json: schema_value,
295364
},
296-
}),
297-
Err(error) => Err(Error {
298-
code: ErrorCode::InternalError,
299-
message: format!("Failed to parse tool parameters for {}: {error}", tool.name),
300-
provider_error_json: None,
301-
}),
302-
}
365+
},
366+
})
303367
}

llm/llm/src/event_source/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ fn check_response(response: Response) -> Result<Response, Error> {
140140
matches!(
141141
(mime_type.type_(), mime_type.subtype()),
142142
(mime::TEXT, mime::EVENT_STREAM)
143-
) || mime_type.subtype().as_str().contains("ndjson")
143+
) || mime_type.subtype().as_str().contains("ndjson") || content_type.to_str().unwrap_or("").contains("vnd.amazon.eventstream")
144144
})
145145
.unwrap_or(false)
146146
{

0 commit comments

Comments
 (0)