Skip to content

Commit 4d25753

Browse files
committed
fix: multi-turn needs output type for openai component
1 parent 0ad73e1 commit 4d25753

File tree

6 files changed

+235
-51
lines changed

6 files changed

+235
-51
lines changed

llm/openai/src/client.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ pub enum InputItem {
195195
pub enum InnerInput {
196196
TextInput(String),
197197
List(Vec<InnerInputItem>),
198+
OutputList(Vec<OutputMessageContent>)
198199
}
199200

200201
#[derive(Debug, Clone, Serialize, Deserialize)]

llm/openai/src/conversions.rs

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,12 @@ pub fn create_request(
4343
pub fn messages_to_input_items(messages: Vec<Message>) -> Vec<InputItem> {
4444
let mut items = Vec::new();
4545
for message in messages {
46-
let role = to_openai_role_name(message.role).to_string();
47-
let mut input_items = Vec::new();
48-
for content_part in message.content {
49-
input_items.push(content_part_to_inner_input_item(content_part));
50-
}
46+
let item = match message.role {
47+
Role::Assistant => llm_message_to_assistant_message(message),
48+
_ => llm_message_to_other_message(message),
49+
};
5150

52-
items.push(InputItem::InputMessage {
53-
role,
54-
content: InnerInput::List(input_items),
55-
});
51+
items.push(item);
5652
}
5753
items
5854
}
@@ -122,35 +118,59 @@ pub fn to_openai_role_name(role: Role) -> &'static str {
122118
}
123119
}
124120

125-
pub fn content_part_to_inner_input_item(content_part: ContentPart) -> InnerInputItem {
126-
match content_part {
127-
ContentPart::Text(msg) => InnerInputItem::TextInput { text: msg },
128-
ContentPart::Image(image_reference) => match image_reference {
129-
ImageReference::Url(image_url) => InnerInputItem::ImageInput {
130-
image_url: image_url.url,
131-
detail: match image_url.detail {
132-
Some(ImageDetail::Auto) => Detail::Auto,
133-
Some(ImageDetail::Low) => Detail::Low,
134-
Some(ImageDetail::High) => Detail::High,
135-
None => Detail::default(),
136-
},
137-
},
138-
ImageReference::Inline(image_source) => {
139-
let base64_data = general_purpose::STANDARD.encode(&image_source.data);
140-
let mime_type = &image_source.mime_type; // This is already a string
141-
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
142-
143-
InnerInputItem::ImageInput {
144-
image_url: data_url,
145-
detail: match image_source.detail {
121+
pub fn llm_message_to_other_message(message: Message) -> InputItem {
122+
let mut items = Vec::new();
123+
for content_part in message.content {
124+
let item = match content_part {
125+
ContentPart::Text(msg) => InnerInputItem::TextInput { text: msg },
126+
ContentPart::Image(image_reference) => match image_reference {
127+
ImageReference::Url(image_url) => InnerInputItem::ImageInput {
128+
image_url: image_url.url,
129+
detail: match image_url.detail {
146130
Some(ImageDetail::Auto) => Detail::Auto,
147131
Some(ImageDetail::Low) => Detail::Low,
148132
Some(ImageDetail::High) => Detail::High,
149133
None => Detail::default(),
150134
},
135+
},
136+
ImageReference::Inline(image_source) => {
137+
let base64_data = general_purpose::STANDARD.encode(&image_source.data);
138+
let mime_type = &image_source.mime_type; // This is already a string
139+
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
140+
141+
InnerInputItem::ImageInput {
142+
image_url: data_url,
143+
detail: match image_source.detail {
144+
Some(ImageDetail::Auto) => Detail::Auto,
145+
Some(ImageDetail::Low) => Detail::Low,
146+
Some(ImageDetail::High) => Detail::High,
147+
None => Detail::default(),
148+
},
149+
}
151150
}
152-
}
153-
},
151+
},
152+
};
153+
items.push(item);
154+
}
155+
156+
InputItem::InputMessage {
157+
role: to_openai_role_name(message.role).to_string(),
158+
content: InnerInput::List(items),
159+
}
160+
}
161+
162+
pub fn llm_message_to_assistant_message(message: Message) -> InputItem {
163+
let mut items = Vec::new();
164+
165+
for content_part in message.content {
166+
if let ContentPart::Text(msg) = content_part {
167+
items.push(OutputMessageContent::Text { text: msg });
168+
}
169+
}
170+
171+
InputItem::InputMessage {
172+
role: to_openai_role_name(message.role).to_string(),
173+
content: InnerInput::OutputList(items),
154174
}
155175
}
156176

test/components-rust/test-llm/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ path = "wit-generated"
3737

3838
[package.metadata.component.target.dependencies]
3939
"golem:llm" = { path = "wit-generated/deps/golem-llm" }
40-
"wasi:clocks" = { path = "wit-generated/deps/clocks" }
4140
"wasi:io" = { path = "wit-generated/deps/io" }
41+
"wasi:clocks" = { path = "wit-generated/deps/clocks" }
4242
"golem:rpc" = { path = "wit-generated/deps/golem-rpc" }
4343
"test:helper-client" = { path = "wit-generated/deps/test_helper-client" }
4444
"test:llm-exports" = { path = "wit-generated/deps/test_llm-exports" }

test/components-rust/test-llm/src/lib.rs

Lines changed: 126 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
#[allow(static_mut_refs)]
22
mod bindings;
33

4-
use golem_rust::atomically;
54
use crate::bindings::exports::test::llm_exports::test_llm_api::*;
65
use crate::bindings::golem::llm::llm;
76
use crate::bindings::golem::llm::llm::StreamEvent;
87
use crate::bindings::test::helper_client::test_helper_client::TestHelperApi;
8+
use golem_rust::atomically;
9+
10+
mod utils;
911

1012
struct Component;
1113

@@ -17,7 +19,7 @@ const MODEL: &'static str = "claude-3-7-sonnet-20250219";
1719
const MODEL: &'static str = "grok-3-beta";
1820
#[cfg(feature = "openrouter")]
1921
const MODEL: &'static str = "openrouter/auto";
20-
#[cfg(feature = "ollama")]
22+
#[cfg(feature = "ollama")]
2123
const MODEL: &'static str = "qwen3:1.7b";
2224

2325
#[cfg(feature = "openai")]
@@ -28,7 +30,7 @@ const IMAGE_MODEL: &'static str = "claude-3-7-sonnet-20250219";
2830
const IMAGE_MODEL: &'static str = "grok-2-vision-latest";
2931
#[cfg(feature = "openrouter")]
3032
const IMAGE_MODEL: &'static str = "openrouter/auto";
31-
#[cfg(feature = "ollama")]
33+
#[cfg(feature = "ollama")]
3234
const IMAGE_MODEL: &'static str = "gemma3:4b";
3335

3436
impl Guest for Component {
@@ -67,9 +69,14 @@ impl Guest for Component {
6769
.map(|content| match content {
6870
llm::ContentPart::Text(txt) => txt,
6971
llm::ContentPart::Image(image_ref) => match image_ref {
70-
llm::ImageReference::Url(url_data) => format!("[IMAGE URL: {}]", url_data.url),
71-
llm::ImageReference::Inline(inline_data) => format!("[INLINE IMAGE: {} bytes, mime: {}]", inline_data.data.len(), inline_data.mime_type),
72-
}
72+
llm::ImageReference::Url(url_data) =>
73+
format!("[IMAGE URL: {}]", url_data.url),
74+
llm::ImageReference::Inline(inline_data) => format!(
75+
"[INLINE IMAGE: {} bytes, mime: {}]",
76+
inline_data.data.len(),
77+
inline_data.mime_type
78+
),
79+
},
7380
})
7481
.collect::<Vec<_>>()
7582
.join(", ")
@@ -154,7 +161,7 @@ impl Guest for Component {
154161
vec![]
155162
}
156163
};
157-
164+
158165
if !tool_request.is_empty() {
159166
let mut calls = Vec::new();
160167
for call in tool_request {
@@ -385,9 +392,14 @@ impl Guest for Component {
385392
.map(|content| match content {
386393
llm::ContentPart::Text(txt) => txt,
387394
llm::ContentPart::Image(image_ref) => match image_ref {
388-
llm::ImageReference::Url(url_data) => format!("[IMAGE URL: {}]", url_data.url),
389-
llm::ImageReference::Inline(inline_data) => format!("[INLINE IMAGE: {} bytes, mime: {}]", inline_data.data.len(), inline_data.mime_type),
390-
}
395+
llm::ImageReference::Url(url_data) =>
396+
format!("[IMAGE URL: {}]", url_data.url),
397+
llm::ImageReference::Inline(inline_data) => format!(
398+
"[INLINE IMAGE: {} bytes, mime: {}]",
399+
inline_data.data.len(),
400+
inline_data.mime_type
401+
),
402+
},
391403
})
392404
.collect::<Vec<_>>()
393405
.join(", ")
@@ -407,7 +419,7 @@ impl Guest for Component {
407419
}
408420
}
409421

410-
/// test6 simulates a crash during a streaming LLM response, but only first time.
422+
/// test6 simulates a crash during a streaming LLM response, but only first time.
411423
/// after the automatic recovery it will continue and finish the request successfully.
412424
fn test6() -> String {
413425
let config = llm::Config {
@@ -456,12 +468,20 @@ impl Guest for Component {
456468
}
457469
llm::ContentPart::Image(image_ref) => match image_ref {
458470
llm::ImageReference::Url(url_data) => {
459-
result.push_str(&format!("IMAGE URL: {} ({:?})\n", url_data.url, url_data.detail));
471+
result.push_str(&format!(
472+
"IMAGE URL: {} ({:?})\n",
473+
url_data.url, url_data.detail
474+
));
460475
}
461476
llm::ImageReference::Inline(inline_data) => {
462-
result.push_str(&format!("INLINE IMAGE: {} bytes, mime: {}, detail: {:?}\n", inline_data.data.len(), inline_data.mime_type, inline_data.detail));
477+
result.push_str(&format!(
478+
"INLINE IMAGE: {} bytes, mime: {}, detail: {:?}\n",
479+
inline_data.data.len(),
480+
inline_data.mime_type,
481+
inline_data.detail
482+
));
463483
}
464-
}
484+
},
465485
}
466486
}
467487
}
@@ -528,7 +548,10 @@ impl Guest for Component {
528548
role: llm::Role::User,
529549
name: None,
530550
content: vec![
531-
llm::ContentPart::Text("Please describe this cat image in detail. What breed might it be?".to_string()),
551+
llm::ContentPart::Text(
552+
"Please describe this cat image in detail. What breed might it be?"
553+
.to_string(),
554+
),
532555
llm::ContentPart::Image(llm::ImageReference::Inline(llm::ImageSource {
533556
data: buffer,
534557
mime_type: "image/png".to_string(),
@@ -549,9 +572,14 @@ impl Guest for Component {
549572
.map(|content| match content {
550573
llm::ContentPart::Text(txt) => txt,
551574
llm::ContentPart::Image(image_ref) => match image_ref {
552-
llm::ImageReference::Url(url_data) => format!("[IMAGE URL: {}]", url_data.url),
553-
llm::ImageReference::Inline(inline_data) => format!("[INLINE IMAGE: {} bytes, mime: {}]", inline_data.data.len(), inline_data.mime_type),
554-
}
575+
llm::ImageReference::Url(url_data) =>
576+
format!("[IMAGE URL: {}]", url_data.url),
577+
llm::ImageReference::Inline(inline_data) => format!(
578+
"[INLINE IMAGE: {} bytes, mime: {}]",
579+
inline_data.data.len(),
580+
inline_data.mime_type
581+
),
582+
},
555583
})
556584
.collect::<Vec<_>>()
557585
.join(", ")
@@ -570,6 +598,86 @@ impl Guest for Component {
570598
}
571599
}
572600
}
601+
fn test8() -> String {
602+
let config = llm::Config {
603+
model: MODEL.to_string(),
604+
temperature: Some(0.2),
605+
max_tokens: None,
606+
stop_sequences: None,
607+
tools: vec![],
608+
tool_choice: None,
609+
provider_options: vec![],
610+
};
611+
612+
let mut messages = vec![llm::Message {
613+
role: llm::Role::User,
614+
name: Some("vigoo".to_string()),
615+
content: vec![llm::ContentPart::Text(
616+
"Do you know what a haiku is?".to_string(),
617+
)],
618+
}];
619+
620+
let stream = llm::stream(&messages, &config);
621+
622+
let mut result = String::new();
623+
624+
let name = std::env::var("GOLEM_WORKER_NAME").unwrap();
625+
626+
loop {
627+
match utils::consume_next_event(&stream) {
628+
Some(delta) => {
629+
result.push_str(&delta);
630+
}
631+
None => break,
632+
}
633+
}
634+
635+
messages.push(llm::Message {
636+
role: llm::Role::Assistant,
637+
name: Some("assistant".to_string()),
638+
content: vec![llm::ContentPart::Text(result)],
639+
});
640+
641+
messages.push(llm::Message {
642+
role: llm::Role::User,
643+
name: Some("vigoo".to_string()),
644+
content: vec![llm::ContentPart::Text(
645+
"Can you write one for me?".to_string(),
646+
)],
647+
});
648+
649+
println!("Message: {messages:?}");
650+
651+
let stream = llm::stream(&messages, &config);
652+
653+
let mut result = String::new();
654+
655+
let name = std::env::var("GOLEM_WORKER_NAME").unwrap();
656+
let mut round = 0;
657+
658+
loop {
659+
match utils::consume_next_event(&stream) {
660+
Some(delta) => {
661+
result.push_str(&delta);
662+
}
663+
None => break,
664+
}
665+
666+
if round == 2 {
667+
atomically(|| {
668+
let client = TestHelperApi::new(&name);
669+
let answer = client.blocking_inc_and_get();
670+
if answer == 1 {
671+
panic!("Simulating crash")
672+
}
673+
});
674+
}
675+
676+
round += 1;
677+
}
678+
679+
result
680+
}
573681
}
574682

575683
bindings::export!(Component with_types_in bindings);

0 commit comments

Comments
 (0)