From 3ba4cb7a17fd2bad293283c1900d262b8e39360e Mon Sep 17 00:00:00 2001 From: neo773 Date: Tue, 1 Jul 2025 02:02:04 +0530 Subject: [PATCH] fix: OpenAI Multi-turn support --- llm/anthropic/src/bindings.rs | 11 +- llm/anthropic/src/conversions.rs | 2 +- llm/grok/src/bindings.rs | 11 +- llm/grok/src/conversions.rs | 2 +- llm/llm/src/event_source/ndjson_stream.rs | 2 +- llm/llm/src/event_source/stream.rs | 6 +- llm/ollama/src/bindings.rs | 11 +- llm/ollama/src/client.rs | 2 +- llm/ollama/src/conversions.rs | 2 +- llm/openai/src/bindings.rs | 11 +- llm/openai/src/client.rs | 326 ++++++++-------- llm/openai/src/conversions.rs | 369 +++++++++--------- llm/openai/src/lib.rs | 306 +++++++++------ llm/openrouter/src/bindings.rs | 11 +- llm/openrouter/src/conversions.rs | 2 +- test/components-rust/test-llm/Cargo.toml | 5 - test/components-rust/test-llm/src/lib.rs | 107 +++++ .../components-rust/test-llm/wit/test-llm.wit | 1 + 18 files changed, 692 insertions(+), 495 deletions(-) diff --git a/llm/anthropic/src/bindings.rs b/llm/anthropic/src/bindings.rs index 70c5f1fd..1a54d616 100644 --- a/llm/anthropic/src/bindings.rs +++ b/llm/anthropic/src/bindings.rs @@ -1,12 +1,15 @@ -// Generated by `wit-bindgen` 0.36.0. DO NOT EDIT! +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! // Options used: // * runtime_path: "wit_bindgen_rt" // * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" // * generate_unused_types use golem_llm::golem::llm::llm as __with_name0; #[cfg(target_arch = "wasm32")] -#[link_section = "component-type:wit-bindgen:0.36.0:golem:llm-anthropic@1.0.0:llm-library:encoded world"] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-anthropic@1.0.0:llm-library:encoded world" +)] #[doc(hidden)] +#[allow(clippy::octal_escapes)] pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1762] = *b"\ \0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xe0\x0c\x01A\x02\x01\ A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ @@ -43,8 +46,8 @@ ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send \0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0%golem:\ llm-anthropic/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09\ -producers\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rus\ -t\x060.36.0"; +producers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rus\ +t\x060.41.0"; #[inline(never)] #[doc(hidden)] pub fn __link_custom_section_describing_imports() { diff --git a/llm/anthropic/src/conversions.rs b/llm/anthropic/src/conversions.rs index e332f139..e7d3175a 100644 --- a/llm/anthropic/src/conversions.rs +++ b/llm/anthropic/src/conversions.rs @@ -130,7 +130,7 @@ pub fn process_response(response: MessagesResponse) -> ChatEvent { Err(e) => { return ChatEvent::Error(Error { code: ErrorCode::InvalidRequest, - message: format!("Failed to decode base64 image data: {}", e), + message: format!("Failed to decode base64 image data: {e}"), provider_error_json: None, }); } diff --git a/llm/grok/src/bindings.rs b/llm/grok/src/bindings.rs index 2a101583..c2f60134 100644 --- a/llm/grok/src/bindings.rs +++ b/llm/grok/src/bindings.rs @@ -1,12 +1,15 @@ -// Generated by `wit-bindgen` 0.36.0. DO NOT EDIT! +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! // Options used: // * runtime_path: "wit_bindgen_rt" // * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" // * generate_unused_types use golem_llm::golem::llm::llm as __with_name0; #[cfg(target_arch = "wasm32")] -#[link_section = "component-type:wit-bindgen:0.36.0:golem:llm-grok@1.0.0:llm-library:encoded world"] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-grok@1.0.0:llm-library:encoded world" +)] #[doc(hidden)] +#[allow(clippy::octal_escapes)] pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1757] = *b"\ \0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xdb\x0c\x01A\x02\x01\ A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ @@ -43,8 +46,8 @@ ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send \0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0\x20gol\ em:llm-grok/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09p\ -roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rust\ -\x060.36.0"; +roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rust\ +\x060.41.0"; #[inline(never)] #[doc(hidden)] pub fn __link_custom_section_describing_imports() { diff --git a/llm/grok/src/conversions.rs b/llm/grok/src/conversions.rs index 68a5d570..129c128a 100644 --- a/llm/grok/src/conversions.rs +++ b/llm/grok/src/conversions.rs @@ -183,7 +183,7 @@ fn convert_content_parts(contents: Vec) -> crate::client::Content { let media_type = &image_source.mime_type; // This is already a string result.push(crate::client::ContentPart::ImageInput { image_url: crate::client::ImageUrl { - url: format!("data:{};base64,{}", media_type, base64_data), + url: format!("data:{media_type};base64,{base64_data}"), detail: image_source.detail.map(|d| d.into()), }, }); diff --git a/llm/llm/src/event_source/ndjson_stream.rs b/llm/llm/src/event_source/ndjson_stream.rs index e2f4cc1b..1b8ef377 100644 --- a/llm/llm/src/event_source/ndjson_stream.rs +++ b/llm/llm/src/event_source/ndjson_stream.rs @@ -126,7 +126,7 @@ fn try_parse_line( return Ok(None); } - trace!("Parsed NDJSON line: {}", line); + trace!("Parsed NDJSON line: {line}"); // Create a MessageEvent with the JSON line as data let event = MessageEvent { diff --git a/llm/llm/src/event_source/stream.rs b/llm/llm/src/event_source/stream.rs index 8f293367..13a5eeb5 100644 --- a/llm/llm/src/event_source/stream.rs +++ b/llm/llm/src/event_source/stream.rs @@ -56,9 +56,9 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Utf8(err) => f.write_fmt(format_args!("UTF8 error: {}", err)), - Self::Parser(err) => f.write_fmt(format_args!("Parse error: {}", err)), - Self::Transport(err) => f.write_fmt(format_args!("Transport error: {}", err)), + Self::Utf8(err) => f.write_fmt(format_args!("UTF8 error: {err}")), + Self::Parser(err) => f.write_fmt(format_args!("Parse error: {err}")), + Self::Transport(err) => f.write_fmt(format_args!("Transport error: {err}")), } } } diff --git a/llm/ollama/src/bindings.rs b/llm/ollama/src/bindings.rs index dbb70470..269cd07f 100644 --- a/llm/ollama/src/bindings.rs +++ b/llm/ollama/src/bindings.rs @@ -1,12 +1,15 @@ -// Generated by `wit-bindgen` 0.36.0. DO NOT EDIT! +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! // Options used: // * runtime_path: "wit_bindgen_rt" // * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" // * generate_unused_types use golem_llm::golem::llm::llm as __with_name0; #[cfg(target_arch = "wasm32")] -#[link_section = "component-type:wit-bindgen:0.36.0:golem:llm-ollama@1.0.0:llm-library:encoded world"] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-ollama@1.0.0:llm-library:encoded world" +)] #[doc(hidden)] +#[allow(clippy::octal_escapes)] pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1759] = *b"\ \0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xdd\x0c\x01A\x02\x01\ A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ @@ -43,8 +46,8 @@ ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send \0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0\"golem\ :llm-ollama/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09p\ -roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rust\ -\x060.36.0"; +roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rust\ +\x060.41.0"; #[inline(never)] #[doc(hidden)] pub fn __link_custom_section_describing_imports() { diff --git a/llm/ollama/src/client.rs b/llm/ollama/src/client.rs index e9514a8d..e2901e70 100644 --- a/llm/ollama/src/client.rs +++ b/llm/ollama/src/client.rs @@ -335,7 +335,7 @@ pub fn image_to_base64(source: &str) -> Result Error { Error { code: ErrorCode::InternalError, - message: format!("{}: {}", context, err), + message: format!("{context}: {err}"), provider_error_json: None, } } diff --git a/llm/ollama/src/conversions.rs b/llm/ollama/src/conversions.rs index b1db65c6..8d64e954 100644 --- a/llm/ollama/src/conversions.rs +++ b/llm/ollama/src/conversions.rs @@ -214,7 +214,7 @@ pub fn process_response(response: CompletionsResponse) -> ChatEvent { }; ChatEvent::Message(CompleteResponse { - id: format!("ollama-{}", timestamp), + id: format!("ollama-{timestamp}"), content, tool_calls, metadata, diff --git a/llm/openai/src/bindings.rs b/llm/openai/src/bindings.rs index c960248a..6d0a7728 100644 --- a/llm/openai/src/bindings.rs +++ b/llm/openai/src/bindings.rs @@ -1,12 +1,15 @@ -// Generated by `wit-bindgen` 0.36.0. DO NOT EDIT! +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! // Options used: // * runtime_path: "wit_bindgen_rt" // * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" // * generate_unused_types use golem_llm::golem::llm::llm as __with_name0; #[cfg(target_arch = "wasm32")] -#[link_section = "component-type:wit-bindgen:0.36.0:golem:llm-openai@1.0.0:llm-library:encoded world"] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-openai@1.0.0:llm-library:encoded world" +)] #[doc(hidden)] +#[allow(clippy::octal_escapes)] pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1759] = *b"\ \0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xdd\x0c\x01A\x02\x01\ A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ @@ -43,8 +46,8 @@ ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send \0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0\"golem\ :llm-openai/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09p\ -roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rust\ -\x060.36.0"; +roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rust\ +\x060.41.0"; #[inline(never)] #[doc(hidden)] pub fn __link_custom_section_describing_imports() { diff --git a/llm/openai/src/client.rs b/llm/openai/src/client.rs index 688939a3..ea4fa758 100644 --- a/llm/openai/src/client.rs +++ b/llm/openai/src/client.rs @@ -10,15 +10,15 @@ use std::fmt::Debug; const BASE_URL: &str = "https://api.openai.com"; -/// The OpenAI API client for creating model responses. +/// The OpenAI Chat Completions API client. /// -/// Based on https://platform.openai.com/docs/api-reference/responses/create -pub struct ResponsesApi { +/// Based on https://platform.openai.com/docs/api-reference/chat/create +pub struct CompletionsApi { openai_api_key: String, client: Client, } -impl ResponsesApi { +impl CompletionsApi { pub fn new(openai_api_key: String) -> Self { let client = Client::builder() .build() @@ -29,15 +29,12 @@ impl ResponsesApi { } } - pub fn create_model_response( - &self, - request: CreateModelResponseRequest, - ) -> Result { + pub fn send_messages(&self, request: CompletionsRequest) -> Result { trace!("Sending request to OpenAI API: {request:?}"); let response: Response = self .client - .request(Method::POST, format!("{BASE_URL}/v1/responses")) + .request(Method::POST, format!("{BASE_URL}/v1/chat/completions")) .bearer_auth(&self.openai_api_key) .json(&request) .send() @@ -46,15 +43,12 @@ impl ResponsesApi { parse_response(response) } - pub fn stream_model_response( - &self, - request: CreateModelResponseRequest, - ) -> Result { + pub fn stream_send_messages(&self, request: CompletionsRequest) -> Result { trace!("Sending request to OpenAI API: {request:?}"); let response: Response = self .client - .request(Method::POST, format!("{BASE_URL}/v1/responses")) + .request(Method::POST, format!("{BASE_URL}/v1/chat/completions")) .bearer_auth(&self.openai_api_key) .header( reqwest::header::ACCEPT, @@ -72,18 +66,33 @@ impl ResponsesApi { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CreateModelResponseRequest { - pub input: Input, +pub struct CompletionsRequest { + pub messages: Vec, pub model: String, #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_completion_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub max_output_tokens: Option, + pub tool_choice: Option, #[serde(skip_serializing_if = "Vec::is_empty")] pub tools: Vec, #[serde(skip_serializing_if = "Option::is_none")] - pub tool_choice: Option, - pub stream: bool, + pub top_logprobs: Option, #[serde(skip_serializing_if = "Option::is_none")] pub top_p: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -91,123 +100,80 @@ pub struct CreateModelResponseRequest { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CreateModelResponseResponse { - pub id: String, - pub created_at: u64, - pub error: Option, - pub incomplete_details: Option, - pub status: Status, - pub output: Vec, - pub usage: Option, - pub metadata: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type")] -pub enum OutputItem { - #[serde(rename = "message")] - Message { - id: String, - content: Vec, - role: String, - status: Status, - }, - #[serde(rename = "function_call")] - ToolCall { - arguments: String, - call_id: String, - name: String, - id: String, - status: Status, - }, +pub struct StreamOptions { + pub include_usage: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] -pub enum OutputMessageContent { - #[serde(rename = "output_text")] - Text { text: String }, - #[serde(rename = "refusal")] - Refusal { refusal: String }, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ErrorObject { - pub code: String, - pub message: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum Status { - #[serde(rename = "completed")] - Completed, - #[serde(rename = "failed")] - Failed, - #[serde(rename = "in_progress")] - InProgress, - #[serde(rename = "incomplete")] - Incomplete, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct IncompleteDetailsObject { - pub reason: String, +pub enum Tool { + #[serde(rename = "function")] + Function { function: Function }, } #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -pub enum Input { - TextInput(String), - List(Vec), +pub struct Function { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, } -/// An item representing part of the context for the response to be generated by the model. #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type")] -pub enum InputItem { - #[serde(rename = "message")] - InputMessage { - /// A list of one or many input items to the model, containing different content types. - content: InnerInput, - /// The role of the message input. One of user, system, or developer. - role: String, +#[serde(tag = "role")] +pub enum Message { + #[serde(rename = "system")] + System { + content: Content, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, }, - #[serde(rename = "function_call")] - ToolCall { - /// A JSON string of the arguments to pass to the function. - arguments: String, - /// The unique ID of the function tool call generated by the model. - call_id: String, - /// https://platform.openai.com/docs/api-reference/responses/create - name: String, + #[serde(rename = "user")] + User { + content: Content, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, }, - #[serde(rename = "function_call_output")] - ToolResult { - /// The unique ID of the function tool call generated by the model. - call_id: String, - /// A JSON string of the output of the function tool call. - output: String, + #[serde(rename = "assistant")] + Assistant { + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + }, + #[serde(rename = "tool")] + Tool { + content: Content, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + tool_call_id: String, }, } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] -pub enum InnerInput { +pub enum Content { TextInput(String), - List(Vec), + List(Vec), } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] -pub enum InnerInputItem { - #[serde(rename = "input_text")] +pub enum ContentPart { + #[serde(rename = "text")] TextInput { text: String }, - #[serde(rename = "input_image")] - ImageInput { - image_url: String, - #[serde(default)] - detail: Detail, - }, + #[serde(rename = "image_url")] + ImageInput { image_url: ImageUrl }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImageUrl { + pub url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] @@ -223,72 +189,116 @@ pub enum Detail { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] -pub enum Tool { +pub enum ToolCall { #[serde(rename = "function")] Function { - name: String, - #[serde(skip_serializing_if = "Option::is_none")] - description: Option, + function: FunctionCall, + id: String, #[serde(skip_serializing_if = "Option::is_none")] - parameters: Option, - strict: bool, + index: Option, }, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Usage { - pub input_tokens: u32, - pub input_tokens_details: InputTokensDetails, - pub output_tokens: u32, - pub output_tokens_details: OutputTokensDetails, - pub total_tokens: u32, +pub struct FunctionCall { + pub arguments: String, + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompletionsResponse { + pub choices: Vec, + pub created: u64, + pub id: String, + pub model: String, + pub system_fingerprint: Option, + pub usage: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct InputTokensDetails { - pub cached_tokens: u32, +pub struct Choice { + pub finish_reason: Option, + pub index: u32, + pub message: ResponseMessage, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct OutputTokensDetails { - pub reasoning_tokens: u32, +pub enum FinishReason { + #[serde(rename = "stop")] + Stop, + #[serde(rename = "length")] + Length, + #[serde(rename = "tool_calls")] + ToolCalls, + #[serde(rename = "content_filter")] + ContentFilter, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ResponseOutputTextDelta { - pub content_index: u32, - pub delta: String, - pub item_id: String, - pub output_index: u32, +pub struct ResponseMessage { + pub content: Option, + pub refusal: Option, + pub role: String, + pub tool_calls: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Usage { + pub completion_tokens: u32, + pub prompt_tokens: u32, + pub total_tokens: u32, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ResponseOutputItemDone { - pub item: OutputItem, - pub output_index: u32, +pub struct ChatCompletionChunk { + pub id: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Option, + pub system_fingerprint: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChoiceChunk { + pub index: u32, + pub delta: ChoiceDelta, + pub finish_reason: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChoiceDelta { + pub content: Option, + pub tool_calls: Option>, + pub role: Option, } fn parse_response(response: Response) -> Result { - let status = response.status(); - if status.is_success() { - let body = response - .json::() - .map_err(|err| from_reqwest_error("Failed to decode response body", err))?; - - trace!("Received response from OpenAI API: {body:?}"); - - Ok(body) - } else { - let body = response - .text() - .map_err(|err| from_reqwest_error("Failed to receive error response body", err))?; - - trace!("Received {status} response from OpenAI API: {body:?}"); - - Err(Error { - code: error_code_from_status(status), - message: format!("Request failed with {status}"), - provider_error_json: Some(body), - }) + trace!( + "Received response from OpenAI API, status: {}", + response.status() + ); + + if !response.status().is_success() { + return Err(Error { + code: error_code_from_status(response.status()), + message: format!("OpenAI API error: HTTP {}", response.status()), + provider_error_json: None, + }); } + + let body = response + .text() + .map_err(|err| from_reqwest_error("Failed to read response body", err))?; + + trace!("Response body: {body}"); + + let result: T = serde_json::from_str(&body).map_err(|err| Error { + code: golem_llm::golem::llm::llm::ErrorCode::InternalError, + message: format!("Failed to parse OpenAI API response: {err}"), + provider_error_json: Some(body), + })?; + + trace!("Parsed response: {result:?}"); + Ok(result) } diff --git a/llm/openai/src/conversions.rs b/llm/openai/src/conversions.rs index 43694c0f..de9f9bff 100644 --- a/llm/openai/src/conversions.rs +++ b/llm/openai/src/conversions.rs @@ -1,216 +1,154 @@ -use crate::client::{ - CreateModelResponseRequest, CreateModelResponseResponse, Detail, InnerInput, InnerInputItem, - Input, InputItem, OutputItem, OutputMessageContent, Tool, -}; +use crate::client::{CompletionsRequest, CompletionsResponse, Detail, Function, Tool}; use base64::{engine::general_purpose, Engine as _}; -use golem_llm::error::error_code_from_status; use golem_llm::golem::llm::llm::{ - ChatEvent, CompleteResponse, Config, ContentPart, Error, ErrorCode, ImageDetail, + ChatEvent, CompleteResponse, Config, ContentPart, Error, ErrorCode, FinishReason, ImageDetail, ImageReference, Message, ResponseMetadata, Role, ToolCall, ToolDefinition, ToolResult, Usage, }; -use reqwest::StatusCode; use std::collections::HashMap; -use std::str::FromStr; -pub fn create_request( - items: Vec, - config: Config, - tools: Vec, -) -> CreateModelResponseRequest { +pub fn create_request(messages: Vec, config: Config) -> Result { let options = config .provider_options .into_iter() .map(|kv| (kv.key, kv.value)) .collect::>(); - CreateModelResponseRequest { - input: Input::List(items), + let mut completion_messages = Vec::new(); + for message in messages { + match message.role { + Role::User => completion_messages.push(crate::client::Message::User { + name: message.name, + content: convert_content_parts(message.content), + }), + Role::Assistant => completion_messages.push(crate::client::Message::Assistant { + name: message.name, + content: Some(convert_content_parts(message.content)), + tool_calls: None, + }), + Role::System => completion_messages.push(crate::client::Message::System { + name: message.name, + content: convert_content_parts(message.content), + }), + Role::Tool => completion_messages.push(crate::client::Message::Tool { + name: message.name, + content: convert_content_parts(message.content), + tool_call_id: "unknown".to_string(), // This should be set properly in tool_results_to_messages + }), + } + } + + let mut tools = Vec::new(); + for tool in config.tools { + tools.push(tool_definition_to_tool(tool)?) + } + + Ok(CompletionsRequest { + messages: completion_messages, model: config.model, + frequency_penalty: options + .get("frequency_penalty") + .and_then(|fp_s| fp_s.parse::().ok()), + max_completion_tokens: config.max_tokens, + n: options.get("n").and_then(|n_s| n_s.parse::().ok()), + presence_penalty: options + .get("presence_penalty") + .and_then(|pp_s| pp_s.parse::().ok()), + seed: options + .get("seed") + .and_then(|seed_s| seed_s.parse::().ok()), + stop: config.stop_sequences, + stream: Some(false), + stream_options: None, temperature: config.temperature, - max_output_tokens: config.max_tokens, - tools, tool_choice: config.tool_choice, - stream: false, + tools, + top_logprobs: options + .get("top_logprobs") + .and_then(|top_logprobs_s| top_logprobs_s.parse::().ok()), top_p: options .get("top_p") .and_then(|top_p_s| top_p_s.parse::().ok()), - user: options - .get("user") - .and_then(|user_s| user_s.parse::().ok()), - } -} - -pub fn messages_to_input_items(messages: Vec) -> Vec { - let mut items = Vec::new(); - for message in messages { - let role = to_openai_role_name(message.role).to_string(); - let mut input_items = Vec::new(); - for content_part in message.content { - input_items.push(content_part_to_inner_input_item(content_part)); - } - - items.push(InputItem::InputMessage { - role, - content: InnerInput::List(input_items), - }); - } - items + user: options.get("user_id").cloned(), + }) } -pub fn tool_results_to_input_items(tool_results: Vec<(ToolCall, ToolResult)>) -> Vec { - let mut items = Vec::new(); +pub fn tool_results_to_messages( + tool_results: Vec<(ToolCall, ToolResult)>, +) -> Vec { + let mut messages = Vec::new(); for (tool_call, tool_result) in tool_results { - let tool_call = InputItem::ToolCall { - arguments: tool_call.arguments_json, - call_id: tool_call.id, - name: tool_call.name, - }; - let tool_result = match tool_result { - ToolResult::Success(success) => InputItem::ToolResult { - call_id: success.id, - output: format!(r#"{{ "success": {} }}"#, success.result_json), + messages.push(crate::client::Message::Assistant { + content: None, + name: None, + tool_calls: Some(vec![crate::client::ToolCall::Function { + function: crate::client::FunctionCall { + arguments: tool_call.arguments_json, + name: tool_call.name, + }, + id: tool_call.id.clone(), + index: None, + }]), + }); + let content = match tool_result { + ToolResult::Success(success) => crate::client::ContentPart::TextInput { + text: success.result_json, }, - ToolResult::Error(error) => InputItem::ToolResult { - call_id: error.id, - output: format!( - r#"{{ "error": {{ "code": {}, "message": {} }} }}"#, - error.error_code.unwrap_or_default(), - error.error_message - ), + ToolResult::Error(failure) => crate::client::ContentPart::TextInput { + text: failure.error_message, }, }; - items.push(tool_call); - items.push(tool_result); - } - items -} - -pub fn tool_defs_to_tools(tool_definitions: &[ToolDefinition]) -> Result, Error> { - let mut tools = Vec::new(); - for tool_def in tool_definitions { - match serde_json::from_str(&tool_def.parameters_schema) { - Ok(value) => { - let tool = Tool::Function { - name: tool_def.name.clone(), - description: tool_def.description.clone(), - parameters: Some(value), - strict: true, - }; - tools.push(tool); - } - Err(error) => { - Err(Error { - code: ErrorCode::InternalError, - message: format!( - "Failed to parse tool parameters for {}: {error}", - tool_def.name - ), - provider_error_json: None, - })?; - } - } - } - Ok(tools) -} - -pub fn to_openai_role_name(role: Role) -> &'static str { - match role { - Role::User => "user", - Role::Assistant => "assistant", - Role::System => "system", - Role::Tool => "tool", + messages.push(crate::client::Message::Tool { + name: None, + content: crate::client::Content::List(vec![content]), + tool_call_id: tool_call.id, + }); } + messages } -pub fn content_part_to_inner_input_item(content_part: ContentPart) -> InnerInputItem { - match content_part { - ContentPart::Text(msg) => InnerInputItem::TextInput { text: msg }, - ContentPart::Image(image_reference) => match image_reference { - ImageReference::Url(image_url) => InnerInputItem::ImageInput { - image_url: image_url.url, - detail: match image_url.detail { - Some(ImageDetail::Auto) => Detail::Auto, - Some(ImageDetail::Low) => Detail::Low, - Some(ImageDetail::High) => Detail::High, - None => Detail::default(), - }, +fn tool_definition_to_tool(tool: ToolDefinition) -> Result { + match serde_json::from_str(&tool.parameters_schema) { + Ok(value) => Ok(Tool::Function { + function: Function { + name: tool.name, + description: tool.description, + parameters: Some(value), }, - ImageReference::Inline(image_source) => { - let base64_data = general_purpose::STANDARD.encode(&image_source.data); - let mime_type = &image_source.mime_type; // This is already a string - let data_url = format!("data:{};base64,{}", mime_type, base64_data); - - InnerInputItem::ImageInput { - image_url: data_url, - detail: match image_source.detail { - Some(ImageDetail::Auto) => Detail::Auto, - Some(ImageDetail::Low) => Detail::Low, - Some(ImageDetail::High) => Detail::High, - None => Detail::default(), - }, - } - } - }, - } -} - -pub fn parse_error_code(code: String) -> ErrorCode { - if let Some(code) = ::from_str(&code) - .ok() - .and_then(|code| StatusCode::from_u16(code).ok()) - { - error_code_from_status(code) - } else { - ErrorCode::InternalError + }), + Err(error) => Err(Error { + code: ErrorCode::InternalError, + message: format!("Failed to parse tool parameters for {}: {error}", tool.name), + provider_error_json: None, + }), } } -pub fn process_model_response(response: CreateModelResponseResponse) -> ChatEvent { - if let Some(error) = response.error { - ChatEvent::Error(Error { - code: parse_error_code(error.code), - message: error.message, - provider_error_json: None, - }) - } else { +pub fn process_response(response: CompletionsResponse) -> ChatEvent { + let choice = response.choices.first(); + if let Some(choice) = choice { let mut contents = Vec::new(); let mut tool_calls = Vec::new(); - let metadata = create_response_metadata(&response); - - for output_item in response.output { - match output_item { - OutputItem::Message { content, .. } => { - for content in content { - match content { - OutputMessageContent::Text { text, .. } => { - contents.push(ContentPart::Text(text)); - } - OutputMessageContent::Refusal { refusal, .. } => { - contents.push(ContentPart::Text(format!("Refusal: {refusal}"))); - } - } - } - } - OutputItem::ToolCall { - arguments, - call_id, - name, - .. - } => { - let tool_call = ToolCall { - id: call_id, - name, - arguments_json: arguments, - }; - tool_calls.push(tool_call); - } - } + if let Some(content) = &choice.message.content { + contents.push(ContentPart::Text(content.clone())); + } + + let empty = Vec::new(); + for tool_call in choice.message.tool_calls.as_ref().unwrap_or(&empty) { + tool_calls.push(convert_tool_call(tool_call)); } - if contents.is_empty() { + if contents.is_empty() && !tool_calls.is_empty() { ChatEvent::ToolRequest(tool_calls) } else { + let metadata = ResponseMetadata { + finish_reason: choice.finish_reason.as_ref().map(convert_finish_reason), + usage: response.usage.as_ref().map(convert_usage), + provider_id: Some(response.id.clone()), + timestamp: Some(response.created.to_string()), + provider_metadata_json: None, + }; + ChatEvent::Message(CompleteResponse { id: response.id, content: contents, @@ -218,19 +156,78 @@ pub fn process_model_response(response: CreateModelResponseResponse) -> ChatEven metadata, }) } + } else { + ChatEvent::Error(Error { + code: ErrorCode::InternalError, + message: "No choices in response".to_string(), + provider_error_json: None, + }) } } -pub fn create_response_metadata(response: &CreateModelResponseResponse) -> ResponseMetadata { - ResponseMetadata { - finish_reason: None, - usage: response.usage.as_ref().map(|usage| Usage { - input_tokens: Some(usage.input_tokens), - output_tokens: Some(usage.output_tokens), - total_tokens: Some(usage.total_tokens), - }), - provider_id: Some(response.id.clone()), - timestamp: Some(response.created_at.to_string()), - provider_metadata_json: response.metadata.as_ref().map(|m| m.to_string()), +pub fn convert_tool_call(tool_call: &crate::client::ToolCall) -> ToolCall { + match tool_call { + crate::client::ToolCall::Function { function, id, .. } => ToolCall { + id: id.clone(), + name: function.name.clone(), + arguments_json: function.arguments.clone(), + }, + } +} + +fn convert_content_parts(contents: Vec) -> crate::client::Content { + let mut result = Vec::new(); + for content in contents { + match content { + ContentPart::Text(text) => result.push(crate::client::ContentPart::TextInput { text }), + ContentPart::Image(image_reference) => match image_reference { + ImageReference::Url(image_url) => { + result.push(crate::client::ContentPart::ImageInput { + image_url: crate::client::ImageUrl { + url: image_url.url, + detail: image_url.detail.map(|d| d.into()), + }, + }) + } + ImageReference::Inline(image_source) => { + let base64_data = general_purpose::STANDARD.encode(&image_source.data); + let media_type = &image_source.mime_type; + result.push(crate::client::ContentPart::ImageInput { + image_url: crate::client::ImageUrl { + url: format!("data:{media_type};base64,{base64_data}"), + detail: image_source.detail.map(|d| d.into()), + }, + }); + } + }, + } + } + crate::client::Content::List(result) +} + +impl From for Detail { + fn from(value: ImageDetail) -> Self { + match value { + ImageDetail::Auto => Self::Auto, + ImageDetail::Low => Self::Low, + ImageDetail::High => Self::High, + } + } +} + +pub fn convert_finish_reason(value: &crate::client::FinishReason) -> FinishReason { + match value { + crate::client::FinishReason::Stop => FinishReason::Stop, + crate::client::FinishReason::Length => FinishReason::Length, + crate::client::FinishReason::ToolCalls => FinishReason::ToolCalls, + crate::client::FinishReason::ContentFilter => FinishReason::ContentFilter, + } +} + +pub fn convert_usage(value: &crate::client::Usage) -> Usage { + Usage { + input_tokens: Some(value.prompt_tokens), + output_tokens: Some(value.completion_tokens), + total_tokens: Some(value.total_tokens), } } diff --git a/llm/openai/src/lib.rs b/llm/openai/src/lib.rs index 31a1fc12..e3fb2c1a 100644 --- a/llm/openai/src/lib.rs +++ b/llm/openai/src/lib.rs @@ -1,31 +1,38 @@ -use crate::client::{ - CreateModelResponseResponse, InputItem, OutputItem, ResponseOutputItemDone, - ResponseOutputTextDelta, ResponsesApi, -}; +mod client; +mod conversions; + +use crate::client::{ChatCompletionChunk, CompletionsApi, CompletionsRequest}; use crate::conversions::{ - create_request, create_response_metadata, messages_to_input_items, parse_error_code, - process_model_response, tool_defs_to_tools, tool_results_to_input_items, + convert_finish_reason, convert_usage, create_request, process_response, + tool_results_to_messages, }; use golem_llm::chat_stream::{LlmChatStream, LlmChatStreamState}; use golem_llm::config::with_config_key; use golem_llm::durability::{DurableLLM, ExtendedGuest}; use golem_llm::event_source::EventSource; use golem_llm::golem::llm::llm::{ - ChatEvent, ChatStream, Config, ContentPart, Error, ErrorCode, Guest, Message, StreamDelta, - StreamEvent, ToolCall, ToolResult, + ChatEvent, ChatStream, Config, ContentPart, Error, FinishReason, Guest, Message, + ResponseMetadata, Role, StreamDelta, StreamEvent, ToolCall, ToolResult, }; use golem_llm::LOGGING_STATE; use golem_rust::wasm_rpc::Pollable; use log::trace; use std::cell::{Ref, RefCell, RefMut}; +use std::collections::HashMap; -mod client; -mod conversions; +#[derive(Default)] +struct JsonFragment { + id: String, + name: String, + json: String, +} struct OpenAIChatStream { stream: RefCell>, failure: Option, finished: RefCell, + finish_reason: RefCell>, + json_fragments: RefCell>, } impl OpenAIChatStream { @@ -34,6 +41,8 @@ impl OpenAIChatStream { stream: RefCell::new(Some(stream)), failure: None, finished: RefCell::new(false), + finish_reason: RefCell::new(None), + json_fragments: RefCell::new(HashMap::new()), }) } @@ -42,8 +51,22 @@ impl OpenAIChatStream { stream: RefCell::new(None), failure: Some(error), finished: RefCell::new(false), + finish_reason: RefCell::new(None), + json_fragments: RefCell::new(HashMap::new()), }) } + + fn set_finished(&self) { + *self.finished.borrow_mut() = true; + } + + fn set_finish_reason(&self, finish_reason: FinishReason) { + *self.finish_reason.borrow_mut() = Some(finish_reason); + } + + fn get_finish_reason(&self) -> Option { + *self.finish_reason.borrow() + } } impl LlmChatStreamState for OpenAIChatStream { @@ -69,90 +92,93 @@ impl LlmChatStreamState for OpenAIChatStream { fn decode_message(&self, raw: &str) -> Result, String> { trace!("Received raw stream event: {raw}"); - let json: serde_json::Value = serde_json::from_str(raw) + + if raw.starts_with("data: [DONE]") { + self.set_finished(); + return Ok(None); + } + + if !raw.starts_with("data: ") { + return Ok(None); + } + + let json_str = &raw[6..]; + let json: serde_json::Value = serde_json::from_str(json_str) + .map_err(|err| format!("Failed to parse stream event JSON: {err}"))?; + + let chunk: ChatCompletionChunk = serde_json::from_value(json) .map_err(|err| format!("Failed to deserialize stream event: {err}"))?; - let typ = json - .as_object() - .and_then(|obj| obj.get("type")) - .and_then(|v| v.as_str()); - match typ { - Some("response.failed") => { - let response = json - .as_object() - .and_then(|obj| obj.get("response")) - .ok_or_else(|| { - "Unexpected stream event format, does not have 'response' field".to_string() - })?; - let decoded = - serde_json::from_value::(response.clone()) - .map_err(|err| { - format!("Failed to deserialize stream event's response field: {err}") - })?; - - if let Some(error) = decoded.error { - Ok(Some(StreamEvent::Error(Error { - code: parse_error_code(error.code), - message: error.message, - provider_error_json: None, - }))) - } else { - Ok(Some(StreamEvent::Error(Error { - code: ErrorCode::InternalError, - message: "Unknown error".to_string(), - provider_error_json: None, - }))) - } - } - Some("response.completed") => { - let response = json - .as_object() - .and_then(|obj| obj.get("response")) - .ok_or_else(|| { - "Unexpected stream event format, does not have 'response' field".to_string() - })?; - let decoded = - serde_json::from_value::(response.clone()) - .map_err(|err| { - format!("Failed to deserialize stream event's response field: {err}") - })?; - Ok(Some(StreamEvent::Finish(create_response_metadata( - &decoded, - )))) + if let Some(choice) = chunk.choices.into_iter().next() { + if let Some(finish_reason) = choice.finish_reason { + self.set_finish_reason(convert_finish_reason(&finish_reason)); } - Some("response.output_text.delta") => { - let decoded = serde_json::from_value::(json) - .map_err(|err| format!("Failed to deserialize stream event: {err}"))?; - Ok(Some(StreamEvent::Delta(StreamDelta { - content: Some(vec![ContentPart::Text(decoded.delta)]), + + let delta = &choice.delta; + + if let Some(content) = &delta.content { + return Ok(Some(StreamEvent::Delta(StreamDelta { + content: Some(vec![ContentPart::Text(content.clone())]), tool_calls: None, - }))) + }))); } - Some("response.output_item.done") => { - let decoded = serde_json::from_value::(json) - .map_err(|err| format!("Failed to deserialize stream event: {err}"))?; - if let OutputItem::ToolCall { - arguments, - call_id, - name, - .. - } = decoded.item - { - Ok(Some(StreamEvent::Delta(StreamDelta { + + if let Some(tool_calls) = &delta.tool_calls { + let mut fragments = self.json_fragments.borrow_mut(); + let mut result_tool_calls = Vec::new(); + + for tool_call in tool_calls { + match tool_call { + crate::client::ToolCall::Function { + function, + id, + index, + } => { + let idx = index.unwrap_or(0); + + let fragment = fragments.entry(idx).or_insert_with(|| JsonFragment { + id: id.clone(), + name: function.name.clone(), + json: String::new(), + }); + + if !function.arguments.is_empty() { + fragment.json.push_str(&function.arguments); + } + + // Only emit when we have content to add + if !fragment.id.is_empty() && !fragment.name.is_empty() { + result_tool_calls.push(ToolCall { + id: fragment.id.clone(), + name: fragment.name.clone(), + arguments_json: fragment.json.clone(), + }); + } + } + } + } + + if !result_tool_calls.is_empty() { + return Ok(Some(StreamEvent::Delta(StreamDelta { content: None, - tool_calls: Some(vec![ToolCall { - id: call_id, - name, - arguments_json: arguments, - }]), - }))) - } else { - Ok(None) + tool_calls: Some(result_tool_calls), + }))); } } - Some(_) => Ok(None), - None => Err("Unexpected stream event format, does not have 'type' field".to_string()), } + + if let Some(usage) = chunk.usage { + let finish_reason = self.get_finish_reason(); + return Ok(Some(StreamEvent::Finish(ResponseMetadata { + finish_reason, + usage: Some(convert_usage(&usage)), + provider_id: Some(chunk.id), + timestamp: Some(chunk.created.to_string()), + provider_metadata_json: None, + }))); + } + + Ok(None) } } @@ -161,33 +187,20 @@ struct OpenAIComponent; impl OpenAIComponent { const ENV_VAR_NAME: &'static str = "OPENAI_API_KEY"; - fn request(client: ResponsesApi, items: Vec, config: Config) -> ChatEvent { - match tool_defs_to_tools(&config.tools) { - Ok(tools) => { - let request = create_request(items, config, tools); - match client.create_model_response(request) { - Ok(response) => process_model_response(response), - Err(error) => ChatEvent::Error(error), - } - } + fn request(client: CompletionsApi, request: CompletionsRequest) -> ChatEvent { + match client.send_messages(request) { + Ok(response) => process_response(response), Err(error) => ChatEvent::Error(error), } } fn streaming_request( - client: ResponsesApi, - items: Vec, - config: Config, + client: CompletionsApi, + mut request: CompletionsRequest, ) -> LlmChatStream { - match tool_defs_to_tools(&config.tools) { - Ok(tools) => { - let mut request = create_request(items, config, tools); - request.stream = true; - match client.stream_model_response(request) { - Ok(stream) => OpenAIChatStream::new(stream), - Err(error) => OpenAIChatStream::failed(error), - } - } + request.stream = Some(true); + match client.stream_send_messages(request) { + Ok(stream) => OpenAIChatStream::new(stream), Err(error) => OpenAIChatStream::failed(error), } } @@ -200,10 +213,12 @@ impl Guest for OpenAIComponent { LOGGING_STATE.with_borrow_mut(|state| state.init()); with_config_key(Self::ENV_VAR_NAME, ChatEvent::Error, |openai_api_key| { - let client = ResponsesApi::new(openai_api_key); + let client = CompletionsApi::new(openai_api_key); - let items = messages_to_input_items(messages); - Self::request(client, items, config) + match create_request(messages, config) { + Ok(request) => Self::request(client, request), + Err(err) => ChatEvent::Error(err), + } }) } @@ -215,11 +230,17 @@ impl Guest for OpenAIComponent { LOGGING_STATE.with_borrow_mut(|state| state.init()); with_config_key(Self::ENV_VAR_NAME, ChatEvent::Error, |openai_api_key| { - let client = ResponsesApi::new(openai_api_key); + let client = CompletionsApi::new(openai_api_key); - let mut items = messages_to_input_items(messages); - items.extend(tool_results_to_input_items(tool_results)); - Self::request(client, items, config) + match create_request(messages, config) { + Ok(mut request) => { + request + .messages + .extend(tool_results_to_messages(tool_results)); + Self::request(client, request) + } + Err(err) => ChatEvent::Error(err), + } }) } @@ -229,21 +250,72 @@ impl Guest for OpenAIComponent { } impl ExtendedGuest for OpenAIComponent { - fn unwrapped_stream(messages: Vec, config: Config) -> Self::ChatStream { + fn unwrapped_stream(messages: Vec, config: Config) -> LlmChatStream { LOGGING_STATE.with_borrow_mut(|state| state.init()); with_config_key( Self::ENV_VAR_NAME, OpenAIChatStream::failed, |openai_api_key| { - let client = ResponsesApi::new(openai_api_key); + let client = CompletionsApi::new(openai_api_key); - let items = messages_to_input_items(messages); - Self::streaming_request(client, items, config) + match create_request(messages, config) { + Ok(request) => Self::streaming_request(client, request), + Err(err) => OpenAIChatStream::failed(err), + } }, ) } + fn retry_prompt(original_messages: &[Message], partial_result: &[StreamDelta]) -> Vec { + let mut extended_messages = Vec::new(); + extended_messages.push(Message { + role: Role::System, + name: None, + content: vec![ + ContentPart::Text( + "You were asked the same question previously, but the response was interrupted before completion. \ + Please continue your response from where you left off. \ + Do not include the part of the response that was already seen.".to_string()), + ], + }); + extended_messages.push(Message { + role: Role::User, + name: None, + content: vec![ContentPart::Text( + "Here is the original question:".to_string(), + )], + }); + extended_messages.extend_from_slice(original_messages); + + let mut partial_result_as_content = Vec::new(); + for delta in partial_result { + if let Some(contents) = &delta.content { + partial_result_as_content.extend_from_slice(contents); + } + if let Some(tool_calls) = &delta.tool_calls { + for tool_call in tool_calls { + partial_result_as_content.push(ContentPart::Text(format!( + "", + tool_call.id, tool_call.name, tool_call.arguments_json, + ))); + } + } + } + + extended_messages.push(Message { + role: Role::User, + name: None, + content: vec![ContentPart::Text( + "Here is the partial response that was successfully received:".to_string(), + )] + .into_iter() + .chain(partial_result_as_content) + .collect(), + }); + extended_messages + } + fn subscribe(stream: &Self::ChatStream) -> Pollable { stream.subscribe() } diff --git a/llm/openrouter/src/bindings.rs b/llm/openrouter/src/bindings.rs index ba2accf7..1300cde9 100644 --- a/llm/openrouter/src/bindings.rs +++ b/llm/openrouter/src/bindings.rs @@ -1,12 +1,15 @@ -// Generated by `wit-bindgen` 0.36.0. DO NOT EDIT! +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! // Options used: // * runtime_path: "wit_bindgen_rt" // * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" // * generate_unused_types use golem_llm::golem::llm::llm as __with_name0; #[cfg(target_arch = "wasm32")] -#[link_section = "component-type:wit-bindgen:0.36.0:golem:llm-openrouter@1.0.0:llm-library:encoded world"] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-openrouter@1.0.0:llm-library:encoded world" +)] #[doc(hidden)] +#[allow(clippy::octal_escapes)] pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1763] = *b"\ \0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xe1\x0c\x01A\x02\x01\ A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ @@ -43,8 +46,8 @@ ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send \0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0&golem:\ llm-openrouter/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09\ -producers\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rus\ -t\x060.36.0"; +producers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rus\ +t\x060.41.0"; #[inline(never)] #[doc(hidden)] pub fn __link_custom_section_describing_imports() { diff --git a/llm/openrouter/src/conversions.rs b/llm/openrouter/src/conversions.rs index d4db2d34..61b5f973 100644 --- a/llm/openrouter/src/conversions.rs +++ b/llm/openrouter/src/conversions.rs @@ -184,7 +184,7 @@ fn convert_content_parts(contents: Vec) -> crate::client::Content { let media_type = &image_source.mime_type; // This is already a string result.push(crate::client::ContentPart::ImageInput { image_url: crate::client::ImageUrl { - url: format!("data:{};base64,{}", media_type, base64_data), + url: format!("data:{media_type};base64,{base64_data}"), detail: image_source.detail.map(|d| d.into()), }, }); diff --git a/test/components-rust/test-llm/Cargo.toml b/test/components-rust/test-llm/Cargo.toml index 7f624287..b927cded 100644 --- a/test/components-rust/test-llm/Cargo.toml +++ b/test/components-rust/test-llm/Cargo.toml @@ -30,11 +30,6 @@ wit-bindgen-rt = { workspace = true } [package.metadata.component.target] path = "wit-generated" -[package.metadata.component.bindings.with] -"wasi:io/poll@0.2.0" = "golem_rust::wasm_rpc::wasi::io::poll" -"wasi:clocks/wall-clock@0.2.0" = "golem_rust::wasm_rpc::wasi::clocks::wall_clock" -"golem:rpc/types@0.2.0" = "golem_rust::wasm_rpc::golem_rpc_0_2_x::types" - [package.metadata.component.target.dependencies] "golem:llm" = { path = "wit-generated/deps/golem-llm" } "wasi:clocks" = { path = "wit-generated/deps/clocks" } diff --git a/test/components-rust/test-llm/src/lib.rs b/test/components-rust/test-llm/src/lib.rs index fa11684d..593f3dd3 100644 --- a/test/components-rust/test-llm/src/lib.rs +++ b/test/components-rust/test-llm/src/lib.rs @@ -570,6 +570,113 @@ impl Guest for Component { } } } + + /// test8 demonstrates multi-turn conversations and crash recovery during streaming + fn test8() -> String { + let config = llm::Config { + model: MODEL.to_string(), + temperature: Some(0.2), + max_tokens: None, + stop_sequences: None, + tools: vec![], + tool_choice: None, + provider_options: vec![], + }; + + let mut messages = vec![llm::Message { + role: llm::Role::User, + name: Some("vigoo".to_string()), + content: vec![llm::ContentPart::Text( + "Do you know what a haiku is?".to_string(), + )], + }]; + + let stream = llm::stream(&messages, &config); + + let mut result = String::new(); + + let name = std::env::var("GOLEM_WORKER_NAME").unwrap(); + + loop { + let events = stream.blocking_get_next(); + if events.is_empty() { + break; + } + + for event in events { + match event { + llm::StreamEvent::Delta(delta) => { + for content in delta.content.unwrap_or_default() { + if let llm::ContentPart::Text(txt) = content { + result.push_str(&txt); + } + } + } + llm::StreamEvent::Finish(_) => {} + llm::StreamEvent::Error(error) => { + result.push_str(&format!("ERROR: {}", error.message)); + } + } + } + } + + messages.push(llm::Message { + role: llm::Role::Assistant, + name: Some("assistant".to_string()), + content: vec![llm::ContentPart::Text(result)], + }); + + messages.push(llm::Message { + role: llm::Role::User, + name: Some("vigoo".to_string()), + content: vec![llm::ContentPart::Text( + "Can you write one for me?".to_string(), + )], + }); + + let stream = llm::stream(&messages, &config); + + let mut result = String::new(); + let mut round = 0; + + loop { + let events = stream.blocking_get_next(); + if events.is_empty() { + break; + } + + for event in events { + match event { + llm::StreamEvent::Delta(delta) => { + for content in delta.content.unwrap_or_default() { + if let llm::ContentPart::Text(txt) = content { + result.push_str(&txt); + } + } + } + llm::StreamEvent::Finish(_) => {} + llm::StreamEvent::Error(error) => { + result.push_str(&format!("ERROR: {}", error.message)); + } + } + } + + round += 1; + if round == 2 { + atomically(|| { + let client = TestHelperApi::new(&name); + let answer = client.blocking_inc_and_get(); + if answer == 1 { + panic!("Simulating crash") + } + }); + } + } + + format!("Multi-turn conversation completed successfully: {}", result) + } } + + bindings::export!(Component with_types_in bindings); diff --git a/test/components-rust/test-llm/wit/test-llm.wit b/test/components-rust/test-llm/wit/test-llm.wit index 37b4f419..97edad98 100644 --- a/test/components-rust/test-llm/wit/test-llm.wit +++ b/test/components-rust/test-llm/wit/test-llm.wit @@ -10,6 +10,7 @@ interface test-llm-api { test5: func() -> string; test6: func() -> string; test7: func() -> string; + test8: func() -> string; } world test-llm {