Skip to content

Commit 0a34c8c

Browse files
committed
stream fix
1 parent e96d582 commit 0a34c8c

File tree

6 files changed

+386
-45
lines changed

6 files changed

+386
-45
lines changed

Cargo.lock

Lines changed: 79 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llm/bedrock/src/lib.rs

Lines changed: 111 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use golem_llm::golem::llm::llm::{
1717
use golem_llm::LOGGING_STATE;
1818
use golem_rust::wasm_rpc::Pollable;
1919
use log::trace;
20+
use serde::Deserialize;
2021
use serde_json::Value;
2122
use std::cell::{Ref, RefCell, RefMut};
2223

@@ -27,6 +28,73 @@ struct BedrockChatStream {
2728
response_metadata: RefCell<ResponseMetadata>,
2829
}
2930

31+
32+
/// [2025-06-29T18:11:10.458Z] [TRACE ] [golem_llm_bedrock] llm/bedrock/src/lib.rs:84: Received raw stream event:
33+
/// {
34+
/// "contentBlockIndex":1,
35+
/// "delta":{
36+
/// "toolUse":{
37+
/// "input":" 10
38+
/// }"}},
39+
/// "p":"abcdefghijklmnopqrstuvwxyzAB"
40+
/// }
41+
/// {
42+
/// "contentBlockIndex":0,
43+
/// "delta":
44+
/// {
45+
/// "text":" German"
46+
/// },
47+
/// "p":"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX"
48+
/// }
49+
///
50+
51+
#[derive(Debug, Deserialize)]
52+
#[serde(untagged)]
53+
pub enum Delta {
54+
ToolUse {
55+
#[serde(rename = "toolUse")]
56+
tool_use: ToolUse,
57+
},
58+
Text {
59+
text: String,
60+
},
61+
}
62+
63+
#[derive(Debug, Deserialize)]
64+
pub struct ToolUse {
65+
pub input: String,
66+
}
67+
68+
// Additional structs for different message types
69+
#[derive(Debug, Deserialize)]
70+
pub struct MessageStart {
71+
pub p: String,
72+
pub role: String,
73+
}
74+
75+
#[derive(Debug, Deserialize)]
76+
pub struct MessageStop {
77+
pub p: String,
78+
#[serde(rename = "stopReason")]
79+
pub stop_reason: String,
80+
}
81+
82+
#[derive(Debug, Deserialize)]
83+
pub struct MetadataMessage {
84+
pub p: String,
85+
pub usage: Option<crate::client::Usage>,
86+
pub metrics: Option<serde_json::Value>,
87+
}
88+
89+
#[derive(Debug, Deserialize)]
90+
pub struct EventContentBlock {
91+
#[serde(rename = "contentBlockIndex")]
92+
pub content_block_index: u32,
93+
pub delta : Delta,
94+
pub p: String,
95+
}
96+
97+
3098
impl BedrockChatStream {
3199
pub fn new(stream: EventSource) -> LlmChatStream<Self> {
32100
LlmChatStream::new(BedrockChatStream {
@@ -85,53 +153,47 @@ impl LlmChatStreamState for BedrockChatStream {
85153

86154
let json: Value = serde_json::from_str(raw)
87155
.map_err(|err| format!("Failed to deserialize stream event: {err}"))?;
88-
89-
if let Some(content_block_delta) = json.get("contentBlockDelta") {
90-
if let Some(delta) = content_block_delta.get("delta") {
91-
if let Some(text) = delta.get("text").and_then(|v| v.as_str()) {
92-
return Ok(Some(StreamEvent::Delta(StreamDelta {
93-
content: Some(vec![ContentPart::Text(text.to_string())]),
94-
tool_calls: None,
95-
})));
96-
}
97-
}
98-
}
99-
100-
if let Some(content_block_start) = json.get("contentBlockStart") {
101-
if let Some(start) = content_block_start.get("start") {
102-
if let Some(tool_use) = start.get("toolUse") {
103-
if let (Some(tool_use_id), Some(name)) = (
104-
tool_use.get("toolUseId").and_then(|v| v.as_str()),
105-
tool_use.get("name").and_then(|v| v.as_str()),
106-
) {
107-
if let Some(input) = tool_use.get("input") {
156+
157+
// 1. Handle content block delta messages (contentBlockIndex + delta)
158+
if json.get("contentBlockIndex").is_some() && json.get("delta").is_some() {
159+
match serde_json::from_value::<EventContentBlock>(json.clone()) {
160+
Ok(event_content_block) => {
161+
match event_content_block.delta {
162+
Delta::Text { text } => {
163+
return Ok(Some(StreamEvent::Delta(StreamDelta {
164+
content: Some(vec![ContentPart::Text(text)]),
165+
tool_calls: None,
166+
})));
167+
}
168+
Delta::ToolUse { tool_use } => {
169+
// Handle tool use delta - this would need tool call ID and name from earlier message
170+
// For now, just return the input as text
108171
return Ok(Some(StreamEvent::Delta(StreamDelta {
109-
content: None,
110-
tool_calls: Some(vec![ToolCall {
111-
id: tool_use_id.to_string(),
112-
name: name.to_string(),
113-
arguments_json: serde_json::to_string(input).unwrap(),
114-
}]),
172+
content: Some(vec![ContentPart::Text(tool_use.input)]),
173+
tool_calls: None,
115174
})));
116175
}
117176
}
118177
}
178+
Err(err) => {
179+
trace!("Failed to parse as EventContentBlock: {}", err);
180+
// Continue to other parsing attempts
181+
}
119182
}
120183
}
121184

122-
if let Some(metadata) = json.get("metadata") {
123-
if let Some(usage) = metadata.get("usage") {
124-
if let Ok(bedrock_usage) =
125-
serde_json::from_value::<crate::client::Usage>(usage.clone())
126-
{
127-
self.response_metadata.borrow_mut().usage = Some(convert_usage(bedrock_usage));
128-
}
185+
// 3. Handle message start (role + p)
186+
if json.get("role").is_some() {
187+
if let Ok(_message_start) = serde_json::from_value::<MessageStart>(json.clone()) {
188+
// Message start event - just metadata, no content to return
189+
return Ok(None);
129190
}
130191
}
131192

132-
if let Some(message_stop) = json.get("messageStop") {
133-
if let Some(stop_reason) = message_stop.get("stopReason").and_then(|v| v.as_str()) {
134-
let stop_reason = match stop_reason {
193+
// 4. Handle message stop with stopReason
194+
if json.get("stopReason").is_some() {
195+
if let Ok(message_stop) = serde_json::from_value::<MessageStop>(json.clone()) {
196+
let stop_reason = match message_stop.stop_reason.as_str() {
135197
"end_turn" => crate::client::StopReason::EndTurn,
136198
"tool_use" => crate::client::StopReason::ToolUse,
137199
"max_tokens" => crate::client::StopReason::MaxTokens,
@@ -142,12 +204,22 @@ impl LlmChatStreamState for BedrockChatStream {
142204
};
143205
self.response_metadata.borrow_mut().finish_reason =
144206
Some(stop_reason_to_finish_reason(stop_reason));
145-
}
146207

147-
let response_metadata = self.response_metadata.borrow().clone();
148-
return Ok(Some(StreamEvent::Finish(response_metadata)));
208+
let response_metadata = self.response_metadata.borrow().clone();
209+
return Ok(Some(StreamEvent::Finish(response_metadata)));
210+
}
149211
}
150212

213+
// 5. Handle metadata messages with usage/metrics
214+
if json.get("usage").is_some() || json.get("metrics").is_some() {
215+
if let Ok(metadata) = serde_json::from_value::<MetadataMessage>(json.clone()) {
216+
if let Some(usage) = metadata.usage {
217+
self.response_metadata.borrow_mut().usage = Some(convert_usage(usage));
218+
}
219+
// Metadata processed, no event to return
220+
return Ok(None);
221+
}
222+
}
151223
Ok(None)
152224
}
153225
}

llm/llm/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ path = "src/lib.rs"
1212
crate-type = ["rlib"]
1313

1414
[dependencies]
15+
aws-smithy-eventstream = "0.60.9"
16+
aws-smithy-types = "1.3.2"
17+
base64 = "0.22.1"
1518
golem-rust = { workspace = true }
1619
log = { workspace = true }
1720
mime = "0.3.17"

0 commit comments

Comments
 (0)