From 7d950bed822bc76307fe7e5ab36ee8c9f7ef25d0 Mon Sep 17 00:00:00 2001 From: Rutik7066 Date: Wed, 25 Jun 2025 21:12:42 +0000 Subject: [PATCH 1/8] wip: bedrock implementation --- Cargo.lock | 111 +++++ Cargo.toml | 5 + Makefile.toml | 4 +- llm/Makefile.toml | 32 +- llm/anthropic/src/bindings.rs | 11 +- llm/bedrock/Cargo.toml | 47 ++ llm/bedrock/src/bindings.rs | 55 +++ llm/bedrock/src/client.rs | 455 +++++++++++++++++++ llm/bedrock/src/conversions.rs | 301 ++++++++++++ llm/bedrock/src/lib.rs | 315 +++++++++++++ llm/bedrock/wit/bedrock.wit | 7 + llm/bedrock/wit/deps/golem-llm/golem-llm.wit | 194 ++++++++ llm/bedrock/wit/deps/wasi:io/error.wit | 34 ++ llm/bedrock/wit/deps/wasi:io/poll.wit | 47 ++ llm/bedrock/wit/deps/wasi:io/streams.wit | 290 ++++++++++++ llm/bedrock/wit/deps/wasi:io/world.wit | 10 + llm/grok/src/bindings.rs | 11 +- llm/ollama/src/bindings.rs | 11 +- llm/openai/src/bindings.rs | 11 +- llm/openrouter/src/bindings.rs | 11 +- test/components-rust/test-llm/Cargo.toml | 1 + test/components-rust/test-llm/golem.yaml | 54 +++ 22 files changed, 1994 insertions(+), 23 deletions(-) create mode 100644 llm/bedrock/Cargo.toml create mode 100644 llm/bedrock/src/bindings.rs create mode 100644 llm/bedrock/src/client.rs create mode 100644 llm/bedrock/src/conversions.rs create mode 100644 llm/bedrock/src/lib.rs create mode 100644 llm/bedrock/wit/bedrock.wit create mode 100644 llm/bedrock/wit/deps/golem-llm/golem-llm.wit create mode 100644 llm/bedrock/wit/deps/wasi:io/error.wit create mode 100644 llm/bedrock/wit/deps/wasi:io/poll.wit create mode 100644 llm/bedrock/wit/deps/wasi:io/streams.wit create mode 100644 llm/bedrock/wit/deps/wasi:io/world.wit diff --git a/Cargo.lock b/Cargo.lock index 0865d6ade..4168db8a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -65,6 +65,15 @@ version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.17.0" @@ -145,6 +154,15 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.4.2" @@ -154,6 +172,27 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -300,6 +339,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.3.3" @@ -360,6 +409,24 @@ dependencies = [ "wit-bindgen-rt 0.40.0", ] +[[package]] +name = "golem-llm-bedrock" +version = "0.0.0" +dependencies = [ + "base64 0.22.1", + "chrono", + "golem-llm", + "golem-rust", + "hex", + "hmac", + "log", + "reqwest", + "serde", + "serde_json", + "sha2", + "wit-bindgen-rt 0.40.0", +] + [[package]] name = "golem-llm-grok" version = "0.0.0" @@ -481,6 +548,21 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "http" version = "1.3.1" @@ -902,6 +984,17 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shlex" version = "1.3.0" @@ -938,6 +1031,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.101" @@ -1011,6 +1110,12 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" +[[package]] +name = "typenum" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" + [[package]] name = "unicase" version = "2.8.1" @@ -1065,6 +1170,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "wasi" version = "0.14.2+wasi-0.2.4" diff --git a/Cargo.toml b/Cargo.toml index 7bea1e1e5..cb5c7e7ce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ resolver = "2" members = [ "llm/llm", "llm/anthropic", + "llm/bedrock", "llm/grok", "llm/ollama", "llm/openai", @@ -26,3 +27,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0" } wit-bindgen-rt = { version = "0.40.0", features = ["bitflags"] } base64 = { version = "0.22.1" } +hex = { version = "0.4" } +hmac = { version = "0.12" } +sha2 = { version = "0.10" } +chrono = { version = "0.4", features = ["serde"] } diff --git a/Makefile.toml b/Makefile.toml index cc443bc6a..bfc2c6acb 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -137,7 +137,7 @@ script = ''' is_portable = eq ${1} "--portable" -targets = array llm_openai llm_anthropic llm_grok llm_openrouter llm_ollama +targets = array llm_openai llm_anthropic llm_grok llm_openrouter llm_ollama llm_bedrock for target in ${targets} if is_portable cp target/wasm32-wasip1/debug/golem_${target}.wasm components/debug/golem_${target}-portable.wasm @@ -153,7 +153,7 @@ script = ''' is_portable = eq ${1} "--portable" -targets = array llm_openai llm_anthropic llm_grok llm_openrouter llm_ollama +targets = array llm_openai llm_anthropic llm_grok llm_openrouter llm_ollama llm_bedrock for target in ${targets} if is_portable cp target/wasm32-wasip1/release/golem_${target}.wasm components/release/golem_${target}-portable.wasm diff --git a/llm/Makefile.toml b/llm/Makefile.toml index 5b7b92a89..15ee75be7 100644 --- a/llm/Makefile.toml +++ b/llm/Makefile.toml @@ -5,6 +5,7 @@ skip_core_tasks = true [tasks.build] run_task = { name = [ "build-anthropic", + "build-bedrock", "build-grok", "build-openai", "build-openrouter", @@ -14,6 +15,7 @@ run_task = { name = [ [tasks.build-portable] run_task = { name = [ "build-anthropic-portable", + "build-bedrock-portable", "build-grok-portable", "build-openai-portable", "build-openrouter-portable", @@ -23,6 +25,7 @@ run_task = { name = [ [tasks.release-build] run_task = { name = [ "release-build-anthropic", + "release-build-bedrock", "release-build-grok", "release-build-openai", "release-build-openrouter", @@ -32,6 +35,7 @@ run_task = { name = [ [tasks.release-build-portable] run_task = { name = [ "release-build-anthropic-portable", + "release-build-bedrock-portable", "release-build-grok-portable", "release-build-openai-portable", "release-build-openrouter-portable", @@ -60,6 +64,16 @@ install_crate = { crate_name = "cargo-component", version = "0.20.0" } command = "cargo-component" args = ["build", "-p", "golem-llm-anthropic", "--no-default-features"] +[tasks.build-bedrock] +install_crate = { crate_name = "cargo-component", version = "0.20.0" } +command = "cargo-component" +args = ["build", "-p", "golem-llm-bedrock"] + +[tasks.build-bedrock-portable] +install_crate = { crate_name = "cargo-component", version = "0.20.0" } +command = "cargo-component" +args = ["build", "-p", "golem-llm-bedrock", "--no-default-features"] + [tasks.build-grok] install_crate = { crate_name = "cargo-component", version = "0.20.0" } command = "cargo-component" @@ -117,6 +131,22 @@ args = [ "--no-default-features", ] +[tasks.release-build-bedrock] +install_crate = { crate_name = "cargo-component", version = "0.20.0" } +command = "cargo-component" +args = ["build", "-p", "golem-llm-bedrock", "--release"] + +[tasks.release-build-bedrock-portable] +install_crate = { crate_name = "cargo-component", version = "0.20.0" } +command = "cargo-component" +args = [ + "build", + "-p", + "golem-llm-bedrock", + "--release", + "--no-default-features", +] + [tasks.release-build-grok] install_crate = { crate_name = "cargo-component", version = "0.20.0" } command = "cargo-component" @@ -163,7 +193,7 @@ dependencies = ["wit-update"] script_runner = "@duckscript" script = """ -modules = array llm openai anthropic grok openrouter ollama +modules = array llm openai anthropic bedrock grok openrouter ollama for module in ${modules} rm -r ${module}/wit/deps diff --git a/llm/anthropic/src/bindings.rs b/llm/anthropic/src/bindings.rs index 70c5f1fd5..1a54d6167 100644 --- a/llm/anthropic/src/bindings.rs +++ b/llm/anthropic/src/bindings.rs @@ -1,12 +1,15 @@ -// Generated by `wit-bindgen` 0.36.0. DO NOT EDIT! +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! // Options used: // * runtime_path: "wit_bindgen_rt" // * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" // * generate_unused_types use golem_llm::golem::llm::llm as __with_name0; #[cfg(target_arch = "wasm32")] -#[link_section = "component-type:wit-bindgen:0.36.0:golem:llm-anthropic@1.0.0:llm-library:encoded world"] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-anthropic@1.0.0:llm-library:encoded world" +)] #[doc(hidden)] +#[allow(clippy::octal_escapes)] pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1762] = *b"\ \0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xe0\x0c\x01A\x02\x01\ A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ @@ -43,8 +46,8 @@ ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send \0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0%golem:\ llm-anthropic/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09\ -producers\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rus\ -t\x060.36.0"; +producers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rus\ +t\x060.41.0"; #[inline(never)] #[doc(hidden)] pub fn __link_custom_section_describing_imports() { diff --git a/llm/bedrock/Cargo.toml b/llm/bedrock/Cargo.toml new file mode 100644 index 000000000..a59411a4b --- /dev/null +++ b/llm/bedrock/Cargo.toml @@ -0,0 +1,47 @@ +[package] +name = "golem-llm-bedrock" +version = "0.0.0" +edition = "2021" +license = "Apache-2.0" +homepage = "https://golem.cloud" +repository = "https://github.com/golemcloud/golem-llm" +description = "WebAssembly component for working with AWS Bedrock APIs, with special support for Golem Cloud" + +[lib] +path = "src/lib.rs" +crate-type = ["cdylib"] + +[features] +default = ["durability"] +durability = ["golem-rust/durability", "golem-llm/durability"] + +[dependencies] +golem-llm = { workspace = true } + +golem-rust = { workspace = true } +log = { workspace = true } +reqwest = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +wit-bindgen-rt = { workspace = true } +base64 = { workspace = true } +hex = { workspace = true } +hmac = { workspace = true } +sha2 = { workspace = true } +chrono = { workspace = true } + +[package.metadata.component] +package = "golem:llm-bedrock" + +[package.metadata.component.bindings] +generate_unused_types = true + +[package.metadata.component.bindings.with] +"golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" + +[package.metadata.component.target] +path = "wit" + +[package.metadata.component.target.dependencies] +"golem:llm" = { path = "wit/deps/golem-llm" } +"wasi:io" = { path = "wit/deps/wasi:io" } \ No newline at end of file diff --git a/llm/bedrock/src/bindings.rs b/llm/bedrock/src/bindings.rs new file mode 100644 index 000000000..c173c3169 --- /dev/null +++ b/llm/bedrock/src/bindings.rs @@ -0,0 +1,55 @@ +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! +// Options used: +// * runtime_path: "wit_bindgen_rt" +// * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" +// * generate_unused_types +use golem_llm::golem::llm::llm as __with_name0; +#[cfg(target_arch = "wasm32")] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-bedrock@1.0.0:llm-library:encoded world" +)] +#[doc(hidden)] +#[allow(clippy::octal_escapes)] +pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1760] = *b"\ +\0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xde\x0c\x01A\x02\x01\ +A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ +\x01m\x06\x0finvalid-request\x15authentication-failed\x13rate-limit-exceeded\x0e\ +internal-error\x0bunsupported\x07unknown\x04\0\x0aerror-code\x03\0\x02\x01m\x06\x04\ +stop\x06length\x0atool-calls\x0econtent-filter\x05error\x05other\x04\0\x0dfinish\ +-reason\x03\0\x04\x01m\x03\x03low\x04high\x04auto\x04\0\x0cimage-detail\x03\0\x06\ +\x01k\x07\x01r\x02\x03urls\x06detail\x08\x04\0\x09image-url\x03\0\x09\x01p}\x01r\ +\x03\x04data\x0b\x09mime-types\x06detail\x08\x04\0\x0cimage-source\x03\0\x0c\x01\ +q\x02\x03url\x01\x0a\0\x06inline\x01\x0d\0\x04\0\x0fimage-reference\x03\0\x0e\x01\ +q\x02\x04text\x01s\0\x05image\x01\x0f\0\x04\0\x0ccontent-part\x03\0\x10\x01ks\x01\ +p\x11\x01r\x03\x04role\x01\x04name\x12\x07content\x13\x04\0\x07message\x03\0\x14\ +\x01r\x03\x04names\x0bdescription\x12\x11parameters-schemas\x04\0\x0ftool-defini\ +tion\x03\0\x16\x01r\x03\x02ids\x04names\x0earguments-jsons\x04\0\x09tool-call\x03\ +\0\x18\x01ky\x01r\x04\x02ids\x04names\x0bresult-jsons\x11execution-time-ms\x1a\x04\ +\0\x0ctool-success\x03\0\x1b\x01r\x04\x02ids\x04names\x0derror-messages\x0aerror\ +-code\x12\x04\0\x0ctool-failure\x03\0\x1d\x01q\x02\x07success\x01\x1c\0\x05error\ +\x01\x1e\0\x04\0\x0btool-result\x03\0\x1f\x01r\x02\x03keys\x05values\x04\0\x02kv\ +\x03\0!\x01kv\x01ps\x01k$\x01p\x17\x01p\"\x01r\x07\x05models\x0btemperature#\x0a\ +max-tokens\x1a\x0estop-sequences%\x05tools&\x0btool-choice\x12\x10provider-optio\ +ns'\x04\0\x06config\x03\0(\x01r\x03\x0cinput-tokens\x1a\x0doutput-tokens\x1a\x0c\ +total-tokens\x1a\x04\0\x05usage\x03\0*\x01k\x05\x01k+\x01r\x05\x0dfinish-reason,\ +\x05usage-\x0bprovider-id\x12\x09timestamp\x12\x16provider-metadata-json\x12\x04\ +\0\x11response-metadata\x03\0.\x01p\x19\x01r\x04\x02ids\x07content\x13\x0atool-c\ +alls0\x08metadata/\x04\0\x11complete-response\x03\01\x01r\x03\x04code\x03\x07mes\ +sages\x13provider-error-json\x12\x04\0\x05error\x03\03\x01q\x03\x07message\x012\0\ +\x0ctool-request\x010\0\x05error\x014\0\x04\0\x0achat-event\x03\05\x01k\x13\x01k\ +0\x01r\x02\x07content7\x0atool-calls8\x04\0\x0cstream-delta\x03\09\x01q\x03\x05d\ +elta\x01:\0\x06finish\x01/\0\x05error\x014\0\x04\0\x0cstream-event\x03\0;\x04\0\x0b\ +chat-stream\x03\x01\x01h=\x01p<\x01k?\x01@\x01\x04self>\0\xc0\0\x04\0\x1c[method\ +]chat-stream.get-next\x01A\x01@\x01\x04self>\0?\x04\0%[method]chat-stream.blocki\ +ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send\ +\x01D\x01o\x02\x19\x20\x01p\xc5\0\x01@\x03\x08messages\xc3\0\x0ctool-results\xc6\ +\0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ +ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0#golem:\ +llm-bedrock/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09p\ +roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rust\ +\x060.41.0"; +#[inline(never)] +#[doc(hidden)] +pub fn __link_custom_section_describing_imports() { + wit_bindgen_rt::maybe_link_cabi_realloc(); +} diff --git a/llm/bedrock/src/client.rs b/llm/bedrock/src/client.rs new file mode 100644 index 000000000..1e5a82da4 --- /dev/null +++ b/llm/bedrock/src/client.rs @@ -0,0 +1,455 @@ +use golem_llm::error::{error_code_from_status, from_event_source_error, from_reqwest_error}; +use golem_llm::event_source::EventSource; +use golem_llm::golem::llm::llm::{Error, ErrorCode}; +use log::trace; +use reqwest::{Client, Method, Response}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::fmt::Debug; +use std::collections::HashMap; +use chrono::Utc; +use hmac::{Hmac, Mac}; +use sha2::{Sha256, Digest}; + +type HmacSha256 = Hmac; + +/// AWS Bedrock client for creating model responses +pub struct BedrockClient { + access_key_id: String, + secret_access_key: String, + region: String, + client: Client, +} + +impl BedrockClient { + pub fn new(access_key_id: String, secret_access_key: String, region: String) -> Self { + let client = Client::builder() + .build() + .expect("Failed to initialize HTTP client"); + Self { + access_key_id, + secret_access_key, + region, + client, + } + } + + pub fn converse(&self, model_id: &str, request: ConverseRequest) -> Result { + trace!("Sending request to Bedrock API: {request:?}"); + + let body = serde_json::to_string(&request) + .map_err(|err| Error { + code: ErrorCode::InvalidRequest, + message: format!("Failed to serialize request: {err}"), + provider_error_json: None, + })?; + + let headers = self.sign_request(&body, model_id, false)?; + let url = format!("https://bedrock-runtime.{}.amazonaws.com/model/{}/converse", self.region, model_id); + + let mut request_builder = self.client.request(Method::POST, url); + for (key, value) in headers { + request_builder = request_builder.header(key, value); + } + + let response: Response = request_builder + .body(body) + .send() + .map_err(|err| from_reqwest_error("Request failed", err))?; + + parse_response(response) + } + + pub fn converse_stream(&self, model_id: &str, request: ConverseRequest) -> Result { + trace!("Sending streaming request to Bedrock API: {request:?}"); + + let body = serde_json::to_string(&request) + .map_err(|err| Error { + code: ErrorCode::InvalidRequest, + message: format!("Failed to serialize request: {err}"), + provider_error_json: None, + })?; + + let headers = self.sign_request(&body, model_id, true)?; + let url = format!("https://bedrock-runtime.{}.amazonaws.com/model/{}/converse-stream", self.region, model_id); + + let mut request_builder = self.client.request(Method::POST, url); + for (key, value) in headers { + request_builder = request_builder.header(key, value); + } + + let response: Response = request_builder + .body(body) + .send() + .map_err(|err| from_reqwest_error("Request failed", err))?; + + trace!("Initializing SSE stream"); + + EventSource::new(response) + .map_err(|err| from_event_source_error("Failed to create SSE stream", err)) + } + + fn sign_request(&self, body: &str, model_id: &str, is_stream: bool) -> Result, Error> { + let now = Utc::now(); + let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string(); + let date_stamp = now.format("%Y%m%d").to_string(); + + let service = "bedrock"; + let endpoint = if is_stream { "converse-stream" } else { "converse" }; + let canonical_uri = format!("/model/{}/{}", model_id, endpoint); + + let canonical_headers = format!( + "host:bedrock-runtime.{}.amazonaws.com\nx-amz-date:{}\n", + self.region, amz_date + ); + let signed_headers = "host;x-amz-date"; + + let payload_hash = hex::encode(Sha256::digest(body.as_bytes())); + + let canonical_request = format!( + "POST\n{}\n\n{}\n{}\n{}", + canonical_uri, canonical_headers, signed_headers, payload_hash + ); + + let algorithm = "AWS4-HMAC-SHA256"; + let credential_scope = format!("{}/{}/{}/aws4_request", date_stamp, self.region, service); + let string_to_sign = format!( + "{}\n{}\n{}\n{}", + algorithm, + amz_date, + credential_scope, + hex::encode(Sha256::digest(canonical_request.as_bytes())) + ); + + let signing_key = self.get_signature_key(&date_stamp, service)?; + let signature = hex::encode( + HmacSha256::new_from_slice(&signing_key) + .map_err(|_| Error { + code: ErrorCode::InternalError, + message: "Failed to create HMAC".to_string(), + provider_error_json: None, + })? + .chain_update(string_to_sign.as_bytes()) + .finalize() + .into_bytes() + ); + + let authorization_header = format!( + "{} Credential={}/{}, SignedHeaders={}, Signature={}", + algorithm, self.access_key_id, credential_scope, signed_headers, signature + ); + + let mut headers = HashMap::new(); + headers.insert("Authorization".to_string(), authorization_header); + headers.insert("X-Amz-Date".to_string(), amz_date); + headers.insert("Host".to_string(), format!("bedrock-runtime.{}.amazonaws.com", self.region)); + headers.insert("Content-Type".to_string(), "application/json".to_string()); + + Ok(headers) + } + + fn get_signature_key(&self, date_stamp: &str, service: &str) -> Result, Error> { + let k_date = HmacSha256::new_from_slice(format!("AWS4{}", self.secret_access_key).as_bytes()) + .map_err(|_| Error { + code: ErrorCode::InternalError, + message: "Failed to create HMAC for date".to_string(), + provider_error_json: None, + })? + .chain_update(date_stamp.as_bytes()) + .finalize() + .into_bytes(); + + let k_region = HmacSha256::new_from_slice(&k_date) + .map_err(|_| Error { + code: ErrorCode::InternalError, + message: "Failed to create HMAC for region".to_string(), + provider_error_json: None, + })? + .chain_update(self.region.as_bytes()) + .finalize() + .into_bytes(); + + let k_service = HmacSha256::new_from_slice(&k_region) + .map_err(|_| Error { + code: ErrorCode::InternalError, + message: "Failed to create HMAC for service".to_string(), + provider_error_json: None, + })? + .chain_update(service.as_bytes()) + .finalize() + .into_bytes(); + + let k_signing = HmacSha256::new_from_slice(&k_service) + .map_err(|_| Error { + code: ErrorCode::InternalError, + message: "Failed to create HMAC for signing".to_string(), + provider_error_json: None, + })? + .chain_update(b"aws4_request") + .finalize() + .into_bytes(); + + Ok(k_signing.to_vec()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConverseRequest { + #[serde(rename = "modelId")] + pub model_id: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub system: Option>, + #[serde(rename = "inferenceConfig", skip_serializing_if = "Option::is_none")] + pub inference_config: Option, + #[serde(rename = "toolConfig", skip_serializing_if = "Option::is_none")] + pub tool_config: Option, + #[serde(rename = "guardrailConfig", skip_serializing_if = "Option::is_none")] + pub guardrail_config: Option, + #[serde(rename = "additionalModelRequestFields", skip_serializing_if = "Option::is_none")] + pub additional_model_request_fields: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: Role, + pub content: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Role { + #[serde(rename = "user")] + User, + #[serde(rename = "assistant")] + Assistant, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image")] + Image { + #[serde(rename = "format")] + format: ImageFormat, + #[serde(rename = "source")] + source: ImageSource, + }, + #[serde(rename = "toolUse")] + ToolUse { + #[serde(rename = "toolUseId")] + tool_use_id: String, + name: String, + input: Value, + }, + #[serde(rename = "toolResult")] + ToolResult { + #[serde(rename = "toolUseId")] + tool_use_id: String, + content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ImageFormat { + #[serde(rename = "png")] + Png, + #[serde(rename = "jpeg")] + Jpeg, + #[serde(rename = "gif")] + Gif, + #[serde(rename = "webp")] + Webp, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "bytes")] +pub struct ImageSource { + pub bytes: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ToolResultContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image")] + Image { + #[serde(rename = "format")] + format: ImageFormat, + #[serde(rename = "source")] + source: ImageSource, + }, + #[serde(rename = "json")] + Json { json: Value }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ToolResultStatus { + #[serde(rename = "success")] + Success, + #[serde(rename = "error")] + Error, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum SystemContentBlock { + #[serde(rename = "text")] + Text { text: String }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InferenceConfig { + #[serde(rename = "maxTokens", skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(rename = "topP", skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")] + pub stop_sequences: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolConfig { + pub tools: Vec, + #[serde(rename = "toolChoice", skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum Tool { + #[serde(rename = "toolSpec")] + ToolSpec { + name: String, + description: String, + #[serde(rename = "inputSchema")] + input_schema: ToolInputSchema, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolInputSchema { + pub json: Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ToolChoice { + #[serde(rename = "auto")] + Auto, + #[serde(rename = "any")] + Any, + #[serde(rename = "tool")] + Tool { name: String }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GuardrailConfig { + #[serde(rename = "guardrailIdentifier")] + pub guardrail_identifier: String, + #[serde(rename = "guardrailVersion")] + pub guardrail_version: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub trace: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum GuardrailTrace { + #[serde(rename = "enabled")] + Enabled, + #[serde(rename = "disabled")] + Disabled, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConverseResponse { + #[serde(rename = "responseMetadata")] + pub response_metadata: ResponseMetadata, + pub output: Output, + #[serde(rename = "stopReason")] + pub stop_reason: StopReason, + pub usage: Usage, + pub metrics: Metrics, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponseMetadata { + #[serde(rename = "requestId")] + pub request_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Output { + pub message: Message, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum StopReason { + #[serde(rename = "end_turn")] + EndTurn, + #[serde(rename = "tool_use")] + ToolUse, + #[serde(rename = "max_tokens")] + MaxTokens, + #[serde(rename = "stop_sequence")] + StopSequence, + #[serde(rename = "guardrail_intervened")] + GuardrailIntervened, + #[serde(rename = "content_filtered")] + ContentFiltered, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Usage { + #[serde(rename = "inputTokens")] + pub input_tokens: u32, + #[serde(rename = "outputTokens")] + pub output_tokens: u32, + #[serde(rename = "totalTokens")] + pub total_tokens: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Metrics { + #[serde(rename = "latencyMs")] + pub latency_ms: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ErrorResponse { + pub message: String, + #[serde(rename = "type")] + pub error_type: String, +} + +fn parse_response(response: Response) -> Result { + let status = response.status(); + if status.is_success() { + let body = response + .json::() + .map_err(|err| from_reqwest_error("Failed to decode response body", err))?; + + trace!("Received response from Bedrock API: {body:?}"); + + Ok(body) + } else { + let error_body = response + .json::() + .map_err(|err| from_reqwest_error("Failed to receive error response body", err))?; + + trace!("Received {status} response from Bedrock API: {error_body:?}"); + + Err(Error { + code: error_code_from_status(status), + message: format!("Request failed with {status}: {}", error_body.message), + provider_error_json: Some(serde_json::to_string(&error_body).unwrap()), + }) + } +} \ No newline at end of file diff --git a/llm/bedrock/src/conversions.rs b/llm/bedrock/src/conversions.rs new file mode 100644 index 000000000..c530c4b02 --- /dev/null +++ b/llm/bedrock/src/conversions.rs @@ -0,0 +1,301 @@ +use crate::client::{ + ContentBlock, ConverseRequest, ConverseResponse, ImageFormat, ImageSource as ClientImageSource, + InferenceConfig, Message as ClientMessage, Role as ClientRole, StopReason, SystemContentBlock, + Tool, ToolChoice, ToolConfig, ToolInputSchema, ToolResultContentBlock, ToolResultStatus, +}; +use base64::{engine::general_purpose, Engine as _}; +use golem_llm::golem::llm::llm::{ + ChatEvent, CompleteResponse, Config, ContentPart, Error, ErrorCode, FinishReason, + ImageReference, ImageSource, Message, ResponseMetadata, Role, ToolCall, + ToolDefinition, ToolResult, Usage, +}; +use std::collections::HashMap; + +pub fn messages_to_request( + messages: Vec, + config: Config, +) -> Result { + let options = config + .provider_options + .into_iter() + .map(|kv| (kv.key, kv.value)) + .collect::>(); + + let mut bedrock_messages = Vec::new(); + let mut system_messages = Vec::new(); + + for message in &messages { + match message.role { + Role::System => { + system_messages.extend(message_to_system_content(message)); + } + Role::User | Role::Assistant | Role::Tool => { + bedrock_messages.push(ClientMessage { + role: match &message.role { + Role::User | Role::Tool => ClientRole::User, + Role::Assistant => ClientRole::Assistant, + Role::System => unreachable!(), + }, + content: message_to_content(message)?, + }); + } + } + } + + let inference_config = if config.max_tokens.is_some() + || config.temperature.is_some() + || config.stop_sequences.is_some() + || options.contains_key("top_p") + { + Some(InferenceConfig { + max_tokens: config.max_tokens, + temperature: config.temperature, + top_p: options + .get("top_p") + .and_then(|top_p_s| top_p_s.parse::().ok()), + stop_sequences: config.stop_sequences, + }) + } else { + None + }; + + let tool_config = if config.tools.is_empty() { + None + } else { + let mut tools = Vec::new(); + for tool in &config.tools { + tools.push(tool_definition_to_tool(tool)?); + } + + let tool_choice = config.tool_choice.map(convert_tool_choice); + + Some(ToolConfig { + tools, + tool_choice, + }) + }; + + Ok(ConverseRequest { + model_id: config.model.clone(), + messages: bedrock_messages, + system: if system_messages.is_empty() { + None + } else { + Some(system_messages) + }, + inference_config, + tool_config, + guardrail_config: None, + additional_model_request_fields: None, + }) +} + +fn convert_tool_choice(tool_name: String) -> ToolChoice { + match tool_name.as_str() { + "auto" => ToolChoice::Auto, + "any" => ToolChoice::Any, + name => ToolChoice::Tool { + name: name.to_string(), + }, + } +} + +pub fn process_response(response: ConverseResponse) -> ChatEvent { + let mut contents = Vec::new(); + let mut tool_calls = Vec::new(); + + for content in response.output.message.content { + match content { + ContentBlock::Text { text } => contents.push(ContentPart::Text(text)), + ContentBlock::Image { format, source } => { + match general_purpose::STANDARD.decode(&source.bytes) { + Ok(decoded_data) => { + let mime_type = match format { + ImageFormat::Jpeg => "image/jpeg", + ImageFormat::Png => "image/png", + ImageFormat::Gif => "image/gif", + ImageFormat::Webp => "image/webp", + }; + contents.push(ContentPart::Image(ImageReference::Inline( + ImageSource { + data: decoded_data, + mime_type: mime_type.to_string(), + detail: None, + }, + ))); + } + Err(e) => { + return ChatEvent::Error(Error { + code: ErrorCode::InvalidRequest, + message: format!("Failed to decode base64 image data: {}", e), + provider_error_json: None, + }); + } + } + } + ContentBlock::ToolUse { + tool_use_id, + name, + input, + } => tool_calls.push(ToolCall { + id: tool_use_id, + name, + arguments_json: serde_json::to_string(&input).unwrap(), + }), + ContentBlock::ToolResult { .. } => {} + } + } + + if contents.is_empty() && !tool_calls.is_empty() { + ChatEvent::ToolRequest(tool_calls) + } else { + let request_id = response.response_metadata.request_id.clone(); + + let metadata = ResponseMetadata { + finish_reason: Some(stop_reason_to_finish_reason(response.stop_reason)), + usage: Some(convert_usage(response.usage)), + provider_id: Some(request_id.clone()), + timestamp: None, + provider_metadata_json: None, + }; + + ChatEvent::Message(CompleteResponse { + id: request_id, + content: contents, + tool_calls, + metadata, + }) + } +} + +pub fn tool_results_to_messages( + tool_results: Vec<(ToolCall, ToolResult)>, +) -> Vec { + let mut messages = Vec::new(); + + for (tool_call, tool_result) in tool_results { + messages.push(ClientMessage { + content: vec![ContentBlock::ToolUse { + tool_use_id: tool_call.id.clone(), + name: tool_call.name, + input: serde_json::from_str(&tool_call.arguments_json).unwrap(), + }], + role: ClientRole::Assistant, + }); + + let (content, status) = match tool_result { + ToolResult::Success(success) => ( + vec![ToolResultContentBlock::Text { + text: success.result_json, + }], + Some(ToolResultStatus::Success), + ), + ToolResult::Error(error) => ( + vec![ToolResultContentBlock::Text { + text: error.error_message, + }], + Some(ToolResultStatus::Error), + ), + }; + + messages.push(ClientMessage { + content: vec![ContentBlock::ToolResult { + tool_use_id: tool_call.id, + content, + status, + }], + role: ClientRole::User, + }); + } + + messages +} + +pub fn stop_reason_to_finish_reason(stop_reason: StopReason) -> FinishReason { + match stop_reason { + StopReason::EndTurn => FinishReason::Stop, + StopReason::ToolUse => FinishReason::ToolCalls, + StopReason::MaxTokens => FinishReason::Length, + StopReason::StopSequence => FinishReason::Stop, + StopReason::GuardrailIntervened => FinishReason::ContentFilter, + StopReason::ContentFiltered => FinishReason::ContentFilter, + } +} + +pub fn convert_usage(usage: crate::client::Usage) -> Usage { + Usage { + input_tokens: Some(usage.input_tokens), + output_tokens: Some(usage.output_tokens), + total_tokens: Some(usage.total_tokens), + } +} + +fn message_to_content(message: &Message) -> Result, Error> { + let mut result = Vec::new(); + + for content_part in &message.content { + match content_part { + ContentPart::Text(text) => result.push(ContentBlock::Text { + text: text.clone(), + }), + ContentPart::Image(image_reference) => match image_reference { + ImageReference::Url(_image_url) => { + return Err(Error { + code: ErrorCode::InvalidRequest, + message: "Bedrock API does not support image URLs, only base64 encoded images".to_string(), + provider_error_json: None, + }); + } + ImageReference::Inline(image_source) => { + let base64_data = general_purpose::STANDARD.encode(&image_source.data); + let format = match image_source.mime_type.as_str() { + "image/jpeg" => ImageFormat::Jpeg, + "image/png" => ImageFormat::Png, + "image/gif" => ImageFormat::Gif, + "image/webp" => ImageFormat::Webp, + _ => ImageFormat::Jpeg, + }; + + result.push(ContentBlock::Image { + format, + source: ClientImageSource { + bytes: base64_data, + }, + }); + } + }, + } + } + + Ok(result) +} + +fn message_to_system_content(message: &Message) -> Vec { + let mut result = Vec::new(); + + for content_part in &message.content { + match content_part { + ContentPart::Text(text) => result.push(SystemContentBlock::Text { + text: text.clone(), + }), + ContentPart::Image(_) => {} + } + } + + result +} + +fn tool_definition_to_tool(tool: &ToolDefinition) -> Result { + match serde_json::from_str(&tool.parameters_schema) { + Ok(json_schema) => Ok(Tool::ToolSpec { + name: tool.name.clone(), + description: tool.description.clone().unwrap_or_default(), + input_schema: ToolInputSchema { json: json_schema }, + }), + Err(error) => Err(Error { + code: ErrorCode::InternalError, + message: format!("Failed to parse tool parameters for {}: {error}", tool.name), + provider_error_json: None, + }), + } +} \ No newline at end of file diff --git a/llm/bedrock/src/lib.rs b/llm/bedrock/src/lib.rs new file mode 100644 index 000000000..e533903c8 --- /dev/null +++ b/llm/bedrock/src/lib.rs @@ -0,0 +1,315 @@ +mod client; +mod conversions; + +use crate::client::{BedrockClient, ConverseRequest}; +use crate::conversions::{ + convert_usage, messages_to_request, process_response, stop_reason_to_finish_reason, + tool_results_to_messages, +}; +use golem_llm::chat_stream::{LlmChatStream, LlmChatStreamState}; +use golem_llm::durability::{DurableLLM, ExtendedGuest}; +use golem_llm::event_source::EventSource; +use golem_llm::golem::llm::llm::{ + ChatEvent, ChatStream, Config, ContentPart, Error, ErrorCode, Guest, Message, ResponseMetadata, + Role, StreamDelta, StreamEvent, ToolCall, ToolResult, +}; +use golem_llm::LOGGING_STATE; +use golem_rust::wasm_rpc::Pollable; +use log::trace; +use serde_json::Value; +use std::cell::{Ref, RefCell, RefMut}; + +struct BedrockChatStream { + stream: RefCell>, + failure: Option, + finished: RefCell, + response_metadata: RefCell, +} + +impl BedrockChatStream { + pub fn new(stream: EventSource) -> LlmChatStream { + LlmChatStream::new(BedrockChatStream { + stream: RefCell::new(Some(stream)), + failure: None, + finished: RefCell::new(false), + response_metadata: RefCell::new(ResponseMetadata { + finish_reason: None, + usage: None, + provider_id: None, + timestamp: None, + provider_metadata_json: None, + }), + }) + } + + pub fn failed(error: Error) -> LlmChatStream { + LlmChatStream::new(BedrockChatStream { + stream: RefCell::new(None), + failure: Some(error), + finished: RefCell::new(false), + response_metadata: RefCell::new(ResponseMetadata { + finish_reason: None, + usage: None, + provider_id: None, + timestamp: None, + provider_metadata_json: None, + }), + }) + } +} + +impl LlmChatStreamState for BedrockChatStream { + fn failure(&self) -> &Option { + &self.failure + } + + fn is_finished(&self) -> bool { + *self.finished.borrow() + } + + fn set_finished(&self) { + *self.finished.borrow_mut() = true; + } + + fn stream(&self) -> Ref> { + self.stream.borrow() + } + + fn stream_mut(&self) -> RefMut> { + self.stream.borrow_mut() + } + + fn decode_message(&self, raw: &str) -> Result, String> { + trace!("Received raw stream event: {raw}"); + + let json: Value = serde_json::from_str(raw) + .map_err(|err| format!("Failed to deserialize stream event: {err}"))?; + + if let Some(content_block_delta) = json.get("contentBlockDelta") { + if let Some(delta) = content_block_delta.get("delta") { + if let Some(text) = delta.get("text").and_then(|v| v.as_str()) { + return Ok(Some(StreamEvent::Delta(StreamDelta { + content: Some(vec![ContentPart::Text(text.to_string())]), + tool_calls: None, + }))); + } + } + } + + if let Some(content_block_start) = json.get("contentBlockStart") { + if let Some(start) = content_block_start.get("start") { + if let Some(tool_use) = start.get("toolUse") { + if let (Some(tool_use_id), Some(name)) = ( + tool_use.get("toolUseId").and_then(|v| v.as_str()), + tool_use.get("name").and_then(|v| v.as_str()), + ) { + if let Some(input) = tool_use.get("input") { + return Ok(Some(StreamEvent::Delta(StreamDelta { + content: None, + tool_calls: Some(vec![ToolCall { + id: tool_use_id.to_string(), + name: name.to_string(), + arguments_json: serde_json::to_string(input).unwrap(), + }]), + }))); + } + } + } + } + } + + if let Some(metadata) = json.get("metadata") { + if let Some(usage) = metadata.get("usage") { + if let Ok(bedrock_usage) = serde_json::from_value::(usage.clone()) { + self.response_metadata.borrow_mut().usage = Some(convert_usage(bedrock_usage)); + } + } + } + + if let Some(message_stop) = json.get("messageStop") { + if let Some(stop_reason) = message_stop.get("stopReason").and_then(|v| v.as_str()) { + let stop_reason = match stop_reason { + "end_turn" => crate::client::StopReason::EndTurn, + "tool_use" => crate::client::StopReason::ToolUse, + "max_tokens" => crate::client::StopReason::MaxTokens, + "stop_sequence" => crate::client::StopReason::StopSequence, + "guardrail_intervened" => crate::client::StopReason::GuardrailIntervened, + "content_filtered" => crate::client::StopReason::ContentFiltered, + _ => crate::client::StopReason::EndTurn, + }; + self.response_metadata.borrow_mut().finish_reason = Some(stop_reason_to_finish_reason(stop_reason)); + } + + let response_metadata = self.response_metadata.borrow().clone(); + return Ok(Some(StreamEvent::Finish(response_metadata))); + } + + Ok(None) + } +} + +struct BedrockComponent; + +impl BedrockComponent { + const ACCESS_KEY_ID_ENV_VAR: &'static str = "AWS_ACCESS_KEY_ID"; + const SECRET_ACCESS_KEY_ENV_VAR: &'static str = "AWS_SECRET_ACCESS_KEY"; + const REGION_ENV_VAR: &'static str = "AWS_REGION"; + + fn get_client() -> Result { + let access_key_id = std::env::var(Self::ACCESS_KEY_ID_ENV_VAR) + .map_err(|_| Error { + code: ErrorCode::AuthenticationFailed, + message: format!("Missing environment variable: {}", Self::ACCESS_KEY_ID_ENV_VAR), + provider_error_json: None, + })?; + + let secret_access_key = std::env::var(Self::SECRET_ACCESS_KEY_ENV_VAR) + .map_err(|_| Error { + code: ErrorCode::AuthenticationFailed, + message: format!("Missing environment variable: {}", Self::SECRET_ACCESS_KEY_ENV_VAR), + provider_error_json: None, + })?; + + let region = std::env::var(Self::REGION_ENV_VAR) + .unwrap_or_else(|_| "us-east-1".to_string()); + + Ok(BedrockClient::new(access_key_id, secret_access_key, region)) + } + + fn request(client: BedrockClient, model_id: &str, request: ConverseRequest) -> ChatEvent { + match client.converse(model_id, request) { + Ok(response) => process_response(response), + Err(err) => ChatEvent::Error(err), + } + } + + fn streaming_request( + client: BedrockClient, + model_id: &str, + request: ConverseRequest, + ) -> LlmChatStream { + match client.converse_stream(model_id, request) { + Ok(stream) => BedrockChatStream::new(stream), + Err(err) => BedrockChatStream::failed(err), + } + } +} + +impl Guest for BedrockComponent { + type ChatStream = LlmChatStream; + + fn send(messages: Vec, config: Config) -> ChatEvent { + LOGGING_STATE.with_borrow_mut(|state| state.init()); + + let client = match Self::get_client() { + Ok(client) => client, + Err(err) => return ChatEvent::Error(err), + }; + + match messages_to_request(messages, config.clone()) { + Ok(request) => Self::request(client, &config.model, request), + Err(err) => ChatEvent::Error(err), + } + } + + fn continue_( + messages: Vec, + tool_results: Vec<(ToolCall, ToolResult)>, + config: Config, + ) -> ChatEvent { + LOGGING_STATE.with_borrow_mut(|state| state.init()); + + let client = match Self::get_client() { + Ok(client) => client, + Err(err) => return ChatEvent::Error(err), + }; + + match messages_to_request(messages, config.clone()) { + Ok(mut request) => { + request.messages.extend(tool_results_to_messages(tool_results)); + Self::request(client, &config.model, request) + } + Err(err) => ChatEvent::Error(err), + } + } + + fn stream(messages: Vec, config: Config) -> ChatStream { + ChatStream::new(Self::unwrapped_stream(messages, config)) + } +} + +impl ExtendedGuest for BedrockComponent { + fn unwrapped_stream( + messages: Vec, + config: Config, + ) -> LlmChatStream { + LOGGING_STATE.with_borrow_mut(|state| state.init()); + + let client = match Self::get_client() { + Ok(client) => client, + Err(err) => return BedrockChatStream::failed(err), + }; + + match messages_to_request(messages, config.clone()) { + Ok(request) => Self::streaming_request(client, &config.model, request), + Err(err) => BedrockChatStream::failed(err), + } + } + + fn retry_prompt(original_messages: &[Message], partial_result: &[StreamDelta]) -> Vec { + let mut extended_messages = Vec::new(); + extended_messages.push(Message { + role: Role::System, + name: None, + content: vec![ + ContentPart::Text( + "You were asked the same question previously, but the response was interrupted before completion. \ + Please continue your response from where you left off. \ + Do not include the part of the response that was already seen.".to_string()), + ], + }); + extended_messages.push(Message { + role: Role::User, + name: None, + content: vec![ContentPart::Text( + "Here is the original question:".to_string(), + )], + }); + extended_messages.extend_from_slice(original_messages); + + let mut partial_result_as_content = Vec::new(); + for delta in partial_result { + if let Some(contents) = &delta.content { + partial_result_as_content.extend_from_slice(contents); + } + if let Some(tool_calls) = &delta.tool_calls { + for tool_call in tool_calls { + partial_result_as_content.push(ContentPart::Text(format!( + "", + tool_call.id, tool_call.name, tool_call.arguments_json, + ))); + } + } + } + + extended_messages.push(Message { + role: Role::User, + name: None, + content: vec![ContentPart::Text( + "Here is the partial response that was successfully received:".to_string(), + )] + .into_iter() + .chain(partial_result_as_content) + .collect(), + }); + extended_messages + } + + fn subscribe(stream: &Self::ChatStream) -> Pollable { + stream.subscribe() + } +} + +type DurableBedrockComponent = DurableLLM; + +golem_llm::export_llm!(DurableBedrockComponent with_types_in golem_llm); \ No newline at end of file diff --git a/llm/bedrock/wit/bedrock.wit b/llm/bedrock/wit/bedrock.wit new file mode 100644 index 000000000..266ba7293 --- /dev/null +++ b/llm/bedrock/wit/bedrock.wit @@ -0,0 +1,7 @@ +package golem:llm-bedrock@1.0.0; + +world llm-library { + include golem:llm/llm-library@1.0.0; + + +} \ No newline at end of file diff --git a/llm/bedrock/wit/deps/golem-llm/golem-llm.wit b/llm/bedrock/wit/deps/golem-llm/golem-llm.wit new file mode 100644 index 000000000..67854470a --- /dev/null +++ b/llm/bedrock/wit/deps/golem-llm/golem-llm.wit @@ -0,0 +1,194 @@ +package golem:llm@1.0.0; + +interface llm { + // --- Roles, Error Codes, Finish Reasons --- + + enum role { + user, + assistant, + system, + tool, + } + + enum error-code { + invalid-request, + authentication-failed, + rate-limit-exceeded, + internal-error, + unsupported, + unknown, + } + + enum finish-reason { + stop, + length, + tool-calls, + content-filter, + error, + other, + } + + enum image-detail { + low, + high, + auto, + } + + // --- Message Content --- + + record image-url { + url: string, + detail: option, + } + + record image-source { + data: list, + mime-type: string, + detail: option, + } + + variant image-reference { + url(image-url), + inline(image-source), + } + + variant content-part { + text(string), + image(image-reference), + } + + record message { + role: role, + name: option, + content: list, + } + + // --- Tooling --- + + record tool-definition { + name: string, + description: option, + parameters-schema: string, + } + + record tool-call { + id: string, + name: string, + arguments-json: string, + } + + record tool-success { + id: string, + name: string, + result-json: string, + execution-time-ms: option, + } + + record tool-failure { + id: string, + name: string, + error-message: string, + error-code: option, + } + + variant tool-result { + success(tool-success), + error(tool-failure), + } + + // --- Configuration --- + + record kv { + key: string, + value: string, + } + + record config { + model: string, + temperature: option, + max-tokens: option, + stop-sequences: option>, + tools: list, + tool-choice: option, + provider-options: list, + } + + // --- Usage / Metadata --- + + record usage { + input-tokens: option, + output-tokens: option, + total-tokens: option, + } + + record response-metadata { + finish-reason: option, + usage: option, + provider-id: option, + timestamp: option, + provider-metadata-json: option, + } + + record complete-response { + id: string, + content: list, + tool-calls: list, + metadata: response-metadata, + } + + // --- Error Handling --- + + record error { + code: error-code, + message: string, + provider-error-json: option, + } + + // --- Chat Response Variants --- + + variant chat-event { + message(complete-response), + tool-request(list), + error(error), + } + + // --- Streaming --- + + record stream-delta { + content: option>, + tool-calls: option>, + } + + variant stream-event { + delta(stream-delta), + finish(response-metadata), + error(error), + } + + resource chat-stream { + get-next: func() -> option>; + blocking-get-next: func() -> list; + } + + // --- Core Functions --- + + send: func( + messages: list, + config: config + ) -> chat-event; + + continue: func( + messages: list, + tool-results: list>, + config: config + ) -> chat-event; + + %stream: func( + messages: list, + config: config + ) -> chat-stream; +} + +world llm-library { + export llm; +} diff --git a/llm/bedrock/wit/deps/wasi:io/error.wit b/llm/bedrock/wit/deps/wasi:io/error.wit new file mode 100644 index 000000000..97c606877 --- /dev/null +++ b/llm/bedrock/wit/deps/wasi:io/error.wit @@ -0,0 +1,34 @@ +package wasi:io@0.2.3; + +@since(version = 0.2.0) +interface error { + /// A resource which represents some error information. + /// + /// The only method provided by this resource is `to-debug-string`, + /// which provides some human-readable information about the error. + /// + /// In the `wasi:io` package, this resource is returned through the + /// `wasi:io/streams/stream-error` type. + /// + /// To provide more specific error information, other interfaces may + /// offer functions to "downcast" this error into more specific types. For example, + /// errors returned from streams derived from filesystem types can be described using + /// the filesystem's own error-code type. This is done using the function + /// `wasi:filesystem/types/filesystem-error-code`, which takes a `borrow` + /// parameter and returns an `option`. + /// + /// The set of functions which can "downcast" an `error` into a more + /// concrete type is open. + @since(version = 0.2.0) + resource error { + /// Returns a string that is suitable to assist humans in debugging + /// this error. + /// + /// WARNING: The returned string should not be consumed mechanically! + /// It may change across platforms, hosts, or other implementation + /// details. Parsing this string is a major platform-compatibility + /// hazard. + @since(version = 0.2.0) + to-debug-string: func() -> string; + } +} diff --git a/llm/bedrock/wit/deps/wasi:io/poll.wit b/llm/bedrock/wit/deps/wasi:io/poll.wit new file mode 100644 index 000000000..9bcbe8e03 --- /dev/null +++ b/llm/bedrock/wit/deps/wasi:io/poll.wit @@ -0,0 +1,47 @@ +package wasi:io@0.2.3; + +/// A poll API intended to let users wait for I/O events on multiple handles +/// at once. +@since(version = 0.2.0) +interface poll { + /// `pollable` represents a single I/O event which may be ready, or not. + @since(version = 0.2.0) + resource pollable { + + /// Return the readiness of a pollable. This function never blocks. + /// + /// Returns `true` when the pollable is ready, and `false` otherwise. + @since(version = 0.2.0) + ready: func() -> bool; + + /// `block` returns immediately if the pollable is ready, and otherwise + /// blocks until ready. + /// + /// This function is equivalent to calling `poll.poll` on a list + /// containing only this pollable. + @since(version = 0.2.0) + block: func(); + } + + /// Poll for completion on a set of pollables. + /// + /// This function takes a list of pollables, which identify I/O sources of + /// interest, and waits until one or more of the events is ready for I/O. + /// + /// The result `list` contains one or more indices of handles in the + /// argument list that is ready for I/O. + /// + /// This function traps if either: + /// - the list is empty, or: + /// - the list contains more elements than can be indexed with a `u32` value. + /// + /// A timeout can be implemented by adding a pollable from the + /// wasi-clocks API to the list. + /// + /// This function does not return a `result`; polling in itself does not + /// do any I/O so it doesn't fail. If any of the I/O sources identified by + /// the pollables has an error, it is indicated by marking the source as + /// being ready for I/O. + @since(version = 0.2.0) + poll: func(in: list>) -> list; +} diff --git a/llm/bedrock/wit/deps/wasi:io/streams.wit b/llm/bedrock/wit/deps/wasi:io/streams.wit new file mode 100644 index 000000000..0de084629 --- /dev/null +++ b/llm/bedrock/wit/deps/wasi:io/streams.wit @@ -0,0 +1,290 @@ +package wasi:io@0.2.3; + +/// WASI I/O is an I/O abstraction API which is currently focused on providing +/// stream types. +/// +/// In the future, the component model is expected to add built-in stream types; +/// when it does, they are expected to subsume this API. +@since(version = 0.2.0) +interface streams { + @since(version = 0.2.0) + use error.{error}; + @since(version = 0.2.0) + use poll.{pollable}; + + /// An error for input-stream and output-stream operations. + @since(version = 0.2.0) + variant stream-error { + /// The last operation (a write or flush) failed before completion. + /// + /// More information is available in the `error` payload. + /// + /// After this, the stream will be closed. All future operations return + /// `stream-error::closed`. + last-operation-failed(error), + /// The stream is closed: no more input will be accepted by the + /// stream. A closed output-stream will return this error on all + /// future operations. + closed + } + + /// An input bytestream. + /// + /// `input-stream`s are *non-blocking* to the extent practical on underlying + /// platforms. I/O operations always return promptly; if fewer bytes are + /// promptly available than requested, they return the number of bytes promptly + /// available, which could even be zero. To wait for data to be available, + /// use the `subscribe` function to obtain a `pollable` which can be polled + /// for using `wasi:io/poll`. + @since(version = 0.2.0) + resource input-stream { + /// Perform a non-blocking read from the stream. + /// + /// When the source of a `read` is binary data, the bytes from the source + /// are returned verbatim. When the source of a `read` is known to the + /// implementation to be text, bytes containing the UTF-8 encoding of the + /// text are returned. + /// + /// This function returns a list of bytes containing the read data, + /// when successful. The returned list will contain up to `len` bytes; + /// it may return fewer than requested, but not more. The list is + /// empty when no bytes are available for reading at this time. The + /// pollable given by `subscribe` will be ready when more bytes are + /// available. + /// + /// This function fails with a `stream-error` when the operation + /// encounters an error, giving `last-operation-failed`, or when the + /// stream is closed, giving `closed`. + /// + /// When the caller gives a `len` of 0, it represents a request to + /// read 0 bytes. If the stream is still open, this call should + /// succeed and return an empty list, or otherwise fail with `closed`. + /// + /// The `len` parameter is a `u64`, which could represent a list of u8 which + /// is not possible to allocate in wasm32, or not desirable to allocate as + /// as a return value by the callee. The callee may return a list of bytes + /// less than `len` in size while more bytes are available for reading. + @since(version = 0.2.0) + read: func( + /// The maximum number of bytes to read + len: u64 + ) -> result, stream-error>; + + /// Read bytes from a stream, after blocking until at least one byte can + /// be read. Except for blocking, behavior is identical to `read`. + @since(version = 0.2.0) + blocking-read: func( + /// The maximum number of bytes to read + len: u64 + ) -> result, stream-error>; + + /// Skip bytes from a stream. Returns number of bytes skipped. + /// + /// Behaves identical to `read`, except instead of returning a list + /// of bytes, returns the number of bytes consumed from the stream. + @since(version = 0.2.0) + skip: func( + /// The maximum number of bytes to skip. + len: u64, + ) -> result; + + /// Skip bytes from a stream, after blocking until at least one byte + /// can be skipped. Except for blocking behavior, identical to `skip`. + @since(version = 0.2.0) + blocking-skip: func( + /// The maximum number of bytes to skip. + len: u64, + ) -> result; + + /// Create a `pollable` which will resolve once either the specified stream + /// has bytes available to read or the other end of the stream has been + /// closed. + /// The created `pollable` is a child resource of the `input-stream`. + /// Implementations may trap if the `input-stream` is dropped before + /// all derived `pollable`s created with this function are dropped. + @since(version = 0.2.0) + subscribe: func() -> pollable; + } + + + /// An output bytestream. + /// + /// `output-stream`s are *non-blocking* to the extent practical on + /// underlying platforms. Except where specified otherwise, I/O operations also + /// always return promptly, after the number of bytes that can be written + /// promptly, which could even be zero. To wait for the stream to be ready to + /// accept data, the `subscribe` function to obtain a `pollable` which can be + /// polled for using `wasi:io/poll`. + /// + /// Dropping an `output-stream` while there's still an active write in + /// progress may result in the data being lost. Before dropping the stream, + /// be sure to fully flush your writes. + @since(version = 0.2.0) + resource output-stream { + /// Check readiness for writing. This function never blocks. + /// + /// Returns the number of bytes permitted for the next call to `write`, + /// or an error. Calling `write` with more bytes than this function has + /// permitted will trap. + /// + /// When this function returns 0 bytes, the `subscribe` pollable will + /// become ready when this function will report at least 1 byte, or an + /// error. + @since(version = 0.2.0) + check-write: func() -> result; + + /// Perform a write. This function never blocks. + /// + /// When the destination of a `write` is binary data, the bytes from + /// `contents` are written verbatim. When the destination of a `write` is + /// known to the implementation to be text, the bytes of `contents` are + /// transcoded from UTF-8 into the encoding of the destination and then + /// written. + /// + /// Precondition: check-write gave permit of Ok(n) and contents has a + /// length of less than or equal to n. Otherwise, this function will trap. + /// + /// returns Err(closed) without writing if the stream has closed since + /// the last call to check-write provided a permit. + @since(version = 0.2.0) + write: func( + contents: list + ) -> result<_, stream-error>; + + /// Perform a write of up to 4096 bytes, and then flush the stream. Block + /// until all of these operations are complete, or an error occurs. + /// + /// This is a convenience wrapper around the use of `check-write`, + /// `subscribe`, `write`, and `flush`, and is implemented with the + /// following pseudo-code: + /// + /// ```text + /// let pollable = this.subscribe(); + /// while !contents.is_empty() { + /// // Wait for the stream to become writable + /// pollable.block(); + /// let Ok(n) = this.check-write(); // eliding error handling + /// let len = min(n, contents.len()); + /// let (chunk, rest) = contents.split_at(len); + /// this.write(chunk ); // eliding error handling + /// contents = rest; + /// } + /// this.flush(); + /// // Wait for completion of `flush` + /// pollable.block(); + /// // Check for any errors that arose during `flush` + /// let _ = this.check-write(); // eliding error handling + /// ``` + @since(version = 0.2.0) + blocking-write-and-flush: func( + contents: list + ) -> result<_, stream-error>; + + /// Request to flush buffered output. This function never blocks. + /// + /// This tells the output-stream that the caller intends any buffered + /// output to be flushed. the output which is expected to be flushed + /// is all that has been passed to `write` prior to this call. + /// + /// Upon calling this function, the `output-stream` will not accept any + /// writes (`check-write` will return `ok(0)`) until the flush has + /// completed. The `subscribe` pollable will become ready when the + /// flush has completed and the stream can accept more writes. + @since(version = 0.2.0) + flush: func() -> result<_, stream-error>; + + /// Request to flush buffered output, and block until flush completes + /// and stream is ready for writing again. + @since(version = 0.2.0) + blocking-flush: func() -> result<_, stream-error>; + + /// Create a `pollable` which will resolve once the output-stream + /// is ready for more writing, or an error has occurred. When this + /// pollable is ready, `check-write` will return `ok(n)` with n>0, or an + /// error. + /// + /// If the stream is closed, this pollable is always ready immediately. + /// + /// The created `pollable` is a child resource of the `output-stream`. + /// Implementations may trap if the `output-stream` is dropped before + /// all derived `pollable`s created with this function are dropped. + @since(version = 0.2.0) + subscribe: func() -> pollable; + + /// Write zeroes to a stream. + /// + /// This should be used precisely like `write` with the exact same + /// preconditions (must use check-write first), but instead of + /// passing a list of bytes, you simply pass the number of zero-bytes + /// that should be written. + @since(version = 0.2.0) + write-zeroes: func( + /// The number of zero-bytes to write + len: u64 + ) -> result<_, stream-error>; + + /// Perform a write of up to 4096 zeroes, and then flush the stream. + /// Block until all of these operations are complete, or an error + /// occurs. + /// + /// This is a convenience wrapper around the use of `check-write`, + /// `subscribe`, `write-zeroes`, and `flush`, and is implemented with + /// the following pseudo-code: + /// + /// ```text + /// let pollable = this.subscribe(); + /// while num_zeroes != 0 { + /// // Wait for the stream to become writable + /// pollable.block(); + /// let Ok(n) = this.check-write(); // eliding error handling + /// let len = min(n, num_zeroes); + /// this.write-zeroes(len); // eliding error handling + /// num_zeroes -= len; + /// } + /// this.flush(); + /// // Wait for completion of `flush` + /// pollable.block(); + /// // Check for any errors that arose during `flush` + /// let _ = this.check-write(); // eliding error handling + /// ``` + @since(version = 0.2.0) + blocking-write-zeroes-and-flush: func( + /// The number of zero-bytes to write + len: u64 + ) -> result<_, stream-error>; + + /// Read from one stream and write to another. + /// + /// The behavior of splice is equivalent to: + /// 1. calling `check-write` on the `output-stream` + /// 2. calling `read` on the `input-stream` with the smaller of the + /// `check-write` permitted length and the `len` provided to `splice` + /// 3. calling `write` on the `output-stream` with that read data. + /// + /// Any error reported by the call to `check-write`, `read`, or + /// `write` ends the splice and reports that error. + /// + /// This function returns the number of bytes transferred; it may be less + /// than `len`. + @since(version = 0.2.0) + splice: func( + /// The stream to read from + src: borrow, + /// The number of bytes to splice + len: u64, + ) -> result; + + /// Read from one stream and write to another, with blocking. + /// + /// This is similar to `splice`, except that it blocks until the + /// `output-stream` is ready for writing, and the `input-stream` + /// is ready for reading, before performing the `splice`. + @since(version = 0.2.0) + blocking-splice: func( + /// The stream to read from + src: borrow, + /// The number of bytes to splice + len: u64, + ) -> result; + } +} diff --git a/llm/bedrock/wit/deps/wasi:io/world.wit b/llm/bedrock/wit/deps/wasi:io/world.wit new file mode 100644 index 000000000..f1d2102dc --- /dev/null +++ b/llm/bedrock/wit/deps/wasi:io/world.wit @@ -0,0 +1,10 @@ +package wasi:io@0.2.3; + +@since(version = 0.2.0) +world imports { + @since(version = 0.2.0) + import streams; + + @since(version = 0.2.0) + import poll; +} diff --git a/llm/grok/src/bindings.rs b/llm/grok/src/bindings.rs index 2a101583e..c2f601347 100644 --- a/llm/grok/src/bindings.rs +++ b/llm/grok/src/bindings.rs @@ -1,12 +1,15 @@ -// Generated by `wit-bindgen` 0.36.0. DO NOT EDIT! +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! // Options used: // * runtime_path: "wit_bindgen_rt" // * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" // * generate_unused_types use golem_llm::golem::llm::llm as __with_name0; #[cfg(target_arch = "wasm32")] -#[link_section = "component-type:wit-bindgen:0.36.0:golem:llm-grok@1.0.0:llm-library:encoded world"] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-grok@1.0.0:llm-library:encoded world" +)] #[doc(hidden)] +#[allow(clippy::octal_escapes)] pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1757] = *b"\ \0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xdb\x0c\x01A\x02\x01\ A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ @@ -43,8 +46,8 @@ ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send \0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0\x20gol\ em:llm-grok/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09p\ -roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rust\ -\x060.36.0"; +roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rust\ +\x060.41.0"; #[inline(never)] #[doc(hidden)] pub fn __link_custom_section_describing_imports() { diff --git a/llm/ollama/src/bindings.rs b/llm/ollama/src/bindings.rs index dbb704704..269cd07fb 100644 --- a/llm/ollama/src/bindings.rs +++ b/llm/ollama/src/bindings.rs @@ -1,12 +1,15 @@ -// Generated by `wit-bindgen` 0.36.0. DO NOT EDIT! +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! // Options used: // * runtime_path: "wit_bindgen_rt" // * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" // * generate_unused_types use golem_llm::golem::llm::llm as __with_name0; #[cfg(target_arch = "wasm32")] -#[link_section = "component-type:wit-bindgen:0.36.0:golem:llm-ollama@1.0.0:llm-library:encoded world"] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-ollama@1.0.0:llm-library:encoded world" +)] #[doc(hidden)] +#[allow(clippy::octal_escapes)] pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1759] = *b"\ \0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xdd\x0c\x01A\x02\x01\ A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ @@ -43,8 +46,8 @@ ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send \0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0\"golem\ :llm-ollama/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09p\ -roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rust\ -\x060.36.0"; +roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rust\ +\x060.41.0"; #[inline(never)] #[doc(hidden)] pub fn __link_custom_section_describing_imports() { diff --git a/llm/openai/src/bindings.rs b/llm/openai/src/bindings.rs index c960248a8..6d0a77280 100644 --- a/llm/openai/src/bindings.rs +++ b/llm/openai/src/bindings.rs @@ -1,12 +1,15 @@ -// Generated by `wit-bindgen` 0.36.0. DO NOT EDIT! +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! // Options used: // * runtime_path: "wit_bindgen_rt" // * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" // * generate_unused_types use golem_llm::golem::llm::llm as __with_name0; #[cfg(target_arch = "wasm32")] -#[link_section = "component-type:wit-bindgen:0.36.0:golem:llm-openai@1.0.0:llm-library:encoded world"] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-openai@1.0.0:llm-library:encoded world" +)] #[doc(hidden)] +#[allow(clippy::octal_escapes)] pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1759] = *b"\ \0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xdd\x0c\x01A\x02\x01\ A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ @@ -43,8 +46,8 @@ ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send \0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0\"golem\ :llm-openai/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09p\ -roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rust\ -\x060.36.0"; +roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rust\ +\x060.41.0"; #[inline(never)] #[doc(hidden)] pub fn __link_custom_section_describing_imports() { diff --git a/llm/openrouter/src/bindings.rs b/llm/openrouter/src/bindings.rs index ba2accf7e..1300cde97 100644 --- a/llm/openrouter/src/bindings.rs +++ b/llm/openrouter/src/bindings.rs @@ -1,12 +1,15 @@ -// Generated by `wit-bindgen` 0.36.0. DO NOT EDIT! +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! // Options used: // * runtime_path: "wit_bindgen_rt" // * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" // * generate_unused_types use golem_llm::golem::llm::llm as __with_name0; #[cfg(target_arch = "wasm32")] -#[link_section = "component-type:wit-bindgen:0.36.0:golem:llm-openrouter@1.0.0:llm-library:encoded world"] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-openrouter@1.0.0:llm-library:encoded world" +)] #[doc(hidden)] +#[allow(clippy::octal_escapes)] pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1763] = *b"\ \0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xe1\x0c\x01A\x02\x01\ A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ @@ -43,8 +46,8 @@ ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send \0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0&golem:\ llm-openrouter/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09\ -producers\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rus\ -t\x060.36.0"; +producers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rus\ +t\x060.41.0"; #[inline(never)] #[doc(hidden)] pub fn __link_custom_section_describing_imports() { diff --git a/test/components-rust/test-llm/Cargo.toml b/test/components-rust/test-llm/Cargo.toml index 7f6242874..52f77b3a7 100644 --- a/test/components-rust/test-llm/Cargo.toml +++ b/test/components-rust/test-llm/Cargo.toml @@ -15,6 +15,7 @@ grok = [] openai = [] openrouter = [] ollama = [] +bedrock = [] [dependencies] # To use common shared libs, use the following: diff --git a/test/components-rust/test-llm/golem.yaml b/test/components-rust/test-llm/golem.yaml index 6efa177c7..7076791d1 100644 --- a/test/components-rust/test-llm/golem.yaml +++ b/test/components-rust/test-llm/golem.yaml @@ -139,6 +139,32 @@ components: clean: - src/bindings.rs + bedrock-debug: + files: + - sourcePath: ../../data/cat.png + targetPath: /data/cat.png + permissions: read-only + build: + - command: cargo component build --no-default-features --features bedrock + sources: + - src + - wit-generated + - ../../common-rust + targets: + - ../../target/wasm32-wasip1/debug/test_llm.wasm + - command: wac plug --plug ../../../target/wasm32-wasip1/debug/golem_llm_bedrock.wasm ../../target/wasm32-wasip1/debug/test_llm.wasm -o ../../target/wasm32-wasip1/debug/test_bedrock_plugged.wasm + sources: + - ../../target/wasm32-wasip1/debug/test_llm.wasm + - ../../../target/wasm32-wasip1/debug/golem_llm_bedrock.wasm + targets: + - ../../target/wasm32-wasip1/debug/test_bedrock_plugged.wasm + sourceWit: wit + generatedWit: wit-generated + componentWasm: ../../target/wasm32-wasip1/debug/test_bedrock_plugged.wasm + linkedWasm: ../../golem-temp/components/test_bedrock_debug.wasm + clean: + - src/bindings.rs + # RELEASE PROFILES openai-release: files: @@ -270,6 +296,34 @@ components: clean: - src/bindings.rs + + bedrock-release: + files: + - sourcePath: ../../data/cat.png + targetPath: /data/cat.png + permissions: read-only + build: + - command: cargo component build --release --no-default-features --features bedrock + sources: + - src + - wit-generated + - ../../common-rust + targets: + - ../../target/wasm32-wasip1/release/test_llm.wasm + - command: wac plug --plug ../../../target/wasm32-wasip1/release/golem_llm_bedrock.wasm ../../target/wasm32-wasip1/release/test_llm.wasm -o ../../target/wasm32-wasip1/release/test_bedrock_plugged.wasm + sources: + - ../../target/wasm32-wasip1/release/test_llm.wasm + - ../../../target/wasm32-wasip1/release/golem_llm_bedrock.wasm + targets: + - ../../target/wasm32-wasip1/release/test_bedrock_plugged.wasm + sourceWit: wit + generatedWit: wit-generated + componentWasm: ../../target/wasm32-wasip1/release/test_bedrock_plugged.wasm + linkedWasm: ../../golem-temp/components/test_bedrock_release.wasm + clean: + - src/bindings.rs + + defaultProfile: openai-debug dependencies: From 47cfffa9a659f989a6857438695b6a9fb6e5ef93 Mon Sep 17 00:00:00 2001 From: Rutik7066 Date: Fri, 27 Jun 2025 10:35:07 +0000 Subject: [PATCH 2/8] sigv4 --- Cargo.lock | 61 +++- Cargo.toml | 6 +- llm/Makefile.toml | 2 + llm/bedrock/Cargo.toml | 7 +- llm/bedrock/src/client.rs | 359 +++++++++++++---------- llm/bedrock/src/lib.rs | 131 +++++---- llm/llm/src/config.rs | 30 +- test/components-rust/test-llm/Cargo.toml | 2 +- test/components-rust/test-llm/src/lib.rs | 4 + 9 files changed, 372 insertions(+), 230 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4168db8a9..ab1942ae9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -182,6 +182,15 @@ dependencies = [ "typenum", ] +[[package]] +name = "deranged" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" +dependencies = [ + "powerfmt", +] + [[package]] name = "digest" version = "0.10.7" @@ -414,16 +423,15 @@ name = "golem-llm-bedrock" version = "0.0.0" dependencies = [ "base64 0.22.1", - "chrono", "golem-llm", "golem-rust", - "hex", "hmac", "log", "reqwest", "serde", "serde_json", "sha2", + "time", "wit-bindgen-rt 0.40.0", ] @@ -548,12 +556,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hex" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" - [[package]] name = "hmac" version = "0.12.1" @@ -815,6 +817,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-traits" version = "0.2.19" @@ -857,6 +865,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "prettyplease" version = "0.2.32" @@ -1088,6 +1102,37 @@ dependencies = [ "syn", ] +[[package]] +name = "time" +version = "0.3.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" + +[[package]] +name = "time-macros" +version = "0.2.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tinystr" version = "0.8.1" diff --git a/Cargo.toml b/Cargo.toml index cb5c7e7ce..a71042e4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,8 +26,4 @@ reqwest = { git = "https://github.com/golemcloud/reqwest", branch = "update-may- serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0" } wit-bindgen-rt = { version = "0.40.0", features = ["bitflags"] } -base64 = { version = "0.22.1" } -hex = { version = "0.4" } -hmac = { version = "0.12" } -sha2 = { version = "0.10" } -chrono = { version = "0.4", features = ["serde"] } +base64 = { version = "0.22.1" } \ No newline at end of file diff --git a/llm/Makefile.toml b/llm/Makefile.toml index 15ee75be7..5ca26320a 100644 --- a/llm/Makefile.toml +++ b/llm/Makefile.toml @@ -230,6 +230,8 @@ golem-cli --version golem-cli app clean golem-cli app build -b anthropic-debug golem-cli app clean +golem-cli app build -b bedrock-debug +golem-cli app clean golem-cli app build -b grok-debug golem-cli app clean golem-cli app build -b openai-debug diff --git a/llm/bedrock/Cargo.toml b/llm/bedrock/Cargo.toml index a59411a4b..8add2714b 100644 --- a/llm/bedrock/Cargo.toml +++ b/llm/bedrock/Cargo.toml @@ -25,10 +25,9 @@ serde = { workspace = true } serde_json = { workspace = true } wit-bindgen-rt = { workspace = true } base64 = { workspace = true } -hex = { workspace = true } -hmac = { workspace = true } -sha2 = { workspace = true } -chrono = { workspace = true } +hmac = "0.12" +sha2 = "0.10" +time = { version = "0.3", features = ["formatting"] } [package.metadata.component] package = "golem:llm-bedrock" diff --git a/llm/bedrock/src/client.rs b/llm/bedrock/src/client.rs index 1e5a82da4..37eea2051 100644 --- a/llm/bedrock/src/client.rs +++ b/llm/bedrock/src/client.rs @@ -1,18 +1,16 @@ use golem_llm::error::{error_code_from_status, from_event_source_error, from_reqwest_error}; use golem_llm::event_source::EventSource; use golem_llm::golem::llm::llm::{Error, ErrorCode}; +use hmac::{Hmac, Mac}; use log::trace; use reqwest::{Client, Method, Response}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use serde_json::Value; +use sha2::{Digest, Sha256}; use std::fmt::Debug; -use std::collections::HashMap; -use chrono::Utc; -use hmac::{Hmac, Mac}; -use sha2::{Sha256, Digest}; - -type HmacSha256 = Hmac; +use std::time::{SystemTime, UNIX_EPOCH}; +use time::OffsetDateTime; /// AWS Bedrock client for creating model responses pub struct BedrockClient { @@ -24,9 +22,7 @@ pub struct BedrockClient { impl BedrockClient { pub fn new(access_key_id: String, secret_access_key: String, region: String) -> Self { - let client = Client::builder() - .build() - .expect("Failed to initialize HTTP client"); + let client = Client::new(); Self { access_key_id, secret_access_key, @@ -35,168 +31,219 @@ impl BedrockClient { } } - pub fn converse(&self, model_id: &str, request: ConverseRequest) -> Result { + pub fn converse( + &self, + model_id: &str, + request: ConverseRequest, + ) -> Result { trace!("Sending request to Bedrock API: {request:?}"); + let url = format!( + "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse", + self.region, model_id + ); - let body = serde_json::to_string(&request) - .map_err(|err| Error { - code: ErrorCode::InvalidRequest, - message: format!("Failed to serialize request: {err}"), - provider_error_json: None, - })?; - - let headers = self.sign_request(&body, model_id, false)?; - let url = format!("https://bedrock-runtime.{}.amazonaws.com/model/{}/converse", self.region, model_id); - - let mut request_builder = self.client.request(Method::POST, url); + let body = serde_json::to_string(&request).map_err(|err| Error { + code: ErrorCode::InternalError, + message: "Failed to serialize request".to_string(), + provider_error_json: Some(err.to_string()), + })?; + + let host = format!("bedrock-runtime.{}.amazonaws.com", self.region); + let headers = generate_sigv4_headers( + &self.access_key_id, + &self.secret_access_key, + &self.region, + "bedrock", + "POST", + &format!("/model/{}/converse", model_id), + &host, + &body, + ) + .map_err(|err| Error { + code: ErrorCode::InternalError, + message: "Failed to sign headers".to_string(), + provider_error_json: Some(err.to_string()), + })?; + + let mut request_builder = self.client.request(Method::POST, &url); + request_builder = request_builder.header("content-type", "application/json"); for (key, value) in headers { request_builder = request_builder.header(key, value); } - let response: Response = request_builder - .body(body) - .send() - .map_err(|err| from_reqwest_error("Request failed", err))?; + let response: Response = request_builder.body(body).send().map_err(|err| { + trace!("HTTP request failed with error: {:?}", err); + from_reqwest_error("Request failed", err) + })?; + + trace!("Received response from Bedrock API: {:?}", response); parse_response(response) } - pub fn converse_stream(&self, model_id: &str, request: ConverseRequest) -> Result { + pub fn converse_stream( + &self, + model_id: &str, + request: ConverseRequest, + ) -> Result { trace!("Sending streaming request to Bedrock API: {request:?}"); + let url = format!( + "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse-stream", + self.region, model_id + ); - let body = serde_json::to_string(&request) - .map_err(|err| Error { - code: ErrorCode::InvalidRequest, - message: format!("Failed to serialize request: {err}"), - provider_error_json: None, - })?; - - let headers = self.sign_request(&body, model_id, true)?; - let url = format!("https://bedrock-runtime.{}.amazonaws.com/model/{}/converse-stream", self.region, model_id); - - let mut request_builder = self.client.request(Method::POST, url); + let body = serde_json::to_string(&request).map_err(|err| Error { + code: ErrorCode::InternalError, + message: "Failed to serialize request".to_string(), + provider_error_json: Some(err.to_string()), + })?; + + let host = format!("bedrock-runtime.{}.amazonaws.com", self.region); + let headers = generate_sigv4_headers( + &self.access_key_id, + &self.secret_access_key, + &self.region, + "bedrock", + "POST", + &format!("/model/{}/converse-stream", model_id), + &host, + &body, + ) + .map_err(|err| Error { + code: ErrorCode::InternalError, + message: "Failed to sign headers".to_string(), + provider_error_json: Some(err.to_string()), + })?; + + let mut request_builder = self.client.request(Method::POST, &url); + request_builder = request_builder.header("content-type", "application/json"); for (key, value) in headers { request_builder = request_builder.header(key, value); } - let response: Response = request_builder - .body(body) - .send() - .map_err(|err| from_reqwest_error("Request failed", err))?; + trace!("Sending streaming HTTP request to Bedrock..."); + let response: Response = request_builder.body(body).send().map_err(|err| { + trace!("HTTP request failed with error: {:?}", err); + from_reqwest_error("Request failed", err) + })?; trace!("Initializing SSE stream"); EventSource::new(response) .map_err(|err| from_event_source_error("Failed to create SSE stream", err)) } +} - fn sign_request(&self, body: &str, model_id: &str, is_stream: bool) -> Result, Error> { - let now = Utc::now(); - let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string(); - let date_stamp = now.format("%Y%m%d").to_string(); - - let service = "bedrock"; - let endpoint = if is_stream { "converse-stream" } else { "converse" }; - let canonical_uri = format!("/model/{}/{}", model_id, endpoint); - - let canonical_headers = format!( - "host:bedrock-runtime.{}.amazonaws.com\nx-amz-date:{}\n", - self.region, amz_date - ); - let signed_headers = "host;x-amz-date"; - - let payload_hash = hex::encode(Sha256::digest(body.as_bytes())); - - let canonical_request = format!( - "POST\n{}\n\n{}\n{}\n{}", - canonical_uri, canonical_headers, signed_headers, payload_hash - ); - - let algorithm = "AWS4-HMAC-SHA256"; - let credential_scope = format!("{}/{}/{}/aws4_request", date_stamp, self.region, service); - let string_to_sign = format!( - "{}\n{}\n{}\n{}", - algorithm, - amz_date, - credential_scope, - hex::encode(Sha256::digest(canonical_request.as_bytes())) - ); - - let signing_key = self.get_signature_key(&date_stamp, service)?; - let signature = hex::encode( - HmacSha256::new_from_slice(&signing_key) - .map_err(|_| Error { - code: ErrorCode::InternalError, - message: "Failed to create HMAC".to_string(), - provider_error_json: None, - })? - .chain_update(string_to_sign.as_bytes()) - .finalize() - .into_bytes() - ); - - let authorization_header = format!( - "{} Credential={}/{}, SignedHeaders={}, Signature={}", - algorithm, self.access_key_id, credential_scope, signed_headers, signature - ); - - let mut headers = HashMap::new(); - headers.insert("Authorization".to_string(), authorization_header); - headers.insert("X-Amz-Date".to_string(), amz_date); - headers.insert("Host".to_string(), format!("bedrock-runtime.{}.amazonaws.com", self.region)); - headers.insert("Content-Type".to_string(), "application/json".to_string()); - - Ok(headers) - } - - fn get_signature_key(&self, date_stamp: &str, service: &str) -> Result, Error> { - let k_date = HmacSha256::new_from_slice(format!("AWS4{}", self.secret_access_key).as_bytes()) - .map_err(|_| Error { - code: ErrorCode::InternalError, - message: "Failed to create HMAC for date".to_string(), - provider_error_json: None, - })? - .chain_update(date_stamp.as_bytes()) - .finalize() - .into_bytes(); - - let k_region = HmacSha256::new_from_slice(&k_date) - .map_err(|_| Error { - code: ErrorCode::InternalError, - message: "Failed to create HMAC for region".to_string(), - provider_error_json: None, - })? - .chain_update(self.region.as_bytes()) - .finalize() - .into_bytes(); - - let k_service = HmacSha256::new_from_slice(&k_region) - .map_err(|_| Error { - code: ErrorCode::InternalError, - message: "Failed to create HMAC for service".to_string(), - provider_error_json: None, - })? - .chain_update(service.as_bytes()) - .finalize() - .into_bytes(); - - let k_signing = HmacSha256::new_from_slice(&k_service) - .map_err(|_| Error { - code: ErrorCode::InternalError, - message: "Failed to create HMAC for signing".to_string(), - provider_error_json: None, - })? - .chain_update(b"aws4_request") - .finalize() - .into_bytes(); - - Ok(k_signing.to_vec()) - } +pub fn generate_sigv4_headers( + access_key: &str, + secret_key: &str, + region: &str, + service: &str, + method: &str, + uri: &str, + host: &str, + body: &str, +) -> Result, Box> { + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); + let timestamp = OffsetDateTime::from_unix_timestamp(now.as_secs() as i64).unwrap(); + + let date_str = format!( + "{:04}{:02}{:02}", + timestamp.year(), + timestamp.month() as u8, + timestamp.day() + ); + let datetime_str = format!( + "{:04}{:02}{:02}T{:02}{:02}{:02}Z", + timestamp.year(), + timestamp.month() as u8, + timestamp.day(), + timestamp.hour(), + timestamp.minute(), + timestamp.second() + ); + + // Create canonical request + let path = if uri.starts_with('/') { uri } else { "/" }; + let query = ""; + + // Create canonical headers + let mut headers: Vec<(String, String)> = vec![ + ("host".to_string(), host.to_string()), + ("x-amz-date".to_string(), datetime_str.clone()), + ]; + headers.sort_by(|a, b| a.0.cmp(&b.0)); + + let canonical_headers = headers + .iter() + .map(|(k, v)| format!("{}:{}", k, v)) + .collect::>() + .join("\n") + + "\n"; + + let signed_headers = headers + .iter() + .map(|(k, _)| k.as_str()) + .collect::>() + .join(";"); + + // Hash payload + let payload_hash = format!("{:x}", Sha256::digest(body.as_bytes())); + + let canonical_request = format!( + "{}\n{}\n{}\n{}\n{}\n{}", + method, path, query, canonical_headers, signed_headers, payload_hash + ); + + // Create string to sign + let credential_scope = format!("{}/{}/{}/aws4_request", date_str, region, service); + let string_to_sign = format!( + "AWS4-HMAC-SHA256\n{}\n{}\n{:x}", + datetime_str, + credential_scope, + Sha256::digest(canonical_request.as_bytes()) + ); + + // Calculate signature + type HmacSha256 = Hmac; + + let mut mac = HmacSha256::new_from_slice(format!("AWS4{}", secret_key).as_bytes())?; + mac.update(date_str.as_bytes()); + let date_key = mac.finalize().into_bytes(); + + let mut mac = HmacSha256::new_from_slice(&date_key)?; + mac.update(region.as_bytes()); + let region_key = mac.finalize().into_bytes(); + + let mut mac = HmacSha256::new_from_slice(®ion_key)?; + mac.update(service.as_bytes()); + let service_key = mac.finalize().into_bytes(); + + let mut mac = HmacSha256::new_from_slice(&service_key)?; + mac.update(b"aws4_request"); + let signing_key = mac.finalize().into_bytes(); + + let mut mac = HmacSha256::new_from_slice(&signing_key)?; + mac.update(string_to_sign.as_bytes()); + let signature = format!("{:x}", mac.finalize().into_bytes()); + + // Create authorization header + let auth_header = format!( + "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", + access_key, credential_scope, signed_headers, signature + ); + + let mut result_headers = vec![ + ("authorization".to_string(), auth_header), + ("x-amz-date".to_string(), datetime_str), + ]; + + Ok(result_headers) } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConverseRequest { - #[serde(rename = "modelId")] + #[serde(skip_serializing, rename = "modelId")] pub model_id: String, pub messages: Vec, #[serde(skip_serializing_if = "Option::is_none")] @@ -207,7 +254,10 @@ pub struct ConverseRequest { pub tool_config: Option, #[serde(rename = "guardrailConfig", skip_serializing_if = "Option::is_none")] pub guardrail_config: Option, - #[serde(rename = "additionalModelRequestFields", skip_serializing_if = "Option::is_none")] + #[serde( + rename = "additionalModelRequestFields", + skip_serializing_if = "Option::is_none" + )] pub additional_model_request_fields: Option, } @@ -231,7 +281,7 @@ pub enum ContentBlock { #[serde(rename = "text")] Text { text: String }, #[serde(rename = "image")] - Image { + Image { #[serde(rename = "format")] format: ImageFormat, #[serde(rename = "source")] @@ -278,7 +328,7 @@ pub enum ToolResultContentBlock { #[serde(rename = "text")] Text { text: String }, #[serde(rename = "image")] - Image { + Image { #[serde(rename = "format")] format: ImageFormat, #[serde(rename = "source")] @@ -440,16 +490,17 @@ fn parse_response(response: Response) -> Result() + let body = response + .text() .map_err(|err| from_reqwest_error("Failed to receive error response body", err))?; - - trace!("Received {status} response from Bedrock API: {error_body:?}"); + trace!("Received {status} response from Bedrock API: {body:?}"); Err(Error { code: error_code_from_status(status), - message: format!("Request failed with {status}: {}", error_body.message), - provider_error_json: Some(serde_json::to_string(&error_body).unwrap()), + message: format!("Request failed with {status}: {}", body), + provider_error_json: Some(body), }) } -} \ No newline at end of file +} + + diff --git a/llm/bedrock/src/lib.rs b/llm/bedrock/src/lib.rs index e533903c8..2aec6450f 100644 --- a/llm/bedrock/src/lib.rs +++ b/llm/bedrock/src/lib.rs @@ -7,10 +7,11 @@ use crate::conversions::{ tool_results_to_messages, }; use golem_llm::chat_stream::{LlmChatStream, LlmChatStreamState}; +use golem_llm::config::with_config_keys; use golem_llm::durability::{DurableLLM, ExtendedGuest}; use golem_llm::event_source::EventSource; use golem_llm::golem::llm::llm::{ - ChatEvent, ChatStream, Config, ContentPart, Error, ErrorCode, Guest, Message, ResponseMetadata, + ChatEvent, ChatStream, Config, ContentPart, Error, Guest, Message, ResponseMetadata, Role, StreamDelta, StreamEvent, ToolCall, ToolResult, }; use golem_llm::LOGGING_STATE; @@ -81,7 +82,7 @@ impl LlmChatStreamState for BedrockChatStream { fn decode_message(&self, raw: &str) -> Result, String> { trace!("Received raw stream event: {raw}"); - + let json: Value = serde_json::from_str(raw) .map_err(|err| format!("Failed to deserialize stream event: {err}"))?; @@ -120,7 +121,9 @@ impl LlmChatStreamState for BedrockChatStream { if let Some(metadata) = json.get("metadata") { if let Some(usage) = metadata.get("usage") { - if let Ok(bedrock_usage) = serde_json::from_value::(usage.clone()) { + if let Ok(bedrock_usage) = + serde_json::from_value::(usage.clone()) + { self.response_metadata.borrow_mut().usage = Some(convert_usage(bedrock_usage)); } } @@ -137,7 +140,8 @@ impl LlmChatStreamState for BedrockChatStream { "content_filtered" => crate::client::StopReason::ContentFiltered, _ => crate::client::StopReason::EndTurn, }; - self.response_metadata.borrow_mut().finish_reason = Some(stop_reason_to_finish_reason(stop_reason)); + self.response_metadata.borrow_mut().finish_reason = + Some(stop_reason_to_finish_reason(stop_reason)); } let response_metadata = self.response_metadata.borrow().clone(); @@ -155,27 +159,6 @@ impl BedrockComponent { const SECRET_ACCESS_KEY_ENV_VAR: &'static str = "AWS_SECRET_ACCESS_KEY"; const REGION_ENV_VAR: &'static str = "AWS_REGION"; - fn get_client() -> Result { - let access_key_id = std::env::var(Self::ACCESS_KEY_ID_ENV_VAR) - .map_err(|_| Error { - code: ErrorCode::AuthenticationFailed, - message: format!("Missing environment variable: {}", Self::ACCESS_KEY_ID_ENV_VAR), - provider_error_json: None, - })?; - - let secret_access_key = std::env::var(Self::SECRET_ACCESS_KEY_ENV_VAR) - .map_err(|_| Error { - code: ErrorCode::AuthenticationFailed, - message: format!("Missing environment variable: {}", Self::SECRET_ACCESS_KEY_ENV_VAR), - provider_error_json: None, - })?; - - let region = std::env::var(Self::REGION_ENV_VAR) - .unwrap_or_else(|_| "us-east-1".to_string()); - - Ok(BedrockClient::new(access_key_id, secret_access_key, region)) - } - fn request(client: BedrockClient, model_id: &str, request: ConverseRequest) -> ChatEvent { match client.converse(model_id, request) { Ok(response) => process_response(response), @@ -200,16 +183,26 @@ impl Guest for BedrockComponent { fn send(messages: Vec, config: Config) -> ChatEvent { LOGGING_STATE.with_borrow_mut(|state| state.init()); - - let client = match Self::get_client() { - Ok(client) => client, - Err(err) => return ChatEvent::Error(err), - }; - - match messages_to_request(messages, config.clone()) { - Ok(request) => Self::request(client, &config.model, request), - Err(err) => ChatEvent::Error(err), - } + with_config_keys( + &[ + Self::ACCESS_KEY_ID_ENV_VAR, + Self::SECRET_ACCESS_KEY_ENV_VAR, + Self::REGION_ENV_VAR, + ], + ChatEvent::Error, + |bedrock_api_keys| { + let client = BedrockClient::new( + bedrock_api_keys[Self::ACCESS_KEY_ID_ENV_VAR].clone(), + bedrock_api_keys[Self::SECRET_ACCESS_KEY_ENV_VAR].clone(), + bedrock_api_keys[Self::REGION_ENV_VAR].clone(), + ); + + match messages_to_request(messages, config.clone()) { + Ok(request) => Self::request(client, &config.model, request), + Err(err) => ChatEvent::Error(err), + } + }, + ) } fn continue_( @@ -219,18 +212,31 @@ impl Guest for BedrockComponent { ) -> ChatEvent { LOGGING_STATE.with_borrow_mut(|state| state.init()); - let client = match Self::get_client() { - Ok(client) => client, - Err(err) => return ChatEvent::Error(err), - }; - - match messages_to_request(messages, config.clone()) { - Ok(mut request) => { - request.messages.extend(tool_results_to_messages(tool_results)); - Self::request(client, &config.model, request) - } - Err(err) => ChatEvent::Error(err), - } + with_config_keys( + &[ + Self::ACCESS_KEY_ID_ENV_VAR, + Self::SECRET_ACCESS_KEY_ENV_VAR, + Self::REGION_ENV_VAR, + ], + ChatEvent::Error, + |bedrock_api_keys| { + let client = BedrockClient::new( + bedrock_api_keys[Self::ACCESS_KEY_ID_ENV_VAR].clone(), + bedrock_api_keys[Self::SECRET_ACCESS_KEY_ENV_VAR].clone(), + bedrock_api_keys[Self::REGION_ENV_VAR].clone(), + ); + + match messages_to_request(messages, config.clone()) { + Ok(mut request) => { + request + .messages + .extend(tool_results_to_messages(tool_results)); + Self::request(client, &config.model, request) + } + Err(err) => ChatEvent::Error(err), + } + }, + ) } fn stream(messages: Vec, config: Config) -> ChatStream { @@ -245,15 +251,26 @@ impl ExtendedGuest for BedrockComponent { ) -> LlmChatStream { LOGGING_STATE.with_borrow_mut(|state| state.init()); - let client = match Self::get_client() { - Ok(client) => client, - Err(err) => return BedrockChatStream::failed(err), - }; - - match messages_to_request(messages, config.clone()) { - Ok(request) => Self::streaming_request(client, &config.model, request), - Err(err) => BedrockChatStream::failed(err), - } + with_config_keys( + &[ + Self::ACCESS_KEY_ID_ENV_VAR, + Self::SECRET_ACCESS_KEY_ENV_VAR, + Self::REGION_ENV_VAR, + ], + BedrockChatStream::failed, + |bedrock_api_keys| { + let client = BedrockClient::new( + bedrock_api_keys[Self::ACCESS_KEY_ID_ENV_VAR].clone(), + bedrock_api_keys[Self::SECRET_ACCESS_KEY_ENV_VAR].clone(), + bedrock_api_keys[Self::REGION_ENV_VAR].clone(), + ); + + match messages_to_request(messages, config.clone()) { + Ok(request) => Self::streaming_request(client, &config.model, request), + Err(err) => BedrockChatStream::failed(err), + } + }, + ) } fn retry_prompt(original_messages: &[Message], partial_result: &[StreamDelta]) -> Vec { @@ -312,4 +329,4 @@ impl ExtendedGuest for BedrockComponent { type DurableBedrockComponent = DurableLLM; -golem_llm::export_llm!(DurableBedrockComponent with_types_in golem_llm); \ No newline at end of file +golem_llm::export_llm!(DurableBedrockComponent with_types_in golem_llm); diff --git a/llm/llm/src/config.rs b/llm/llm/src/config.rs index de9822c36..43584ad32 100644 --- a/llm/llm/src/config.rs +++ b/llm/llm/src/config.rs @@ -1,5 +1,5 @@ use crate::golem::llm::llm::{Error, ErrorCode}; -use std::ffi::OsStr; +use std::{collections::HashMap, ffi::OsStr}; /// Gets an expected configuration value from the environment, and fails if its is not found /// using the `fail` function. Otherwise, it runs `succeed` with the configuration value. @@ -21,3 +21,31 @@ pub fn with_config_key( } } } + +/// Gets multiple expected configuration values from the environment, and fails if any is not found +/// using the `fail` function. Otherwise, it runs `succeed` with all configuration values. +pub fn with_config_keys( + keys: &[&str], + fail: impl FnOnce(Error) -> R, + succeed: impl FnOnce(HashMap) -> R, +) -> R { + let mut values = HashMap::new(); + + for key in keys { + match std::env::var(key) { + Ok(value) => { + values.insert(key.to_string(), value); + } + Err(_) => { + let error = Error { + code: ErrorCode::InternalError, + message: format!("Missing config key: {key}"), + provider_error_json: None, + }; + return fail(error); + } + } + } + + succeed(values) +} diff --git a/test/components-rust/test-llm/Cargo.toml b/test/components-rust/test-llm/Cargo.toml index 52f77b3a7..d4acefb4f 100644 --- a/test/components-rust/test-llm/Cargo.toml +++ b/test/components-rust/test-llm/Cargo.toml @@ -38,8 +38,8 @@ path = "wit-generated" [package.metadata.component.target.dependencies] "golem:llm" = { path = "wit-generated/deps/golem-llm" } -"wasi:clocks" = { path = "wit-generated/deps/clocks" } "wasi:io" = { path = "wit-generated/deps/io" } +"wasi:clocks" = { path = "wit-generated/deps/clocks" } "golem:rpc" = { path = "wit-generated/deps/golem-rpc" } "test:helper-client" = { path = "wit-generated/deps/test_helper-client" } "test:llm-exports" = { path = "wit-generated/deps/test_llm-exports" } diff --git a/test/components-rust/test-llm/src/lib.rs b/test/components-rust/test-llm/src/lib.rs index fa11684de..2779d84ee 100644 --- a/test/components-rust/test-llm/src/lib.rs +++ b/test/components-rust/test-llm/src/lib.rs @@ -19,6 +19,8 @@ const MODEL: &'static str = "grok-3-beta"; const MODEL: &'static str = "openrouter/auto"; #[cfg(feature = "ollama")] const MODEL: &'static str = "qwen3:1.7b"; +#[cfg(feature = "bedrock")] +const MODEL: &'static str = "amazon.nova-lite-v1:0"; #[cfg(feature = "openai")] const IMAGE_MODEL: &'static str = "gpt-4o-mini"; @@ -30,6 +32,8 @@ const IMAGE_MODEL: &'static str = "grok-2-vision-latest"; const IMAGE_MODEL: &'static str = "openrouter/auto"; #[cfg(feature = "ollama")] const IMAGE_MODEL: &'static str = "gemma3:4b"; +#[cfg(feature = "bedrock")] +const IMAGE_MODEL: &'static str = "amazon.nova-lite-v1:0"; impl Guest for Component { /// test1 demonstrates a simple, non-streaming text question-answer interaction with the LLM. From 3cbc715625f9bb606e67f14226b51495a2cdab2f Mon Sep 17 00:00:00 2001 From: Rutik7066 Date: Fri, 27 Jun 2025 20:09:39 +0000 Subject: [PATCH 3/8] sigv4 fixes --- Cargo.lock | 1 + llm/bedrock/Cargo.toml | 1 + llm/bedrock/src/client.rs | 214 ++++++++++++++++++----- llm/bedrock/src/conversions.rs | 12 +- test/components-rust/test-llm/Cargo.toml | 2 +- 5 files changed, 179 insertions(+), 51 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ab1942ae9..624fce7d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -427,6 +427,7 @@ dependencies = [ "golem-rust", "hmac", "log", + "percent-encoding", "reqwest", "serde", "serde_json", diff --git a/llm/bedrock/Cargo.toml b/llm/bedrock/Cargo.toml index 8add2714b..1ca88260b 100644 --- a/llm/bedrock/Cargo.toml +++ b/llm/bedrock/Cargo.toml @@ -28,6 +28,7 @@ base64 = { workspace = true } hmac = "0.12" sha2 = "0.10" time = { version = "0.3", features = ["formatting"] } +percent-encoding = "2.3" [package.metadata.component] package = "golem:llm-bedrock" diff --git a/llm/bedrock/src/client.rs b/llm/bedrock/src/client.rs index 37eea2051..a9620ce24 100644 --- a/llm/bedrock/src/client.rs +++ b/llm/bedrock/src/client.rs @@ -66,7 +66,6 @@ impl BedrockClient { })?; let mut request_builder = self.client.request(Method::POST, &url); - request_builder = request_builder.header("content-type", "application/json"); for (key, value) in headers { request_builder = request_builder.header(key, value); } @@ -116,7 +115,6 @@ impl BedrockClient { })?; let mut request_builder = self.client.request(Method::POST, &url); - request_builder = request_builder.header("content-type", "application/json"); for (key, value) in headers { request_builder = request_builder.header(key, value); } @@ -144,6 +142,8 @@ pub fn generate_sigv4_headers( host: &str, body: &str, ) -> Result, Box> { + use std::collections::BTreeMap; + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); let timestamp = OffsetDateTime::from_unix_timestamp(now.as_secs() as i64).unwrap(); @@ -163,48 +163,60 @@ pub fn generate_sigv4_headers( timestamp.second() ); - // Create canonical request - let path = if uri.starts_with('/') { uri } else { "/" }; - let query = ""; - - // Create canonical headers - let mut headers: Vec<(String, String)> = vec![ - ("host".to_string(), host.to_string()), - ("x-amz-date".to_string(), datetime_str.clone()), - ]; - headers.sort_by(|a, b| a.0.cmp(&b.0)); + let (canonical_uri, canonical_query_string) = if let Some(query_pos) = uri.find('?') { + let path = &uri[..query_pos]; + let query = &uri[query_pos + 1..]; + + let encoded_path = if path.contains(':') { + path.replace(':', "%3A") + } else { + path.to_string() + }; + + let mut query_params: Vec<&str> = query.split('&').collect(); + query_params.sort(); + (encoded_path, query_params.join("&")) + } else { + let encoded_path = if uri.contains(':') { + uri.replace(':', "%3A") + } else { + uri.to_string() + }; + (encoded_path, String::new()) + }; + + let mut headers = BTreeMap::new(); + headers.insert("content-type", "application/x-amz-json-1.0"); + headers.insert("host", host); + headers.insert("x-amz-date", &datetime_str); let canonical_headers = headers .iter() - .map(|(k, v)| format!("{}:{}", k, v)) + .map(|(k, v)| format!("{}:{}", k.to_lowercase().trim(), v.trim())) .collect::>() .join("\n") + "\n"; let signed_headers = headers - .iter() - .map(|(k, _)| k.as_str()) + .keys() + .map(|k| k.to_lowercase()) .collect::>() .join(";"); - // Hash payload let payload_hash = format!("{:x}", Sha256::digest(body.as_bytes())); let canonical_request = format!( "{}\n{}\n{}\n{}\n{}\n{}", - method, path, query, canonical_headers, signed_headers, payload_hash + method, canonical_uri, canonical_query_string, canonical_headers, signed_headers, payload_hash ); - // Create string to sign let credential_scope = format!("{}/{}/{}/aws4_request", date_str, region, service); + let canonical_request_hash = format!("{:x}", Sha256::digest(canonical_request.as_bytes())); let string_to_sign = format!( - "AWS4-HMAC-SHA256\n{}\n{}\n{:x}", - datetime_str, - credential_scope, - Sha256::digest(canonical_request.as_bytes()) + "AWS4-HMAC-SHA256\n{}\n{}\n{}", + datetime_str, credential_scope, canonical_request_hash ); - // Calculate signature type HmacSha256 = Hmac; let mut mac = HmacSha256::new_from_slice(format!("AWS4{}", secret_key).as_bytes())?; @@ -227,15 +239,15 @@ pub fn generate_sigv4_headers( mac.update(string_to_sign.as_bytes()); let signature = format!("{:x}", mac.finalize().into_bytes()); - // Create authorization header let auth_header = format!( "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", access_key, credential_scope, signed_headers, signature ); - let mut result_headers = vec![ + let result_headers = vec![ ("authorization".to_string(), auth_header), ("x-amz-date".to_string(), datetime_str), + ("content-type".to_string(), "application/x-amz-json-1.0".to_string()), ]; Ok(result_headers) @@ -373,20 +385,17 @@ pub struct ToolConfig { } #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type")] -pub enum Tool { +pub struct Tool { #[serde(rename = "toolSpec")] - ToolSpec { - name: String, - description: String, - #[serde(rename = "inputSchema")] - input_schema: ToolInputSchema, - }, + pub tool_spec: ToolSpec, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolInputSchema { - pub json: Value, +pub struct ToolSpec { + pub name: String, + pub description: String, + #[serde(rename = "inputSchema")] + pub input_schema: Value, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -406,16 +415,7 @@ pub struct GuardrailConfig { pub guardrail_identifier: String, #[serde(rename = "guardrailVersion")] pub guardrail_version: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub trace: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum GuardrailTrace { - #[serde(rename = "enabled")] - Enabled, - #[serde(rename = "disabled")] - Disabled, + pub trace: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -503,4 +503,128 @@ fn parse_response(response: Response) -> Result = headers.into_iter().collect(); + + assert!(header_map.contains_key("authorization")); + assert!(header_map.contains_key("x-amz-date")); + assert!(header_map.contains_key("content-type")); + + let auth_header = &header_map["authorization"]; + assert!(auth_header.starts_with("AWS4-HMAC-SHA256 Credential=")); + assert!(auth_header.contains("SignedHeaders=")); + assert!(auth_header.contains("Signature=")); + + assert_eq!(header_map["content-type"], "application/x-amz-json-1.0"); + + let date_header = &header_map["x-amz-date"]; + assert!(date_header.ends_with('Z')); + assert!(date_header.contains('T')); + } + + #[test] + fn test_canonical_headers_ordering() { + let access_key = "AKIAIOSFODNN7EXAMPLE"; + let secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"; + let region = "us-east-1"; + let service = "bedrock"; + let method = "POST"; + let uri = "/model/test/converse"; + let host = "bedrock-runtime.us-east-1.amazonaws.com"; + let body = "{}"; + + let result = generate_sigv4_headers( + access_key, + secret_key, + region, + service, + method, + uri, + host, + body, + ); + + assert!(result.is_ok()); + let headers = result.unwrap(); + let header_map: std::collections::HashMap = headers.into_iter().collect(); + + let auth_header = &header_map["authorization"]; + + assert!(auth_header.contains("SignedHeaders=content-type;host;x-amz-date")); + } + + #[test] + fn test_bedrock_client_integration() { + let client = BedrockClient::new( + "AKIAIOSFODNN7EXAMPLE".to_string(), + "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(), + "us-east-1".to_string(), + ); + + let request = ConverseRequest { + model_id: "anthropic.claude-3-sonnet-20240229-v1:0".to_string(), + messages: vec![Message { + role: Role::User, + content: vec![ContentBlock::Text { + text: "Hello, how are you?".to_string(), + }], + }], + system: None, + inference_config: None, + tool_config: None, + guardrail_config: None, + additional_model_request_fields: None, + }; + + + let body = serde_json::to_string(&request).expect("Failed to serialize request"); + let host = format!("bedrock-runtime.{}.amazonaws.com", client.region); + + let headers_result = generate_sigv4_headers( + &client.access_key_id, + &client.secret_access_key, + &client.region, + "bedrock", + "POST", + "/model/anthropic.claude-3-sonnet-20240229-v1:0/converse", + &host, + &body, + ); + + assert!(headers_result.is_ok()); + let headers = headers_result.unwrap(); + let header_names: Vec<&str> = headers.iter().map(|(k, _)| k.as_str()).collect(); + assert!(header_names.contains(&"authorization")); + assert!(header_names.contains(&"x-amz-date")); + assert!(header_names.contains(&"content-type")); + } +} \ No newline at end of file diff --git a/llm/bedrock/src/conversions.rs b/llm/bedrock/src/conversions.rs index c530c4b02..e11a3a621 100644 --- a/llm/bedrock/src/conversions.rs +++ b/llm/bedrock/src/conversions.rs @@ -1,7 +1,7 @@ use crate::client::{ ContentBlock, ConverseRequest, ConverseResponse, ImageFormat, ImageSource as ClientImageSource, InferenceConfig, Message as ClientMessage, Role as ClientRole, StopReason, SystemContentBlock, - Tool, ToolChoice, ToolConfig, ToolInputSchema, ToolResultContentBlock, ToolResultStatus, + Tool, ToolChoice, ToolConfig, ToolSpec, ToolResultContentBlock, ToolResultStatus, }; use base64::{engine::general_purpose, Engine as _}; use golem_llm::golem::llm::llm::{ @@ -287,10 +287,12 @@ fn message_to_system_content(message: &Message) -> Vec { fn tool_definition_to_tool(tool: &ToolDefinition) -> Result { match serde_json::from_str(&tool.parameters_schema) { - Ok(json_schema) => Ok(Tool::ToolSpec { - name: tool.name.clone(), - description: tool.description.clone().unwrap_or_default(), - input_schema: ToolInputSchema { json: json_schema }, + Ok(json_schema) => Ok(Tool { + tool_spec: ToolSpec { + name: tool.name.clone(), + description: tool.description.clone().unwrap_or_default(), + input_schema: json_schema, + }, }), Err(error) => Err(Error { code: ErrorCode::InternalError, diff --git a/test/components-rust/test-llm/Cargo.toml b/test/components-rust/test-llm/Cargo.toml index d4acefb4f..52f77b3a7 100644 --- a/test/components-rust/test-llm/Cargo.toml +++ b/test/components-rust/test-llm/Cargo.toml @@ -38,8 +38,8 @@ path = "wit-generated" [package.metadata.component.target.dependencies] "golem:llm" = { path = "wit-generated/deps/golem-llm" } -"wasi:io" = { path = "wit-generated/deps/io" } "wasi:clocks" = { path = "wit-generated/deps/clocks" } +"wasi:io" = { path = "wit-generated/deps/io" } "golem:rpc" = { path = "wit-generated/deps/golem-rpc" } "test:helper-client" = { path = "wit-generated/deps/test_helper-client" } "test:llm-exports" = { path = "wit-generated/deps/test_llm-exports" } From e96d5822134a4d79e7dc5b7fc9809a7d6ad416a8 Mon Sep 17 00:00:00 2001 From: Rutik7066 Date: Sun, 29 Jun 2025 07:08:50 +0000 Subject: [PATCH 4/8] toolcall, image fix & stream parser wip --- llm/bedrock/src/client.rs | 82 +++++++----- llm/bedrock/src/conversions.rs | 160 ++++++++++++++++------- llm/llm/src/event_source/mod.rs | 2 +- test/components-rust/test-llm/Cargo.toml | 2 +- test/components-rust/test-llm/src/lib.rs | 4 +- 5 files changed, 163 insertions(+), 87 deletions(-) diff --git a/llm/bedrock/src/client.rs b/llm/bedrock/src/client.rs index a9620ce24..e8455c60f 100644 --- a/llm/bedrock/src/client.rs +++ b/llm/bedrock/src/client.rs @@ -126,7 +126,7 @@ impl BedrockClient { })?; trace!("Initializing SSE stream"); - + trace!("Response: {:?}", response.headers().clone()); EventSource::new(response) .map_err(|err| from_event_source_error("Failed to create SSE stream", err)) } @@ -288,34 +288,45 @@ pub enum Role { } #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type")] +#[serde(untagged)] pub enum ContentBlock { - #[serde(rename = "text")] Text { text: String }, - #[serde(rename = "image")] Image { - #[serde(rename = "format")] - format: ImageFormat, - #[serde(rename = "source")] - source: ImageSource, + image: ImageBlock, }, - #[serde(rename = "toolUse")] ToolUse { - #[serde(rename = "toolUseId")] - tool_use_id: String, - name: String, - input: Value, + #[serde(rename = "toolUse")] + tool_use: ToolUseBlock, }, - #[serde(rename = "toolResult")] ToolResult { - #[serde(rename = "toolUseId")] - tool_use_id: String, - content: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - status: Option, + #[serde(rename = "toolResult")] + tool_result: ToolResultBlock, }, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImageBlock { + pub format: ImageFormat, + pub source: ImageSource, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolUseBlock { + #[serde(rename = "toolUseId")] + pub tool_use_id: String, + pub name: String, + pub input: Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolResultBlock { + #[serde(rename = "toolUseId")] + pub tool_use_id: String, + pub content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub status: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub enum ImageFormat { #[serde(rename = "png")] @@ -329,7 +340,6 @@ pub enum ImageFormat { } #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "bytes")] pub struct ImageSource { pub bytes: String, } @@ -395,18 +405,25 @@ pub struct ToolSpec { pub name: String, pub description: String, #[serde(rename = "inputSchema")] - pub input_schema: Value, + pub input_schema: ToolInputSchema, } #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type")] +pub struct ToolInputSchema { + pub json: Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] pub enum ToolChoice { - #[serde(rename = "auto")] - Auto, - #[serde(rename = "any")] - Any, - #[serde(rename = "tool")] - Tool { name: String }, + Auto { auto: serde_json::Value }, + Any { any: serde_json::Value }, + Tool { tool: ToolChoiceTool }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolChoiceTool { + pub name: String, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -420,20 +437,15 @@ pub struct GuardrailConfig { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConverseResponse { - #[serde(rename = "responseMetadata")] - pub response_metadata: ResponseMetadata, pub output: Output, #[serde(rename = "stopReason")] pub stop_reason: StopReason, pub usage: Usage, pub metrics: Metrics, + #[serde(rename = "additionalModelResponseFields", skip_serializing_if = "Option::is_none")] + pub additional_model_response_fields: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ResponseMetadata { - #[serde(rename = "requestId")] - pub request_id: String, -} #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Output { diff --git a/llm/bedrock/src/conversions.rs b/llm/bedrock/src/conversions.rs index e11a3a621..2d9215c8d 100644 --- a/llm/bedrock/src/conversions.rs +++ b/llm/bedrock/src/conversions.rs @@ -2,6 +2,7 @@ use crate::client::{ ContentBlock, ConverseRequest, ConverseResponse, ImageFormat, ImageSource as ClientImageSource, InferenceConfig, Message as ClientMessage, Role as ClientRole, StopReason, SystemContentBlock, Tool, ToolChoice, ToolConfig, ToolSpec, ToolResultContentBlock, ToolResultStatus, + ImageBlock, ToolUseBlock, ToolResultBlock, ToolInputSchema, ToolChoiceTool, }; use base64::{engine::general_purpose, Engine as _}; use golem_llm::golem::llm::llm::{ @@ -9,7 +10,8 @@ use golem_llm::golem::llm::llm::{ ImageReference, ImageSource, Message, ResponseMetadata, Role, ToolCall, ToolDefinition, ToolResult, Usage, }; -use std::collections::HashMap; +use reqwest::{Client, Url}; +use std::{collections::HashMap, fs, path::Path}; pub fn messages_to_request( messages: Vec, @@ -69,10 +71,14 @@ pub fn messages_to_request( let tool_choice = config.tool_choice.map(convert_tool_choice); - Some(ToolConfig { - tools, - tool_choice, - }) + if tools.is_empty() { + None + } else { + Some(ToolConfig { + tools, + tool_choice, + }) + } }; Ok(ConverseRequest { @@ -91,11 +97,19 @@ pub fn messages_to_request( } fn convert_tool_choice(tool_name: String) -> ToolChoice { + use serde_json::Value; + match tool_name.as_str() { - "auto" => ToolChoice::Auto, - "any" => ToolChoice::Any, + "auto" => ToolChoice::Auto { + auto: Value::Object(serde_json::Map::new()), + }, + "any" => ToolChoice::Any { + any: Value::Object(serde_json::Map::new()), + }, name => ToolChoice::Tool { - name: name.to_string(), + tool: ToolChoiceTool { + name: name.to_string(), + }, }, } } @@ -107,10 +121,10 @@ pub fn process_response(response: ConverseResponse) -> ChatEvent { for content in response.output.message.content { match content { ContentBlock::Text { text } => contents.push(ContentPart::Text(text)), - ContentBlock::Image { format, source } => { - match general_purpose::STANDARD.decode(&source.bytes) { + ContentBlock::Image { image } => { + match general_purpose::STANDARD.decode(&image.source.bytes) { Ok(decoded_data) => { - let mime_type = match format { + let mime_type = match image.format { ImageFormat::Jpeg => "image/jpeg", ImageFormat::Png => "image/png", ImageFormat::Gif => "image/gif", @@ -133,14 +147,10 @@ pub fn process_response(response: ConverseResponse) -> ChatEvent { } } } - ContentBlock::ToolUse { - tool_use_id, - name, - input, - } => tool_calls.push(ToolCall { - id: tool_use_id, - name, - arguments_json: serde_json::to_string(&input).unwrap(), + ContentBlock::ToolUse { tool_use } => tool_calls.push(ToolCall { + id: tool_use.tool_use_id, + name: tool_use.name, + arguments_json: serde_json::to_string(&tool_use.input).unwrap(), }), ContentBlock::ToolResult { .. } => {} } @@ -149,7 +159,7 @@ pub fn process_response(response: ConverseResponse) -> ChatEvent { if contents.is_empty() && !tool_calls.is_empty() { ChatEvent::ToolRequest(tool_calls) } else { - let request_id = response.response_metadata.request_id.clone(); + let request_id = "bedrock-response".to_string(); let metadata = ResponseMetadata { finish_reason: Some(stop_reason_to_finish_reason(response.stop_reason)), @@ -176,9 +186,11 @@ pub fn tool_results_to_messages( for (tool_call, tool_result) in tool_results { messages.push(ClientMessage { content: vec![ContentBlock::ToolUse { - tool_use_id: tool_call.id.clone(), - name: tool_call.name, - input: serde_json::from_str(&tool_call.arguments_json).unwrap(), + tool_use: ToolUseBlock { + tool_use_id: tool_call.id.clone(), + name: tool_call.name, + input: serde_json::from_str(&tool_call.arguments_json).unwrap(), + }, }], role: ClientRole::Assistant, }); @@ -200,9 +212,11 @@ pub fn tool_results_to_messages( messages.push(ClientMessage { content: vec![ContentBlock::ToolResult { - tool_use_id: tool_call.id, - content, - status, + tool_result: ToolResultBlock { + tool_use_id: tool_call.id, + content, + status, + }, }], role: ClientRole::User, }); @@ -239,13 +253,45 @@ fn message_to_content(message: &Message) -> Result, Error> { text: text.clone(), }), ContentPart::Image(image_reference) => match image_reference { - ImageReference::Url(_image_url) => { - return Err(Error { - code: ErrorCode::InvalidRequest, - message: "Bedrock API does not support image URLs, only base64 encoded images".to_string(), - provider_error_json: None, + ImageReference::Url(image_url) => { + let url = &image_url.url; + let mut format = ImageFormat::Png; + let bytes = if Url::parse(url).is_ok() { + let client = Client::new(); + let response = client.get(url).send().map_err(|e| Error { + code: ErrorCode::InvalidRequest, + message: format!("Failed to fetch image from URL: {}", e), + provider_error_json: None, + }); + response.map(|r| { + format = match r.headers().get("Content-Type").unwrap().to_str().unwrap() { + "image/jpeg" => ImageFormat::Jpeg, + "image/png" => ImageFormat::Png, + "image/gif" => ImageFormat::Gif, + "image/webp" => ImageFormat::Webp, + _ => ImageFormat::Jpeg, + }; + r.bytes().unwrap().to_vec() + }) + } else { + let path = Path::new(url); + fs::read(path).map_err(|e| Error { + code: ErrorCode::InvalidRequest, + message: format!("Failed to read image from path: {}", e), + provider_error_json: None, + }) + }; + + let base64_data = general_purpose::STANDARD.encode(&bytes.unwrap()); + result.push(ContentBlock::Image { + image: ImageBlock { + format: ImageFormat::Png, + source: ClientImageSource { + bytes: base64_data, + }, + }, }); - } + }, ImageReference::Inline(image_source) => { let base64_data = general_purpose::STANDARD.encode(&image_source.data); let format = match image_source.mime_type.as_str() { @@ -257,9 +303,11 @@ fn message_to_content(message: &Message) -> Result, Error> { }; result.push(ContentBlock::Image { - format, - source: ClientImageSource { - bytes: base64_data, + image: ImageBlock { + format, + source: ClientImageSource { + bytes: base64_data, + }, }, }); } @@ -286,18 +334,34 @@ fn message_to_system_content(message: &Message) -> Vec { } fn tool_definition_to_tool(tool: &ToolDefinition) -> Result { - match serde_json::from_str(&tool.parameters_schema) { - Ok(json_schema) => Ok(Tool { - tool_spec: ToolSpec { - name: tool.name.clone(), - description: tool.description.clone().unwrap_or_default(), - input_schema: json_schema, + use serde_json::Value; + + let schema_value = if tool.parameters_schema.trim().is_empty() { + serde_json::json!({ + "type": "object", + "properties": {}, + "additionalProperties": false + }) + } else { + match serde_json::from_str::(&tool.parameters_schema) { + Ok(value) => value, + Err(error) => { + return Err(Error { + code: ErrorCode::InternalError, + message: format!("Failed to parse tool parameters for {}: {error}", tool.name), + provider_error_json: None, + }); + } + } + }; + + Ok(Tool { + tool_spec: ToolSpec { + name: tool.name.clone(), + description: tool.description.clone().unwrap_or_default(), + input_schema: ToolInputSchema { + json: schema_value, }, - }), - Err(error) => Err(Error { - code: ErrorCode::InternalError, - message: format!("Failed to parse tool parameters for {}: {error}", tool.name), - provider_error_json: None, - }), - } + }, + }) } \ No newline at end of file diff --git a/llm/llm/src/event_source/mod.rs b/llm/llm/src/event_source/mod.rs index 2e92d4dd8..18c94dc83 100644 --- a/llm/llm/src/event_source/mod.rs +++ b/llm/llm/src/event_source/mod.rs @@ -140,7 +140,7 @@ fn check_response(response: Response) -> Result { matches!( (mime_type.type_(), mime_type.subtype()), (mime::TEXT, mime::EVENT_STREAM) - ) || mime_type.subtype().as_str().contains("ndjson") + ) || mime_type.subtype().as_str().contains("ndjson") || content_type.to_str().unwrap_or("").contains("vnd.amazon.eventstream") }) .unwrap_or(false) { diff --git a/test/components-rust/test-llm/Cargo.toml b/test/components-rust/test-llm/Cargo.toml index 52f77b3a7..d4acefb4f 100644 --- a/test/components-rust/test-llm/Cargo.toml +++ b/test/components-rust/test-llm/Cargo.toml @@ -38,8 +38,8 @@ path = "wit-generated" [package.metadata.component.target.dependencies] "golem:llm" = { path = "wit-generated/deps/golem-llm" } -"wasi:clocks" = { path = "wit-generated/deps/clocks" } "wasi:io" = { path = "wit-generated/deps/io" } +"wasi:clocks" = { path = "wit-generated/deps/clocks" } "golem:rpc" = { path = "wit-generated/deps/golem-rpc" } "test:helper-client" = { path = "wit-generated/deps/test_helper-client" } "test:llm-exports" = { path = "wit-generated/deps/test_llm-exports" } diff --git a/test/components-rust/test-llm/src/lib.rs b/test/components-rust/test-llm/src/lib.rs index 2779d84ee..e0c7a8a0b 100644 --- a/test/components-rust/test-llm/src/lib.rs +++ b/test/components-rust/test-llm/src/lib.rs @@ -20,7 +20,7 @@ const MODEL: &'static str = "openrouter/auto"; #[cfg(feature = "ollama")] const MODEL: &'static str = "qwen3:1.7b"; #[cfg(feature = "bedrock")] -const MODEL: &'static str = "amazon.nova-lite-v1:0"; +const MODEL: &'static str = "anthropic.claude-3-haiku-20240307-v1:0"; #[cfg(feature = "openai")] const IMAGE_MODEL: &'static str = "gpt-4o-mini"; @@ -33,7 +33,7 @@ const IMAGE_MODEL: &'static str = "openrouter/auto"; #[cfg(feature = "ollama")] const IMAGE_MODEL: &'static str = "gemma3:4b"; #[cfg(feature = "bedrock")] -const IMAGE_MODEL: &'static str = "amazon.nova-lite-v1:0"; +const IMAGE_MODEL: &'static str = "anthropic.claude-3-haiku-20240307-v1:0"; impl Guest for Component { /// test1 demonstrates a simple, non-streaming text question-answer interaction with the LLM. From 0a34c8ce2868059042975f98293668ca320e55aa Mon Sep 17 00:00:00 2001 From: Rutik7066 Date: Sun, 29 Jun 2025 20:24:05 +0000 Subject: [PATCH 5/8] stream fix --- Cargo.lock | 79 +++++++++ llm/bedrock/src/lib.rs | 150 ++++++++++++----- llm/llm/Cargo.toml | 3 + llm/llm/src/event_source/aws_eventstream.rs | 175 ++++++++++++++++++++ llm/llm/src/event_source/mod.rs | 19 ++- llm/llm/src/event_source/stream.rs | 5 +- 6 files changed, 386 insertions(+), 45 deletions(-) create mode 100644 llm/llm/src/event_source/aws_eventstream.rs diff --git a/Cargo.lock b/Cargo.lock index 624fce7d3..d5243af1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,6 +47,35 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "aws-smithy-eventstream" +version = "0.60.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "338a3642c399c0a5d157648426110e199ca7fd1c689cc395676b81aa563700c4" +dependencies = [ + "aws-smithy-types", + "bytes", + "crc32fast", +] + +[[package]] +name = "aws-smithy-types" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d498595448e43de7f4296b7b7a18a8a02c61ec9349128c80a368f7c3b4ab11a8" +dependencies = [ + "base64-simd", + "bytes", + "bytes-utils", + "itoa", + "num-integer", + "pin-project-lite", + "pin-utils", + "ryu", + "serde", + "time", +] + [[package]] name = "base64" version = "0.21.7" @@ -59,6 +88,16 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "339abbe78e73178762e23bea9dfd08e697eb3f3301cd4be981c0f78ba5859195" +dependencies = [ + "outref", + "vsimd", +] + [[package]] name = "bitflags" version = "2.9.1" @@ -86,6 +125,16 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +[[package]] +name = "bytes-utils" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" +dependencies = [ + "bytes", + "either", +] + [[package]] name = "camino" version = "1.1.9" @@ -213,6 +262,12 @@ dependencies = [ "syn", ] +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -394,6 +449,9 @@ dependencies = [ name = "golem-llm" version = "0.0.0" dependencies = [ + "aws-smithy-eventstream", + "aws-smithy-types", + "base64 0.22.1", "golem-rust", "log", "mime", @@ -824,6 +882,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -839,6 +906,12 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "outref" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" + [[package]] name = "percent-encoding" version = "2.3.1" @@ -1222,6 +1295,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "vsimd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" + [[package]] name = "wasi" version = "0.14.2+wasi-0.2.4" diff --git a/llm/bedrock/src/lib.rs b/llm/bedrock/src/lib.rs index 2aec6450f..7c16f2032 100644 --- a/llm/bedrock/src/lib.rs +++ b/llm/bedrock/src/lib.rs @@ -17,6 +17,7 @@ use golem_llm::golem::llm::llm::{ use golem_llm::LOGGING_STATE; use golem_rust::wasm_rpc::Pollable; use log::trace; +use serde::Deserialize; use serde_json::Value; use std::cell::{Ref, RefCell, RefMut}; @@ -27,6 +28,73 @@ struct BedrockChatStream { response_metadata: RefCell, } + +/// [2025-06-29T18:11:10.458Z] [TRACE ] [golem_llm_bedrock] llm/bedrock/src/lib.rs:84: Received raw stream event: +/// { +/// "contentBlockIndex":1, +/// "delta":{ +/// "toolUse":{ +/// "input":" 10 +/// }"}}, +/// "p":"abcdefghijklmnopqrstuvwxyzAB" +/// } +/// { +/// "contentBlockIndex":0, +/// "delta": +/// { +/// "text":" German" +/// }, +/// "p":"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX" +/// } +/// + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum Delta { + ToolUse { + #[serde(rename = "toolUse")] + tool_use: ToolUse, + }, + Text { + text: String, + }, +} + +#[derive(Debug, Deserialize)] +pub struct ToolUse { + pub input: String, +} + +// Additional structs for different message types +#[derive(Debug, Deserialize)] +pub struct MessageStart { + pub p: String, + pub role: String, +} + +#[derive(Debug, Deserialize)] +pub struct MessageStop { + pub p: String, + #[serde(rename = "stopReason")] + pub stop_reason: String, +} + +#[derive(Debug, Deserialize)] +pub struct MetadataMessage { + pub p: String, + pub usage: Option, + pub metrics: Option, +} + +#[derive(Debug, Deserialize)] +pub struct EventContentBlock { + #[serde(rename = "contentBlockIndex")] + pub content_block_index: u32, + pub delta : Delta, + pub p: String, +} + + impl BedrockChatStream { pub fn new(stream: EventSource) -> LlmChatStream { LlmChatStream::new(BedrockChatStream { @@ -85,53 +153,47 @@ impl LlmChatStreamState for BedrockChatStream { let json: Value = serde_json::from_str(raw) .map_err(|err| format!("Failed to deserialize stream event: {err}"))?; - - if let Some(content_block_delta) = json.get("contentBlockDelta") { - if let Some(delta) = content_block_delta.get("delta") { - if let Some(text) = delta.get("text").and_then(|v| v.as_str()) { - return Ok(Some(StreamEvent::Delta(StreamDelta { - content: Some(vec![ContentPart::Text(text.to_string())]), - tool_calls: None, - }))); - } - } - } - - if let Some(content_block_start) = json.get("contentBlockStart") { - if let Some(start) = content_block_start.get("start") { - if let Some(tool_use) = start.get("toolUse") { - if let (Some(tool_use_id), Some(name)) = ( - tool_use.get("toolUseId").and_then(|v| v.as_str()), - tool_use.get("name").and_then(|v| v.as_str()), - ) { - if let Some(input) = tool_use.get("input") { + + // 1. Handle content block delta messages (contentBlockIndex + delta) + if json.get("contentBlockIndex").is_some() && json.get("delta").is_some() { + match serde_json::from_value::(json.clone()) { + Ok(event_content_block) => { + match event_content_block.delta { + Delta::Text { text } => { + return Ok(Some(StreamEvent::Delta(StreamDelta { + content: Some(vec![ContentPart::Text(text)]), + tool_calls: None, + }))); + } + Delta::ToolUse { tool_use } => { + // Handle tool use delta - this would need tool call ID and name from earlier message + // For now, just return the input as text return Ok(Some(StreamEvent::Delta(StreamDelta { - content: None, - tool_calls: Some(vec![ToolCall { - id: tool_use_id.to_string(), - name: name.to_string(), - arguments_json: serde_json::to_string(input).unwrap(), - }]), + content: Some(vec![ContentPart::Text(tool_use.input)]), + tool_calls: None, }))); } } } + Err(err) => { + trace!("Failed to parse as EventContentBlock: {}", err); + // Continue to other parsing attempts + } } } - if let Some(metadata) = json.get("metadata") { - if let Some(usage) = metadata.get("usage") { - if let Ok(bedrock_usage) = - serde_json::from_value::(usage.clone()) - { - self.response_metadata.borrow_mut().usage = Some(convert_usage(bedrock_usage)); - } + // 3. Handle message start (role + p) + if json.get("role").is_some() { + if let Ok(_message_start) = serde_json::from_value::(json.clone()) { + // Message start event - just metadata, no content to return + return Ok(None); } } - if let Some(message_stop) = json.get("messageStop") { - if let Some(stop_reason) = message_stop.get("stopReason").and_then(|v| v.as_str()) { - let stop_reason = match stop_reason { + // 4. Handle message stop with stopReason + if json.get("stopReason").is_some() { + if let Ok(message_stop) = serde_json::from_value::(json.clone()) { + let stop_reason = match message_stop.stop_reason.as_str() { "end_turn" => crate::client::StopReason::EndTurn, "tool_use" => crate::client::StopReason::ToolUse, "max_tokens" => crate::client::StopReason::MaxTokens, @@ -142,12 +204,22 @@ impl LlmChatStreamState for BedrockChatStream { }; self.response_metadata.borrow_mut().finish_reason = Some(stop_reason_to_finish_reason(stop_reason)); - } - let response_metadata = self.response_metadata.borrow().clone(); - return Ok(Some(StreamEvent::Finish(response_metadata))); + let response_metadata = self.response_metadata.borrow().clone(); + return Ok(Some(StreamEvent::Finish(response_metadata))); + } } + // 5. Handle metadata messages with usage/metrics + if json.get("usage").is_some() || json.get("metrics").is_some() { + if let Ok(metadata) = serde_json::from_value::(json.clone()) { + if let Some(usage) = metadata.usage { + self.response_metadata.borrow_mut().usage = Some(convert_usage(usage)); + } + // Metadata processed, no event to return + return Ok(None); + } + } Ok(None) } } diff --git a/llm/llm/Cargo.toml b/llm/llm/Cargo.toml index f12ada98b..fee601afe 100644 --- a/llm/llm/Cargo.toml +++ b/llm/llm/Cargo.toml @@ -12,6 +12,9 @@ path = "src/lib.rs" crate-type = ["rlib"] [dependencies] +aws-smithy-eventstream = "0.60.9" +aws-smithy-types = "1.3.2" +base64 = "0.22.1" golem-rust = { workspace = true } log = { workspace = true } mime = "0.3.17" diff --git a/llm/llm/src/event_source/aws_eventstream.rs b/llm/llm/src/event_source/aws_eventstream.rs new file mode 100644 index 000000000..bc5514392 --- /dev/null +++ b/llm/llm/src/event_source/aws_eventstream.rs @@ -0,0 +1,175 @@ +use std::task::Poll; + +use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder}; +use aws_smithy_types::event_stream::HeaderValue; +use base64::Engine; +use golem_rust::{ + bindings::wasi::io::streams::{InputStream, StreamError}, + wasm_rpc::Pollable, +}; +use log::trace; + +use crate::event_source::{ + stream::{LlmStream, StreamError as AwsEventStreamError}, + MessageEvent, +}; + +#[derive(Debug, Clone, Copy)] +pub enum AwsEventStreamState { + NotStarted, + Started, + Terminated, +} + +impl AwsEventStreamState { + fn is_terminated(self) -> bool { + matches!(self, Self::Terminated) + } +} + +/// A Stream of AWS EventStream events using the vnd.amazon.eventstream binary format +pub struct AwsEventStream { + stream: InputStream, + decoder: MessageFrameDecoder, + buffer: Vec, + state: AwsEventStreamState, + last_event_id: String, + subscription: Pollable, +} + +impl LlmStream for AwsEventStream { + fn new(stream: InputStream) -> Self { + let subscription = stream.subscribe(); + Self { + decoder: MessageFrameDecoder::new(), + buffer: Vec::new(), + state: AwsEventStreamState::NotStarted, + last_event_id: String::new(), + stream, + subscription, + } + } + + fn set_last_event_id(&mut self, id: impl Into) { + self.last_event_id = id.into(); + } + + fn last_event_id(&self) -> &str { + &self.last_event_id + } + + fn subscribe(&self) -> Pollable { + self.stream.subscribe() + } + + fn poll_next( + &mut self, + ) -> Poll>>> { + trace!("Polling for next AWS EventStream event"); + + if let Some(event) = try_decode_message(self)? { + return Poll::Ready(Some(Ok(event))); + } + + if self.state.is_terminated() { + return Poll::Ready(None); + } + + loop { + if self.subscription.ready() { + match self.stream.read(8192) { + Ok(bytes) => { + if bytes.is_empty() { + continue; + } + + if !self.state.is_terminated() { + self.state = AwsEventStreamState::Started; + } + + self.buffer.extend_from_slice(&bytes); + + // Try to decode complete messages from the updated buffer + if let Some(event) = try_decode_message(self)? { + return Poll::Ready(Some(Ok(event))); + } + } + Err(StreamError::Closed) => { + trace!("AWS EventStream closed"); + self.state = AwsEventStreamState::Terminated; + return Poll::Ready(None); + } + Err(err) => return Poll::Ready(Some(Err(AwsEventStreamError::Transport(err)))), + } + } else { + return Poll::Pending; + } + } + } +} + +fn try_decode_message( + stream: &mut AwsEventStream, +) -> Result, AwsEventStreamError> { + if stream.buffer.is_empty() { + return Ok(None); + } + + let mut buffer_slice = stream.buffer.as_slice(); + let original_len = buffer_slice.len(); + + match stream.decoder.decode_frame(&mut buffer_slice) { + Ok(DecodedFrame::Complete(message)) => { + trace!( + "Decoded AWS EventStream message with {} byte payload", + message.payload().len() + ); + + let consumed = original_len - buffer_slice.len(); + + let event_type = message + .headers() + .iter() + .find(|header| header.name().as_str() == ":event-type") + .and_then(|header| { + if let HeaderValue::String(s) = header.value() { + Some(s.as_str()) + } else { + None + } + }) + .unwrap_or("message"); + + let data = match std::str::from_utf8(message.payload()) { + Ok(s) => s.to_string(), + Err(_) => base64::engine::general_purpose::STANDARD.encode(message.payload()), + }; + + if let Some(id_header) = message + .headers() + .iter() + .find(|header| header.name().as_str() == ":event-id") + { + if let HeaderValue::String(id) = id_header.value() { + stream.last_event_id = id.as_str().to_string(); + } + } + + stream.buffer.drain(..consumed); + + let event = MessageEvent { + event: event_type.to_string(), + data, + id: stream.last_event_id.clone(), + retry: None, + }; + + Ok(Some(event)) + } + Ok(DecodedFrame::Incomplete) => Ok(None), + Err(err) => Err(AwsEventStreamError::Parser(nom::error::Error::new( + format!("AWS EventStream decode error: {}", err), + nom::error::ErrorKind::Tag, + ))), + } +} diff --git a/llm/llm/src/event_source/mod.rs b/llm/llm/src/event_source/mod.rs index 18c94dc83..8ed4689f3 100644 --- a/llm/llm/src/event_source/mod.rs +++ b/llm/llm/src/event_source/mod.rs @@ -2,6 +2,7 @@ // modified to use the wasi-http based reqwest, and wasi pollables pub mod error; +mod aws_eventstream; mod event_stream; mod message_event; mod ndjson_stream; @@ -11,6 +12,7 @@ mod utf8_stream; use crate::event_source::error::Error; use crate::event_source::event_stream::EventStream; +use aws_eventstream::AwsEventStream; use golem_rust::wasm_rpc::Pollable; pub use message_event::MessageEvent; use ndjson_stream::NdJsonStream; @@ -50,15 +52,17 @@ impl EventSource { >(response.get_raw_input_stream()) }; - let stream = if response + let content_type = response .headers() .get(&reqwest::header::CONTENT_TYPE) .unwrap() .to_str() - .unwrap() - .contains("ndjson") - { + .unwrap(); + + let stream = if content_type.contains("ndjson") { StreamType::NdJsonStream(NdJsonStream::new(handle)) + } else if content_type.contains("vnd.amazon.eventstream") { + StreamType::AwsEventStream(AwsEventStream::new(handle)) } else { StreamType::EventStream(EventStream::new(handle)) }; @@ -90,6 +94,7 @@ impl EventSource { match &self.stream { StreamType::EventStream(stream) => stream.subscribe(), StreamType::NdJsonStream(stream) => stream.subscribe(), + StreamType::AwsEventStream(stream) => stream.subscribe(), } } @@ -111,6 +116,12 @@ impl EventSource { Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, }, + StreamType::AwsEventStream(stream) => match stream.poll_next() { + Poll::Ready(Some(Ok(event))) => Poll::Ready(Some(Ok(Event::Message(event)))), + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + }, } } } diff --git a/llm/llm/src/event_source/stream.rs b/llm/llm/src/event_source/stream.rs index 8f2933676..c145f2773 100644 --- a/llm/llm/src/event_source/stream.rs +++ b/llm/llm/src/event_source/stream.rs @@ -2,8 +2,8 @@ use core::fmt; use std::{string::FromUtf8Error, task::Poll}; use super::{ - event_stream::EventStream, ndjson_stream::NdJsonStream, utf8_stream::Utf8StreamError, - MessageEvent, + aws_eventstream::AwsEventStream, event_stream::EventStream, ndjson_stream::NdJsonStream, + utf8_stream::Utf8StreamError, MessageEvent, }; use golem_rust::{ bindings::wasi::io::streams::{InputStream, StreamError as WasiStreamError}, @@ -14,6 +14,7 @@ use nom::error::Error as NomError; pub enum StreamType { EventStream(EventStream), NdJsonStream(NdJsonStream), + AwsEventStream(AwsEventStream), } pub trait LlmStream { From 7c75d37d492b472a15662fbe3881dfe30117c555 Mon Sep 17 00:00:00 2001 From: Rutik7066 Date: Mon, 30 Jun 2025 08:31:53 +0000 Subject: [PATCH 6/8] lint --- llm/anthropic/src/conversions.rs | 2 +- llm/bedrock/src/client.rs | 178 +++----------------- llm/bedrock/src/conversions.rs | 96 +++++------ llm/bedrock/src/lib.rs | 20 +-- llm/grok/src/conversions.rs | 2 +- llm/llm/src/config.rs | 4 +- llm/llm/src/event_source/aws_eventstream.rs | 2 +- llm/llm/src/event_source/mod.rs | 8 +- llm/llm/src/event_source/ndjson_stream.rs | 2 +- llm/llm/src/event_source/stream.rs | 9 +- llm/ollama/src/client.rs | 2 +- llm/ollama/src/conversions.rs | 2 +- llm/openai/src/conversions.rs | 2 +- llm/openrouter/src/conversions.rs | 2 +- 14 files changed, 97 insertions(+), 234 deletions(-) diff --git a/llm/anthropic/src/conversions.rs b/llm/anthropic/src/conversions.rs index e332f1391..e7d3175a0 100644 --- a/llm/anthropic/src/conversions.rs +++ b/llm/anthropic/src/conversions.rs @@ -130,7 +130,7 @@ pub fn process_response(response: MessagesResponse) -> ChatEvent { Err(e) => { return ChatEvent::Error(Error { code: ErrorCode::InvalidRequest, - message: format!("Failed to decode base64 image data: {}", e), + message: format!("Failed to decode base64 image data: {e}"), provider_error_json: None, }); } diff --git a/llm/bedrock/src/client.rs b/llm/bedrock/src/client.rs index e8455c60f..fd0149c5a 100644 --- a/llm/bedrock/src/client.rs +++ b/llm/bedrock/src/client.rs @@ -55,7 +55,7 @@ impl BedrockClient { &self.region, "bedrock", "POST", - &format!("/model/{}/converse", model_id), + &format!("/model/{model_id}/converse"), &host, &body, ) @@ -71,11 +71,11 @@ impl BedrockClient { } let response: Response = request_builder.body(body).send().map_err(|err| { - trace!("HTTP request failed with error: {:?}", err); + trace!("HTTP request failed with error: {err:?}"); from_reqwest_error("Request failed", err) })?; - trace!("Received response from Bedrock API: {:?}", response); + trace!("Received response from Bedrock API: {response:?}"); parse_response(response) } @@ -104,7 +104,7 @@ impl BedrockClient { &self.region, "bedrock", "POST", - &format!("/model/{}/converse-stream", model_id), + &format!("/model/{model_id}/converse-stream"), &host, &body, ) @@ -121,7 +121,7 @@ impl BedrockClient { trace!("Sending streaming HTTP request to Bedrock..."); let response: Response = request_builder.body(body).send().map_err(|err| { - trace!("HTTP request failed with error: {:?}", err); + trace!("HTTP request failed with error: {err:?}"); from_reqwest_error("Request failed", err) })?; @@ -132,6 +132,7 @@ impl BedrockClient { } } +#[allow(clippy::too_many_arguments)] pub fn generate_sigv4_headers( access_key: &str, secret_key: &str, @@ -143,7 +144,7 @@ pub fn generate_sigv4_headers( body: &str, ) -> Result, Box> { use std::collections::BTreeMap; - + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); let timestamp = OffsetDateTime::from_unix_timestamp(now.as_secs() as i64).unwrap(); @@ -166,13 +167,13 @@ pub fn generate_sigv4_headers( let (canonical_uri, canonical_query_string) = if let Some(query_pos) = uri.find('?') { let path = &uri[..query_pos]; let query = &uri[query_pos + 1..]; - + let encoded_path = if path.contains(':') { path.replace(':', "%3A") } else { path.to_string() }; - + let mut query_params: Vec<&str> = query.split('&').collect(); query_params.sort(); (encoded_path, query_params.join("&")) @@ -206,20 +207,17 @@ pub fn generate_sigv4_headers( let payload_hash = format!("{:x}", Sha256::digest(body.as_bytes())); let canonical_request = format!( - "{}\n{}\n{}\n{}\n{}\n{}", - method, canonical_uri, canonical_query_string, canonical_headers, signed_headers, payload_hash + "{method}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n{signed_headers}\n{payload_hash}" ); - let credential_scope = format!("{}/{}/{}/aws4_request", date_str, region, service); + let credential_scope = format!("{date_str}/{region}/{service}/aws4_request"); let canonical_request_hash = format!("{:x}", Sha256::digest(canonical_request.as_bytes())); - let string_to_sign = format!( - "AWS4-HMAC-SHA256\n{}\n{}\n{}", - datetime_str, credential_scope, canonical_request_hash - ); + let string_to_sign = + format!("AWS4-HMAC-SHA256\n{datetime_str}\n{credential_scope}\n{canonical_request_hash}"); type HmacSha256 = Hmac; - let mut mac = HmacSha256::new_from_slice(format!("AWS4{}", secret_key).as_bytes())?; + let mut mac = HmacSha256::new_from_slice(format!("AWS4{secret_key}").as_bytes())?; mac.update(date_str.as_bytes()); let date_key = mac.finalize().into_bytes(); @@ -240,14 +238,16 @@ pub fn generate_sigv4_headers( let signature = format!("{:x}", mac.finalize().into_bytes()); let auth_header = format!( - "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", - access_key, credential_scope, signed_headers, signature + "AWS4-HMAC-SHA256 Credential={access_key}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}" ); let result_headers = vec![ ("authorization".to_string(), auth_header), ("x-amz-date".to_string(), datetime_str), - ("content-type".to_string(), "application/x-amz-json-1.0".to_string()), + ( + "content-type".to_string(), + "application/x-amz-json-1.0".to_string(), + ), ]; Ok(result_headers) @@ -255,8 +255,6 @@ pub fn generate_sigv4_headers( #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConverseRequest { - #[serde(skip_serializing, rename = "modelId")] - pub model_id: String, pub messages: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub system: Option>, @@ -290,7 +288,9 @@ pub enum Role { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] pub enum ContentBlock { - Text { text: String }, + Text { + text: String, + }, Image { image: ImageBlock, }, @@ -442,11 +442,13 @@ pub struct ConverseResponse { pub stop_reason: StopReason, pub usage: Usage, pub metrics: Metrics, - #[serde(rename = "additionalModelResponseFields", skip_serializing_if = "Option::is_none")] + #[serde( + rename = "additionalModelResponseFields", + skip_serializing_if = "Option::is_none" + )] pub additional_model_response_fields: Option, } - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Output { pub message: Message, @@ -509,134 +511,8 @@ fn parse_response(response: Response) -> Result = headers.into_iter().collect(); - - assert!(header_map.contains_key("authorization")); - assert!(header_map.contains_key("x-amz-date")); - assert!(header_map.contains_key("content-type")); - - let auth_header = &header_map["authorization"]; - assert!(auth_header.starts_with("AWS4-HMAC-SHA256 Credential=")); - assert!(auth_header.contains("SignedHeaders=")); - assert!(auth_header.contains("Signature=")); - - assert_eq!(header_map["content-type"], "application/x-amz-json-1.0"); - - let date_header = &header_map["x-amz-date"]; - assert!(date_header.ends_with('Z')); - assert!(date_header.contains('T')); - } - - #[test] - fn test_canonical_headers_ordering() { - let access_key = "AKIAIOSFODNN7EXAMPLE"; - let secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"; - let region = "us-east-1"; - let service = "bedrock"; - let method = "POST"; - let uri = "/model/test/converse"; - let host = "bedrock-runtime.us-east-1.amazonaws.com"; - let body = "{}"; - - let result = generate_sigv4_headers( - access_key, - secret_key, - region, - service, - method, - uri, - host, - body, - ); - - assert!(result.is_ok()); - let headers = result.unwrap(); - let header_map: std::collections::HashMap = headers.into_iter().collect(); - - let auth_header = &header_map["authorization"]; - - assert!(auth_header.contains("SignedHeaders=content-type;host;x-amz-date")); - } - - #[test] - fn test_bedrock_client_integration() { - let client = BedrockClient::new( - "AKIAIOSFODNN7EXAMPLE".to_string(), - "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(), - "us-east-1".to_string(), - ); - - let request = ConverseRequest { - model_id: "anthropic.claude-3-sonnet-20240229-v1:0".to_string(), - messages: vec![Message { - role: Role::User, - content: vec![ContentBlock::Text { - text: "Hello, how are you?".to_string(), - }], - }], - system: None, - inference_config: None, - tool_config: None, - guardrail_config: None, - additional_model_request_fields: None, - }; - - - let body = serde_json::to_string(&request).expect("Failed to serialize request"); - let host = format!("bedrock-runtime.{}.amazonaws.com", client.region); - - let headers_result = generate_sigv4_headers( - &client.access_key_id, - &client.secret_access_key, - &client.region, - "bedrock", - "POST", - "/model/anthropic.claude-3-sonnet-20240229-v1:0/converse", - &host, - &body, - ); - - assert!(headers_result.is_ok()); - let headers = headers_result.unwrap(); - - let header_names: Vec<&str> = headers.iter().map(|(k, _)| k.as_str()).collect(); - assert!(header_names.contains(&"authorization")); - assert!(header_names.contains(&"x-amz-date")); - assert!(header_names.contains(&"content-type")); - } -} \ No newline at end of file diff --git a/llm/bedrock/src/conversions.rs b/llm/bedrock/src/conversions.rs index 2d9215c8d..19b6854f9 100644 --- a/llm/bedrock/src/conversions.rs +++ b/llm/bedrock/src/conversions.rs @@ -1,14 +1,15 @@ use crate::client::{ - ContentBlock, ConverseRequest, ConverseResponse, ImageFormat, ImageSource as ClientImageSource, - InferenceConfig, Message as ClientMessage, Role as ClientRole, StopReason, SystemContentBlock, - Tool, ToolChoice, ToolConfig, ToolSpec, ToolResultContentBlock, ToolResultStatus, - ImageBlock, ToolUseBlock, ToolResultBlock, ToolInputSchema, ToolChoiceTool, + ContentBlock, ConverseRequest, ConverseResponse, ImageBlock, ImageFormat, + ImageSource as ClientImageSource, InferenceConfig, Message as ClientMessage, + Role as ClientRole, StopReason, SystemContentBlock, Tool, ToolChoice, ToolChoiceTool, + ToolConfig, ToolInputSchema, ToolResultBlock, ToolResultContentBlock, ToolResultStatus, + ToolSpec, ToolUseBlock, }; use base64::{engine::general_purpose, Engine as _}; use golem_llm::golem::llm::llm::{ ChatEvent, CompleteResponse, Config, ContentPart, Error, ErrorCode, FinishReason, - ImageReference, ImageSource, Message, ResponseMetadata, Role, ToolCall, - ToolDefinition, ToolResult, Usage, + ImageReference, ImageSource, Message, ResponseMetadata, Role, ToolCall, ToolDefinition, + ToolResult, Usage, }; use reqwest::{Client, Url}; use std::{collections::HashMap, fs, path::Path}; @@ -68,21 +69,17 @@ pub fn messages_to_request( for tool in &config.tools { tools.push(tool_definition_to_tool(tool)?); } - + let tool_choice = config.tool_choice.map(convert_tool_choice); - + if tools.is_empty() { None } else { - Some(ToolConfig { - tools, - tool_choice, - }) + Some(ToolConfig { tools, tool_choice }) } }; Ok(ConverseRequest { - model_id: config.model.clone(), messages: bedrock_messages, system: if system_messages.is_empty() { None @@ -98,7 +95,7 @@ pub fn messages_to_request( fn convert_tool_choice(tool_name: String) -> ToolChoice { use serde_json::Value; - + match tool_name.as_str() { "auto" => ToolChoice::Auto { auto: Value::Object(serde_json::Map::new()), @@ -130,18 +127,16 @@ pub fn process_response(response: ConverseResponse) -> ChatEvent { ImageFormat::Gif => "image/gif", ImageFormat::Webp => "image/webp", }; - contents.push(ContentPart::Image(ImageReference::Inline( - ImageSource { - data: decoded_data, - mime_type: mime_type.to_string(), - detail: None, - }, - ))); + contents.push(ContentPart::Image(ImageReference::Inline(ImageSource { + data: decoded_data, + mime_type: mime_type.to_string(), + detail: None, + }))); } Err(e) => { return ChatEvent::Error(Error { code: ErrorCode::InvalidRequest, - message: format!("Failed to decode base64 image data: {}", e), + message: format!("Failed to decode base64 image data: {e}"), provider_error_json: None, }); } @@ -160,7 +155,7 @@ pub fn process_response(response: ConverseResponse) -> ChatEvent { ChatEvent::ToolRequest(tool_calls) } else { let request_id = "bedrock-response".to_string(); - + let metadata = ResponseMetadata { finish_reason: Some(stop_reason_to_finish_reason(response.stop_reason)), usage: Some(convert_usage(response.usage)), @@ -178,9 +173,7 @@ pub fn process_response(response: ConverseResponse) -> ChatEvent { } } -pub fn tool_results_to_messages( - tool_results: Vec<(ToolCall, ToolResult)>, -) -> Vec { +pub fn tool_results_to_messages(tool_results: Vec<(ToolCall, ToolResult)>) -> Vec { let mut messages = Vec::new(); for (tool_call, tool_result) in tool_results { @@ -249,9 +242,7 @@ fn message_to_content(message: &Message) -> Result, Error> { for content_part in &message.content { match content_part { - ContentPart::Text(text) => result.push(ContentBlock::Text { - text: text.clone(), - }), + ContentPart::Text(text) => result.push(ContentBlock::Text { text: text.clone() }), ContentPart::Image(image_reference) => match image_reference { ImageReference::Url(image_url) => { let url = &image_url.url; @@ -260,38 +251,37 @@ fn message_to_content(message: &Message) -> Result, Error> { let client = Client::new(); let response = client.get(url).send().map_err(|e| Error { code: ErrorCode::InvalidRequest, - message: format!("Failed to fetch image from URL: {}", e), + message: format!("Failed to fetch image from URL: {e}"), provider_error_json: None, }); response.map(|r| { - format = match r.headers().get("Content-Type").unwrap().to_str().unwrap() { - "image/jpeg" => ImageFormat::Jpeg, - "image/png" => ImageFormat::Png, - "image/gif" => ImageFormat::Gif, - "image/webp" => ImageFormat::Webp, - _ => ImageFormat::Jpeg, - }; + format = + match r.headers().get("Content-Type").unwrap().to_str().unwrap() { + "image/jpeg" => ImageFormat::Jpeg, + "image/png" => ImageFormat::Png, + "image/gif" => ImageFormat::Gif, + "image/webp" => ImageFormat::Webp, + _ => ImageFormat::Jpeg, + }; r.bytes().unwrap().to_vec() }) } else { let path = Path::new(url); fs::read(path).map_err(|e| Error { code: ErrorCode::InvalidRequest, - message: format!("Failed to read image from path: {}", e), + message: format!("Failed to read image from path: {e}"), provider_error_json: None, }) }; - - let base64_data = general_purpose::STANDARD.encode(&bytes.unwrap()); + + let base64_data = general_purpose::STANDARD.encode(bytes.unwrap()); result.push(ContentBlock::Image { image: ImageBlock { format: ImageFormat::Png, - source: ClientImageSource { - bytes: base64_data, - }, + source: ClientImageSource { bytes: base64_data }, }, }); - }, + } ImageReference::Inline(image_source) => { let base64_data = general_purpose::STANDARD.encode(&image_source.data); let format = match image_source.mime_type.as_str() { @@ -305,9 +295,7 @@ fn message_to_content(message: &Message) -> Result, Error> { result.push(ContentBlock::Image { image: ImageBlock { format, - source: ClientImageSource { - bytes: base64_data, - }, + source: ClientImageSource { bytes: base64_data }, }, }); } @@ -323,9 +311,7 @@ fn message_to_system_content(message: &Message) -> Vec { for content_part in &message.content { match content_part { - ContentPart::Text(text) => result.push(SystemContentBlock::Text { - text: text.clone(), - }), + ContentPart::Text(text) => result.push(SystemContentBlock::Text { text: text.clone() }), ContentPart::Image(_) => {} } } @@ -335,7 +321,7 @@ fn message_to_system_content(message: &Message) -> Vec { fn tool_definition_to_tool(tool: &ToolDefinition) -> Result { use serde_json::Value; - + let schema_value = if tool.parameters_schema.trim().is_empty() { serde_json::json!({ "type": "object", @@ -354,14 +340,12 @@ fn tool_definition_to_tool(tool: &ToolDefinition) -> Result { } } }; - + Ok(Tool { tool_spec: ToolSpec { name: tool.name.clone(), description: tool.description.clone().unwrap_or_default(), - input_schema: ToolInputSchema { - json: schema_value, - }, + input_schema: ToolInputSchema { json: schema_value }, }, }) -} \ No newline at end of file +} diff --git a/llm/bedrock/src/lib.rs b/llm/bedrock/src/lib.rs index 7c16f2032..ddc2b3186 100644 --- a/llm/bedrock/src/lib.rs +++ b/llm/bedrock/src/lib.rs @@ -11,8 +11,8 @@ use golem_llm::config::with_config_keys; use golem_llm::durability::{DurableLLM, ExtendedGuest}; use golem_llm::event_source::EventSource; use golem_llm::golem::llm::llm::{ - ChatEvent, ChatStream, Config, ContentPart, Error, Guest, Message, ResponseMetadata, - Role, StreamDelta, StreamEvent, ToolCall, ToolResult, + ChatEvent, ChatStream, Config, ContentPart, Error, Guest, Message, ResponseMetadata, Role, + StreamDelta, StreamEvent, ToolCall, ToolResult, }; use golem_llm::LOGGING_STATE; use golem_rust::wasm_rpc::Pollable; @@ -28,8 +28,7 @@ struct BedrockChatStream { response_metadata: RefCell, } - -/// [2025-06-29T18:11:10.458Z] [TRACE ] [golem_llm_bedrock] llm/bedrock/src/lib.rs:84: Received raw stream event: +/// [2025-06-29T18:11:10.458Z] [TRACE ] [golem_llm_bedrock] llm/bedrock/src/lib.rs:84: Received raw stream event: /// { /// "contentBlockIndex":1, /// "delta":{ @@ -46,17 +45,17 @@ struct BedrockChatStream { /// }, /// "p":"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX" /// } -/// +/// #[derive(Debug, Deserialize)] #[serde(untagged)] pub enum Delta { ToolUse { #[serde(rename = "toolUse")] - tool_use: ToolUse, + tool_use: ToolUse, }, Text { - text: String, + text: String, }, } @@ -90,11 +89,10 @@ pub struct MetadataMessage { pub struct EventContentBlock { #[serde(rename = "contentBlockIndex")] pub content_block_index: u32, - pub delta : Delta, + pub delta: Delta, pub p: String, } - impl BedrockChatStream { pub fn new(stream: EventSource) -> LlmChatStream { LlmChatStream::new(BedrockChatStream { @@ -153,7 +151,7 @@ impl LlmChatStreamState for BedrockChatStream { let json: Value = serde_json::from_str(raw) .map_err(|err| format!("Failed to deserialize stream event: {err}"))?; - + // 1. Handle content block delta messages (contentBlockIndex + delta) if json.get("contentBlockIndex").is_some() && json.get("delta").is_some() { match serde_json::from_value::(json.clone()) { @@ -176,7 +174,7 @@ impl LlmChatStreamState for BedrockChatStream { } } Err(err) => { - trace!("Failed to parse as EventContentBlock: {}", err); + trace!("Failed to parse as EventContentBlock: {err}"); // Continue to other parsing attempts } } diff --git a/llm/grok/src/conversions.rs b/llm/grok/src/conversions.rs index 68a5d570c..129c128ad 100644 --- a/llm/grok/src/conversions.rs +++ b/llm/grok/src/conversions.rs @@ -183,7 +183,7 @@ fn convert_content_parts(contents: Vec) -> crate::client::Content { let media_type = &image_source.mime_type; // This is already a string result.push(crate::client::ContentPart::ImageInput { image_url: crate::client::ImageUrl { - url: format!("data:{};base64,{}", media_type, base64_data), + url: format!("data:{media_type};base64,{base64_data}"), detail: image_source.detail.map(|d| d.into()), }, }); diff --git a/llm/llm/src/config.rs b/llm/llm/src/config.rs index 43584ad32..461010a06 100644 --- a/llm/llm/src/config.rs +++ b/llm/llm/src/config.rs @@ -30,7 +30,7 @@ pub fn with_config_keys( succeed: impl FnOnce(HashMap) -> R, ) -> R { let mut values = HashMap::new(); - + for key in keys { match std::env::var(key) { Ok(value) => { @@ -46,6 +46,6 @@ pub fn with_config_keys( } } } - + succeed(values) } diff --git a/llm/llm/src/event_source/aws_eventstream.rs b/llm/llm/src/event_source/aws_eventstream.rs index bc5514392..1d8e42520 100644 --- a/llm/llm/src/event_source/aws_eventstream.rs +++ b/llm/llm/src/event_source/aws_eventstream.rs @@ -168,7 +168,7 @@ fn try_decode_message( } Ok(DecodedFrame::Incomplete) => Ok(None), Err(err) => Err(AwsEventStreamError::Parser(nom::error::Error::new( - format!("AWS EventStream decode error: {}", err), + format!("AWS EventStream decode error: {err}"), nom::error::ErrorKind::Tag, ))), } diff --git a/llm/llm/src/event_source/mod.rs b/llm/llm/src/event_source/mod.rs index 8ed4689f3..4c8962503 100644 --- a/llm/llm/src/event_source/mod.rs +++ b/llm/llm/src/event_source/mod.rs @@ -1,8 +1,8 @@ // Based on https://github.com/jpopesculian/eventsource-stream and https://github.com/jpopesculian/reqwest-eventsource // modified to use the wasi-http based reqwest, and wasi pollables -pub mod error; mod aws_eventstream; +pub mod error; mod event_stream; mod message_event; mod ndjson_stream; @@ -151,7 +151,11 @@ fn check_response(response: Response) -> Result { matches!( (mime_type.type_(), mime_type.subtype()), (mime::TEXT, mime::EVENT_STREAM) - ) || mime_type.subtype().as_str().contains("ndjson") || content_type.to_str().unwrap_or("").contains("vnd.amazon.eventstream") + ) || mime_type.subtype().as_str().contains("ndjson") + || content_type + .to_str() + .unwrap_or("") + .contains("vnd.amazon.eventstream") }) .unwrap_or(false) { diff --git a/llm/llm/src/event_source/ndjson_stream.rs b/llm/llm/src/event_source/ndjson_stream.rs index e2f4cc1b2..1b8ef3773 100644 --- a/llm/llm/src/event_source/ndjson_stream.rs +++ b/llm/llm/src/event_source/ndjson_stream.rs @@ -126,7 +126,7 @@ fn try_parse_line( return Ok(None); } - trace!("Parsed NDJSON line: {}", line); + trace!("Parsed NDJSON line: {line}"); // Create a MessageEvent with the JSON line as data let event = MessageEvent { diff --git a/llm/llm/src/event_source/stream.rs b/llm/llm/src/event_source/stream.rs index c145f2773..02d65f6c4 100644 --- a/llm/llm/src/event_source/stream.rs +++ b/llm/llm/src/event_source/stream.rs @@ -2,7 +2,7 @@ use core::fmt; use std::{string::FromUtf8Error, task::Poll}; use super::{ - aws_eventstream::AwsEventStream, event_stream::EventStream, ndjson_stream::NdJsonStream, + aws_eventstream::AwsEventStream, event_stream::EventStream, ndjson_stream::NdJsonStream, utf8_stream::Utf8StreamError, MessageEvent, }; use golem_rust::{ @@ -11,6 +11,7 @@ use golem_rust::{ }; use nom::error::Error as NomError; +#[allow(clippy::enum_variant_names)] pub enum StreamType { EventStream(EventStream), NdJsonStream(NdJsonStream), @@ -57,9 +58,9 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Utf8(err) => f.write_fmt(format_args!("UTF8 error: {}", err)), - Self::Parser(err) => f.write_fmt(format_args!("Parse error: {}", err)), - Self::Transport(err) => f.write_fmt(format_args!("Transport error: {}", err)), + Self::Utf8(err) => f.write_fmt(format_args!("UTF8 error: {err}")), + Self::Parser(err) => f.write_fmt(format_args!("Parse error: {err}")), + Self::Transport(err) => f.write_fmt(format_args!("Transport error: {err}")), } } } diff --git a/llm/ollama/src/client.rs b/llm/ollama/src/client.rs index e9514a8dd..e2901e70c 100644 --- a/llm/ollama/src/client.rs +++ b/llm/ollama/src/client.rs @@ -335,7 +335,7 @@ pub fn image_to_base64(source: &str) -> Result Error { Error { code: ErrorCode::InternalError, - message: format!("{}: {}", context, err), + message: format!("{context}: {err}"), provider_error_json: None, } } diff --git a/llm/ollama/src/conversions.rs b/llm/ollama/src/conversions.rs index b1db65c61..8d64e954f 100644 --- a/llm/ollama/src/conversions.rs +++ b/llm/ollama/src/conversions.rs @@ -214,7 +214,7 @@ pub fn process_response(response: CompletionsResponse) -> ChatEvent { }; ChatEvent::Message(CompleteResponse { - id: format!("ollama-{}", timestamp), + id: format!("ollama-{timestamp}"), content, tool_calls, metadata, diff --git a/llm/openai/src/conversions.rs b/llm/openai/src/conversions.rs index 43694c0f3..a4989b0c1 100644 --- a/llm/openai/src/conversions.rs +++ b/llm/openai/src/conversions.rs @@ -138,7 +138,7 @@ pub fn content_part_to_inner_input_item(content_part: ContentPart) -> InnerInput ImageReference::Inline(image_source) => { let base64_data = general_purpose::STANDARD.encode(&image_source.data); let mime_type = &image_source.mime_type; // This is already a string - let data_url = format!("data:{};base64,{}", mime_type, base64_data); + let data_url = format!("data:{mime_type};base64,{base64_data}"); InnerInputItem::ImageInput { image_url: data_url, diff --git a/llm/openrouter/src/conversions.rs b/llm/openrouter/src/conversions.rs index d4db2d34c..61b5f973b 100644 --- a/llm/openrouter/src/conversions.rs +++ b/llm/openrouter/src/conversions.rs @@ -184,7 +184,7 @@ fn convert_content_parts(contents: Vec) -> crate::client::Content { let media_type = &image_source.mime_type; // This is already a string result.push(crate::client::ContentPart::ImageInput { image_url: crate::client::ImageUrl { - url: format!("data:{};base64,{}", media_type, base64_data), + url: format!("data:{media_type};base64,{base64_data}"), detail: image_source.detail.map(|d| d.into()), }, }); From 7af953794b6cafb1c51179089deab8346cf9f6d8 Mon Sep 17 00:00:00 2001 From: Rutik7066 Date: Mon, 7 Jul 2025 22:24:00 +0000 Subject: [PATCH 7/8] stream fix --- Makefile.toml | 7 +- llm/bedrock/src/lib.rs | 168 +++++++++----------- llm/llm/src/event_source/aws_eventstream.rs | 17 +- 3 files changed, 89 insertions(+), 103 deletions(-) diff --git a/Makefile.toml b/Makefile.toml index bfc2c6acb..74be1010b 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -3,8 +3,11 @@ default_to_workspace = false skip_core_tasks = true [tasks.clean] -command = "cargo" -args = ["clean"] +script = ''' +find . -type d \( -name "target" -o -name "wit-generated" -o -name "golem-temp" \) -exec rm -rf {} + +find . -name "bindings.rs" -delete +cargo clean +''' [tasks.unit-tests] command = "cargo" diff --git a/llm/bedrock/src/lib.rs b/llm/bedrock/src/lib.rs index ddc2b3186..72f2f96db 100644 --- a/llm/bedrock/src/lib.rs +++ b/llm/bedrock/src/lib.rs @@ -25,27 +25,18 @@ struct BedrockChatStream { stream: RefCell>, failure: Option, finished: RefCell, - response_metadata: RefCell, } -/// [2025-06-29T18:11:10.458Z] [TRACE ] [golem_llm_bedrock] llm/bedrock/src/lib.rs:84: Received raw stream event: -/// { -/// "contentBlockIndex":1, -/// "delta":{ -/// "toolUse":{ -/// "input":" 10 -/// }"}}, -/// "p":"abcdefghijklmnopqrstuvwxyzAB" -/// } -/// { -/// "contentBlockIndex":0, -/// "delta": -/// { -/// "text":" German" -/// }, -/// "p":"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX" -/// } -/// +#[derive(Debug, Deserialize)] +pub struct EventContentBlock { + #[serde(rename = "contentBlockIndex")] + pub content_block_index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub delta: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub start: Option, + pub p: String, +} #[derive(Debug, Deserialize)] #[serde(untagged)] @@ -64,33 +55,38 @@ pub struct ToolUse { pub input: String, } -// Additional structs for different message types #[derive(Debug, Deserialize)] -pub struct MessageStart { - pub p: String, - pub role: String, +pub struct ToolUseStart { + #[serde(rename = "toolUse")] + pub tool_use: ToolUseInfo, } #[derive(Debug, Deserialize)] -pub struct MessageStop { - pub p: String, - #[serde(rename = "stopReason")] - pub stop_reason: String, +pub struct ToolUseInfo { + pub name: String, + #[serde(rename = "toolUseId")] + pub tool_use_id: String, } #[derive(Debug, Deserialize)] pub struct MetadataMessage { pub p: String, - pub usage: Option, - pub metrics: Option, + pub usage: Option, + pub metrics: Option, } #[derive(Debug, Deserialize)] -pub struct EventContentBlock { - #[serde(rename = "contentBlockIndex")] - pub content_block_index: u32, - pub delta: Delta, - pub p: String, +pub struct Metrics { + #[serde(rename = "latencyMs")] + pub latency_ms: u32, +} + +#[derive(Debug, Deserialize)] +pub struct Usage { + #[serde(rename = "outputTokens")] + pub output_tokens: u32, + #[serde(rename = "totalTokens")] + pub total_tokens: u32, } impl BedrockChatStream { @@ -99,13 +95,6 @@ impl BedrockChatStream { stream: RefCell::new(Some(stream)), failure: None, finished: RefCell::new(false), - response_metadata: RefCell::new(ResponseMetadata { - finish_reason: None, - usage: None, - provider_id: None, - timestamp: None, - provider_metadata_json: None, - }), }) } @@ -114,13 +103,6 @@ impl BedrockChatStream { stream: RefCell::new(None), failure: Some(error), finished: RefCell::new(false), - response_metadata: RefCell::new(ResponseMetadata { - finish_reason: None, - usage: None, - provider_id: None, - timestamp: None, - provider_metadata_json: None, - }), }) } } @@ -152,11 +134,38 @@ impl LlmChatStreamState for BedrockChatStream { let json: Value = serde_json::from_str(raw) .map_err(|err| format!("Failed to deserialize stream event: {err}"))?; - // 1. Handle content block delta messages (contentBlockIndex + delta) - if json.get("contentBlockIndex").is_some() && json.get("delta").is_some() { - match serde_json::from_value::(json.clone()) { - Ok(event_content_block) => { - match event_content_block.delta { + if json.get("role").is_some() { + return Ok(Some(StreamEvent::Delta(StreamDelta { + content: Some(vec![ContentPart::Text(json.to_string())]), + tool_calls: None, + }))); + } + + if json.get("usage").is_some() || json.get("metrics").is_some() { + if let Ok(metadata) = serde_json::from_value::(json.clone()) { + let usage = metadata.usage.unwrap(); + return Ok(Some(StreamEvent::Finish(ResponseMetadata { + finish_reason: None, + usage: Some(golem_llm::golem::llm::llm::Usage { + input_tokens: Some(usage.total_tokens - usage.output_tokens), + output_tokens: Some(usage.output_tokens), + total_tokens: Some(usage.total_tokens), + }), + provider_id: None, + timestamp: None, + provider_metadata_json: if metadata.metrics.is_some() { + Some(format!("{:?}", metadata.metrics.unwrap())) + }else { + None + }, + }))); + } + } + + match serde_json::from_value::(json.clone()) { + Ok(event_content_block) => { + if let Some(delta) = event_content_block.delta { + match delta { Delta::Text { text } => { return Ok(Some(StreamEvent::Delta(StreamDelta { content: Some(vec![ContentPart::Text(text)]), @@ -164,8 +173,6 @@ impl LlmChatStreamState for BedrockChatStream { }))); } Delta::ToolUse { tool_use } => { - // Handle tool use delta - this would need tool call ID and name from earlier message - // For now, just return the input as text return Ok(Some(StreamEvent::Delta(StreamDelta { content: Some(vec![ContentPart::Text(tool_use.input)]), tool_calls: None, @@ -173,51 +180,20 @@ impl LlmChatStreamState for BedrockChatStream { } } } - Err(err) => { - trace!("Failed to parse as EventContentBlock: {err}"); - // Continue to other parsing attempts + if let Some(tool_use_start) = event_content_block.start { + return Ok(Some(StreamEvent::Delta(StreamDelta { + content: Some(vec![]), + tool_calls: Some(vec![ToolCall { + id: tool_use_start.tool_use.tool_use_id, + name: tool_use_start.tool_use.name, + arguments_json: "".to_string(), + }]), + }))); } } + Err(_) => {} } - // 3. Handle message start (role + p) - if json.get("role").is_some() { - if let Ok(_message_start) = serde_json::from_value::(json.clone()) { - // Message start event - just metadata, no content to return - return Ok(None); - } - } - - // 4. Handle message stop with stopReason - if json.get("stopReason").is_some() { - if let Ok(message_stop) = serde_json::from_value::(json.clone()) { - let stop_reason = match message_stop.stop_reason.as_str() { - "end_turn" => crate::client::StopReason::EndTurn, - "tool_use" => crate::client::StopReason::ToolUse, - "max_tokens" => crate::client::StopReason::MaxTokens, - "stop_sequence" => crate::client::StopReason::StopSequence, - "guardrail_intervened" => crate::client::StopReason::GuardrailIntervened, - "content_filtered" => crate::client::StopReason::ContentFiltered, - _ => crate::client::StopReason::EndTurn, - }; - self.response_metadata.borrow_mut().finish_reason = - Some(stop_reason_to_finish_reason(stop_reason)); - - let response_metadata = self.response_metadata.borrow().clone(); - return Ok(Some(StreamEvent::Finish(response_metadata))); - } - } - - // 5. Handle metadata messages with usage/metrics - if json.get("usage").is_some() || json.get("metrics").is_some() { - if let Ok(metadata) = serde_json::from_value::(json.clone()) { - if let Some(usage) = metadata.usage { - self.response_metadata.borrow_mut().usage = Some(convert_usage(usage)); - } - // Metadata processed, no event to return - return Ok(None); - } - } Ok(None) } } diff --git a/llm/llm/src/event_source/aws_eventstream.rs b/llm/llm/src/event_source/aws_eventstream.rs index 1d8e42520..6ff0463fd 100644 --- a/llm/llm/src/event_source/aws_eventstream.rs +++ b/llm/llm/src/event_source/aws_eventstream.rs @@ -67,8 +67,12 @@ impl LlmStream for AwsEventStream { ) -> Poll>>> { trace!("Polling for next AWS EventStream event"); - if let Some(event) = try_decode_message(self)? { - return Poll::Ready(Some(Ok(event))); + match try_decode_message(self) { + Ok(Some(event)) => { + return Poll::Ready(Some(Ok(event))); + }, + Err(err) => return Poll::Ready(Some(Err(err))), + _ => {} } if self.state.is_terminated() { @@ -89,9 +93,12 @@ impl LlmStream for AwsEventStream { self.buffer.extend_from_slice(&bytes); - // Try to decode complete messages from the updated buffer - if let Some(event) = try_decode_message(self)? { - return Poll::Ready(Some(Ok(event))); + match try_decode_message(self) { + Ok(Some(event)) => { + return Poll::Ready(Some(Ok(event))); + }, + Err(err) => return Poll::Ready(Some(Err(err))), + _ => {} } } Err(StreamError::Closed) => { From 4cd5469f3d7378b50c3e834bf121e37512fb049e Mon Sep 17 00:00:00 2001 From: Rutik7066 Date: Mon, 7 Jul 2025 22:37:10 +0000 Subject: [PATCH 8/8] lint --- llm/bedrock/src/lib.rs | 58 +++++++++------------ llm/llm/src/event_source/aws_eventstream.rs | 4 +- 2 files changed, 28 insertions(+), 34 deletions(-) diff --git a/llm/bedrock/src/lib.rs b/llm/bedrock/src/lib.rs index 72f2f96db..a0bca3f64 100644 --- a/llm/bedrock/src/lib.rs +++ b/llm/bedrock/src/lib.rs @@ -2,10 +2,7 @@ mod client; mod conversions; use crate::client::{BedrockClient, ConverseRequest}; -use crate::conversions::{ - convert_usage, messages_to_request, process_response, stop_reason_to_finish_reason, - tool_results_to_messages, -}; +use crate::conversions::{messages_to_request, process_response, tool_results_to_messages}; use golem_llm::chat_stream::{LlmChatStream, LlmChatStreamState}; use golem_llm::config::with_config_keys; use golem_llm::durability::{DurableLLM, ExtendedGuest}; @@ -155,43 +152,40 @@ impl LlmChatStreamState for BedrockChatStream { timestamp: None, provider_metadata_json: if metadata.metrics.is_some() { Some(format!("{:?}", metadata.metrics.unwrap())) - }else { + } else { None }, }))); } } - match serde_json::from_value::(json.clone()) { - Ok(event_content_block) => { - if let Some(delta) = event_content_block.delta { - match delta { - Delta::Text { text } => { - return Ok(Some(StreamEvent::Delta(StreamDelta { - content: Some(vec![ContentPart::Text(text)]), - tool_calls: None, - }))); - } - Delta::ToolUse { tool_use } => { - return Ok(Some(StreamEvent::Delta(StreamDelta { - content: Some(vec![ContentPart::Text(tool_use.input)]), - tool_calls: None, - }))); - } + if let Ok(event_content_block) = serde_json::from_value::(json.clone()) { + if let Some(delta) = event_content_block.delta { + match delta { + Delta::Text { text } => { + return Ok(Some(StreamEvent::Delta(StreamDelta { + content: Some(vec![ContentPart::Text(text)]), + tool_calls: None, + }))); + } + Delta::ToolUse { tool_use } => { + return Ok(Some(StreamEvent::Delta(StreamDelta { + content: Some(vec![ContentPart::Text(tool_use.input)]), + tool_calls: None, + }))); } } - if let Some(tool_use_start) = event_content_block.start { - return Ok(Some(StreamEvent::Delta(StreamDelta { - content: Some(vec![]), - tool_calls: Some(vec![ToolCall { - id: tool_use_start.tool_use.tool_use_id, - name: tool_use_start.tool_use.name, - arguments_json: "".to_string(), - }]), - }))); - } } - Err(_) => {} + if let Some(tool_use_start) = event_content_block.start { + return Ok(Some(StreamEvent::Delta(StreamDelta { + content: Some(vec![]), + tool_calls: Some(vec![ToolCall { + id: tool_use_start.tool_use.tool_use_id, + name: tool_use_start.tool_use.name, + arguments_json: "".to_string(), + }]), + }))); + } } Ok(None) diff --git a/llm/llm/src/event_source/aws_eventstream.rs b/llm/llm/src/event_source/aws_eventstream.rs index 6ff0463fd..efd0f3748 100644 --- a/llm/llm/src/event_source/aws_eventstream.rs +++ b/llm/llm/src/event_source/aws_eventstream.rs @@ -70,7 +70,7 @@ impl LlmStream for AwsEventStream { match try_decode_message(self) { Ok(Some(event)) => { return Poll::Ready(Some(Ok(event))); - }, + } Err(err) => return Poll::Ready(Some(Err(err))), _ => {} } @@ -96,7 +96,7 @@ impl LlmStream for AwsEventStream { match try_decode_message(self) { Ok(Some(event)) => { return Poll::Ready(Some(Ok(event))); - }, + } Err(err) => return Poll::Ready(Some(Err(err))), _ => {} }