diff --git a/Cargo.lock b/Cargo.lock index 0865d6ade..d5243af1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,6 +47,35 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "aws-smithy-eventstream" +version = "0.60.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "338a3642c399c0a5d157648426110e199ca7fd1c689cc395676b81aa563700c4" +dependencies = [ + "aws-smithy-types", + "bytes", + "crc32fast", +] + +[[package]] +name = "aws-smithy-types" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d498595448e43de7f4296b7b7a18a8a02c61ec9349128c80a368f7c3b4ab11a8" +dependencies = [ + "base64-simd", + "bytes", + "bytes-utils", + "itoa", + "num-integer", + "pin-project-lite", + "pin-utils", + "ryu", + "serde", + "time", +] + [[package]] name = "base64" version = "0.21.7" @@ -59,12 +88,31 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "339abbe78e73178762e23bea9dfd08e697eb3f3301cd4be981c0f78ba5859195" +dependencies = [ + "outref", + "vsimd", +] + [[package]] name = "bitflags" version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.17.0" @@ -77,6 +125,16 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +[[package]] +name = "bytes-utils" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" +dependencies = [ + "bytes", + "either", +] + [[package]] name = "camino" version = "1.1.9" @@ -145,6 +203,15 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.4.2" @@ -154,6 +221,36 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "deranged" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" +dependencies = [ + "powerfmt", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -165,6 +262,12 @@ dependencies = [ "syn", ] +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -300,6 +403,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.3.3" @@ -336,6 +449,9 @@ dependencies = [ name = "golem-llm" version = "0.0.0" dependencies = [ + "aws-smithy-eventstream", + "aws-smithy-types", + "base64 0.22.1", "golem-rust", "log", "mime", @@ -360,6 +476,24 @@ dependencies = [ "wit-bindgen-rt 0.40.0", ] +[[package]] +name = "golem-llm-bedrock" +version = "0.0.0" +dependencies = [ + "base64 0.22.1", + "golem-llm", + "golem-rust", + "hmac", + "log", + "percent-encoding", + "reqwest", + "serde", + "serde_json", + "sha2", + "time", + "wit-bindgen-rt 0.40.0", +] + [[package]] name = "golem-llm-grok" version = "0.0.0" @@ -481,6 +615,15 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "http" version = "1.3.1" @@ -733,6 +876,21 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -748,6 +906,12 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "outref" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" + [[package]] name = "percent-encoding" version = "2.3.1" @@ -775,6 +939,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "prettyplease" version = "0.2.32" @@ -902,6 +1072,17 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shlex" version = "1.3.0" @@ -938,6 +1119,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.101" @@ -989,6 +1176,37 @@ dependencies = [ "syn", ] +[[package]] +name = "time" +version = "0.3.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" + +[[package]] +name = "time-macros" +version = "0.2.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tinystr" version = "0.8.1" @@ -1011,6 +1229,12 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" +[[package]] +name = "typenum" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" + [[package]] name = "unicase" version = "2.8.1" @@ -1065,6 +1289,18 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "vsimd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" + [[package]] name = "wasi" version = "0.14.2+wasi-0.2.4" diff --git a/Cargo.toml b/Cargo.toml index 7bea1e1e5..a71042e4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ resolver = "2" members = [ "llm/llm", "llm/anthropic", + "llm/bedrock", "llm/grok", "llm/ollama", "llm/openai", @@ -25,4 +26,4 @@ reqwest = { git = "https://github.com/golemcloud/reqwest", branch = "update-may- serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0" } wit-bindgen-rt = { version = "0.40.0", features = ["bitflags"] } -base64 = { version = "0.22.1" } +base64 = { version = "0.22.1" } \ No newline at end of file diff --git a/Makefile.toml b/Makefile.toml index cc443bc6a..74be1010b 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -3,8 +3,11 @@ default_to_workspace = false skip_core_tasks = true [tasks.clean] -command = "cargo" -args = ["clean"] +script = ''' +find . -type d \( -name "target" -o -name "wit-generated" -o -name "golem-temp" \) -exec rm -rf {} + +find . -name "bindings.rs" -delete +cargo clean +''' [tasks.unit-tests] command = "cargo" @@ -137,7 +140,7 @@ script = ''' is_portable = eq ${1} "--portable" -targets = array llm_openai llm_anthropic llm_grok llm_openrouter llm_ollama +targets = array llm_openai llm_anthropic llm_grok llm_openrouter llm_ollama llm_bedrock for target in ${targets} if is_portable cp target/wasm32-wasip1/debug/golem_${target}.wasm components/debug/golem_${target}-portable.wasm @@ -153,7 +156,7 @@ script = ''' is_portable = eq ${1} "--portable" -targets = array llm_openai llm_anthropic llm_grok llm_openrouter llm_ollama +targets = array llm_openai llm_anthropic llm_grok llm_openrouter llm_ollama llm_bedrock for target in ${targets} if is_portable cp target/wasm32-wasip1/release/golem_${target}.wasm components/release/golem_${target}-portable.wasm diff --git a/llm/Makefile.toml b/llm/Makefile.toml index 5b7b92a89..5ca26320a 100644 --- a/llm/Makefile.toml +++ b/llm/Makefile.toml @@ -5,6 +5,7 @@ skip_core_tasks = true [tasks.build] run_task = { name = [ "build-anthropic", + "build-bedrock", "build-grok", "build-openai", "build-openrouter", @@ -14,6 +15,7 @@ run_task = { name = [ [tasks.build-portable] run_task = { name = [ "build-anthropic-portable", + "build-bedrock-portable", "build-grok-portable", "build-openai-portable", "build-openrouter-portable", @@ -23,6 +25,7 @@ run_task = { name = [ [tasks.release-build] run_task = { name = [ "release-build-anthropic", + "release-build-bedrock", "release-build-grok", "release-build-openai", "release-build-openrouter", @@ -32,6 +35,7 @@ run_task = { name = [ [tasks.release-build-portable] run_task = { name = [ "release-build-anthropic-portable", + "release-build-bedrock-portable", "release-build-grok-portable", "release-build-openai-portable", "release-build-openrouter-portable", @@ -60,6 +64,16 @@ install_crate = { crate_name = "cargo-component", version = "0.20.0" } command = "cargo-component" args = ["build", "-p", "golem-llm-anthropic", "--no-default-features"] +[tasks.build-bedrock] +install_crate = { crate_name = "cargo-component", version = "0.20.0" } +command = "cargo-component" +args = ["build", "-p", "golem-llm-bedrock"] + +[tasks.build-bedrock-portable] +install_crate = { crate_name = "cargo-component", version = "0.20.0" } +command = "cargo-component" +args = ["build", "-p", "golem-llm-bedrock", "--no-default-features"] + [tasks.build-grok] install_crate = { crate_name = "cargo-component", version = "0.20.0" } command = "cargo-component" @@ -117,6 +131,22 @@ args = [ "--no-default-features", ] +[tasks.release-build-bedrock] +install_crate = { crate_name = "cargo-component", version = "0.20.0" } +command = "cargo-component" +args = ["build", "-p", "golem-llm-bedrock", "--release"] + +[tasks.release-build-bedrock-portable] +install_crate = { crate_name = "cargo-component", version = "0.20.0" } +command = "cargo-component" +args = [ + "build", + "-p", + "golem-llm-bedrock", + "--release", + "--no-default-features", +] + [tasks.release-build-grok] install_crate = { crate_name = "cargo-component", version = "0.20.0" } command = "cargo-component" @@ -163,7 +193,7 @@ dependencies = ["wit-update"] script_runner = "@duckscript" script = """ -modules = array llm openai anthropic grok openrouter ollama +modules = array llm openai anthropic bedrock grok openrouter ollama for module in ${modules} rm -r ${module}/wit/deps @@ -200,6 +230,8 @@ golem-cli --version golem-cli app clean golem-cli app build -b anthropic-debug golem-cli app clean +golem-cli app build -b bedrock-debug +golem-cli app clean golem-cli app build -b grok-debug golem-cli app clean golem-cli app build -b openai-debug diff --git a/llm/anthropic/src/bindings.rs b/llm/anthropic/src/bindings.rs index 70c5f1fd5..1a54d6167 100644 --- a/llm/anthropic/src/bindings.rs +++ b/llm/anthropic/src/bindings.rs @@ -1,12 +1,15 @@ -// Generated by `wit-bindgen` 0.36.0. DO NOT EDIT! +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! // Options used: // * runtime_path: "wit_bindgen_rt" // * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" // * generate_unused_types use golem_llm::golem::llm::llm as __with_name0; #[cfg(target_arch = "wasm32")] -#[link_section = "component-type:wit-bindgen:0.36.0:golem:llm-anthropic@1.0.0:llm-library:encoded world"] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-anthropic@1.0.0:llm-library:encoded world" +)] #[doc(hidden)] +#[allow(clippy::octal_escapes)] pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1762] = *b"\ \0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xe0\x0c\x01A\x02\x01\ A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ @@ -43,8 +46,8 @@ ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send \0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0%golem:\ llm-anthropic/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09\ -producers\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rus\ -t\x060.36.0"; +producers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rus\ +t\x060.41.0"; #[inline(never)] #[doc(hidden)] pub fn __link_custom_section_describing_imports() { diff --git a/llm/anthropic/src/conversions.rs b/llm/anthropic/src/conversions.rs index e332f1391..e7d3175a0 100644 --- a/llm/anthropic/src/conversions.rs +++ b/llm/anthropic/src/conversions.rs @@ -130,7 +130,7 @@ pub fn process_response(response: MessagesResponse) -> ChatEvent { Err(e) => { return ChatEvent::Error(Error { code: ErrorCode::InvalidRequest, - message: format!("Failed to decode base64 image data: {}", e), + message: format!("Failed to decode base64 image data: {e}"), provider_error_json: None, }); } diff --git a/llm/bedrock/Cargo.toml b/llm/bedrock/Cargo.toml new file mode 100644 index 000000000..1ca88260b --- /dev/null +++ b/llm/bedrock/Cargo.toml @@ -0,0 +1,47 @@ +[package] +name = "golem-llm-bedrock" +version = "0.0.0" +edition = "2021" +license = "Apache-2.0" +homepage = "https://golem.cloud" +repository = "https://github.com/golemcloud/golem-llm" +description = "WebAssembly component for working with AWS Bedrock APIs, with special support for Golem Cloud" + +[lib] +path = "src/lib.rs" +crate-type = ["cdylib"] + +[features] +default = ["durability"] +durability = ["golem-rust/durability", "golem-llm/durability"] + +[dependencies] +golem-llm = { workspace = true } + +golem-rust = { workspace = true } +log = { workspace = true } +reqwest = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +wit-bindgen-rt = { workspace = true } +base64 = { workspace = true } +hmac = "0.12" +sha2 = "0.10" +time = { version = "0.3", features = ["formatting"] } +percent-encoding = "2.3" + +[package.metadata.component] +package = "golem:llm-bedrock" + +[package.metadata.component.bindings] +generate_unused_types = true + +[package.metadata.component.bindings.with] +"golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" + +[package.metadata.component.target] +path = "wit" + +[package.metadata.component.target.dependencies] +"golem:llm" = { path = "wit/deps/golem-llm" } +"wasi:io" = { path = "wit/deps/wasi:io" } \ No newline at end of file diff --git a/llm/bedrock/src/bindings.rs b/llm/bedrock/src/bindings.rs new file mode 100644 index 000000000..c173c3169 --- /dev/null +++ b/llm/bedrock/src/bindings.rs @@ -0,0 +1,55 @@ +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! +// Options used: +// * runtime_path: "wit_bindgen_rt" +// * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" +// * generate_unused_types +use golem_llm::golem::llm::llm as __with_name0; +#[cfg(target_arch = "wasm32")] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-bedrock@1.0.0:llm-library:encoded world" +)] +#[doc(hidden)] +#[allow(clippy::octal_escapes)] +pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1760] = *b"\ +\0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xde\x0c\x01A\x02\x01\ +A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ +\x01m\x06\x0finvalid-request\x15authentication-failed\x13rate-limit-exceeded\x0e\ +internal-error\x0bunsupported\x07unknown\x04\0\x0aerror-code\x03\0\x02\x01m\x06\x04\ +stop\x06length\x0atool-calls\x0econtent-filter\x05error\x05other\x04\0\x0dfinish\ +-reason\x03\0\x04\x01m\x03\x03low\x04high\x04auto\x04\0\x0cimage-detail\x03\0\x06\ +\x01k\x07\x01r\x02\x03urls\x06detail\x08\x04\0\x09image-url\x03\0\x09\x01p}\x01r\ +\x03\x04data\x0b\x09mime-types\x06detail\x08\x04\0\x0cimage-source\x03\0\x0c\x01\ +q\x02\x03url\x01\x0a\0\x06inline\x01\x0d\0\x04\0\x0fimage-reference\x03\0\x0e\x01\ +q\x02\x04text\x01s\0\x05image\x01\x0f\0\x04\0\x0ccontent-part\x03\0\x10\x01ks\x01\ +p\x11\x01r\x03\x04role\x01\x04name\x12\x07content\x13\x04\0\x07message\x03\0\x14\ +\x01r\x03\x04names\x0bdescription\x12\x11parameters-schemas\x04\0\x0ftool-defini\ +tion\x03\0\x16\x01r\x03\x02ids\x04names\x0earguments-jsons\x04\0\x09tool-call\x03\ +\0\x18\x01ky\x01r\x04\x02ids\x04names\x0bresult-jsons\x11execution-time-ms\x1a\x04\ +\0\x0ctool-success\x03\0\x1b\x01r\x04\x02ids\x04names\x0derror-messages\x0aerror\ +-code\x12\x04\0\x0ctool-failure\x03\0\x1d\x01q\x02\x07success\x01\x1c\0\x05error\ +\x01\x1e\0\x04\0\x0btool-result\x03\0\x1f\x01r\x02\x03keys\x05values\x04\0\x02kv\ +\x03\0!\x01kv\x01ps\x01k$\x01p\x17\x01p\"\x01r\x07\x05models\x0btemperature#\x0a\ +max-tokens\x1a\x0estop-sequences%\x05tools&\x0btool-choice\x12\x10provider-optio\ +ns'\x04\0\x06config\x03\0(\x01r\x03\x0cinput-tokens\x1a\x0doutput-tokens\x1a\x0c\ +total-tokens\x1a\x04\0\x05usage\x03\0*\x01k\x05\x01k+\x01r\x05\x0dfinish-reason,\ +\x05usage-\x0bprovider-id\x12\x09timestamp\x12\x16provider-metadata-json\x12\x04\ +\0\x11response-metadata\x03\0.\x01p\x19\x01r\x04\x02ids\x07content\x13\x0atool-c\ +alls0\x08metadata/\x04\0\x11complete-response\x03\01\x01r\x03\x04code\x03\x07mes\ +sages\x13provider-error-json\x12\x04\0\x05error\x03\03\x01q\x03\x07message\x012\0\ +\x0ctool-request\x010\0\x05error\x014\0\x04\0\x0achat-event\x03\05\x01k\x13\x01k\ +0\x01r\x02\x07content7\x0atool-calls8\x04\0\x0cstream-delta\x03\09\x01q\x03\x05d\ +elta\x01:\0\x06finish\x01/\0\x05error\x014\0\x04\0\x0cstream-event\x03\0;\x04\0\x0b\ +chat-stream\x03\x01\x01h=\x01p<\x01k?\x01@\x01\x04self>\0\xc0\0\x04\0\x1c[method\ +]chat-stream.get-next\x01A\x01@\x01\x04self>\0?\x04\0%[method]chat-stream.blocki\ +ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send\ +\x01D\x01o\x02\x19\x20\x01p\xc5\0\x01@\x03\x08messages\xc3\0\x0ctool-results\xc6\ +\0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ +ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0#golem:\ +llm-bedrock/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09p\ +roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rust\ +\x060.41.0"; +#[inline(never)] +#[doc(hidden)] +pub fn __link_custom_section_describing_imports() { + wit_bindgen_rt::maybe_link_cabi_realloc(); +} diff --git a/llm/bedrock/src/client.rs b/llm/bedrock/src/client.rs new file mode 100644 index 000000000..fd0149c5a --- /dev/null +++ b/llm/bedrock/src/client.rs @@ -0,0 +1,518 @@ +use golem_llm::error::{error_code_from_status, from_event_source_error, from_reqwest_error}; +use golem_llm::event_source::EventSource; +use golem_llm::golem::llm::llm::{Error, ErrorCode}; +use hmac::{Hmac, Mac}; +use log::trace; +use reqwest::{Client, Method, Response}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use sha2::{Digest, Sha256}; +use std::fmt::Debug; +use std::time::{SystemTime, UNIX_EPOCH}; +use time::OffsetDateTime; + +/// AWS Bedrock client for creating model responses +pub struct BedrockClient { + access_key_id: String, + secret_access_key: String, + region: String, + client: Client, +} + +impl BedrockClient { + pub fn new(access_key_id: String, secret_access_key: String, region: String) -> Self { + let client = Client::new(); + Self { + access_key_id, + secret_access_key, + region, + client, + } + } + + pub fn converse( + &self, + model_id: &str, + request: ConverseRequest, + ) -> Result { + trace!("Sending request to Bedrock API: {request:?}"); + let url = format!( + "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse", + self.region, model_id + ); + + let body = serde_json::to_string(&request).map_err(|err| Error { + code: ErrorCode::InternalError, + message: "Failed to serialize request".to_string(), + provider_error_json: Some(err.to_string()), + })?; + + let host = format!("bedrock-runtime.{}.amazonaws.com", self.region); + let headers = generate_sigv4_headers( + &self.access_key_id, + &self.secret_access_key, + &self.region, + "bedrock", + "POST", + &format!("/model/{model_id}/converse"), + &host, + &body, + ) + .map_err(|err| Error { + code: ErrorCode::InternalError, + message: "Failed to sign headers".to_string(), + provider_error_json: Some(err.to_string()), + })?; + + let mut request_builder = self.client.request(Method::POST, &url); + for (key, value) in headers { + request_builder = request_builder.header(key, value); + } + + let response: Response = request_builder.body(body).send().map_err(|err| { + trace!("HTTP request failed with error: {err:?}"); + from_reqwest_error("Request failed", err) + })?; + + trace!("Received response from Bedrock API: {response:?}"); + + parse_response(response) + } + + pub fn converse_stream( + &self, + model_id: &str, + request: ConverseRequest, + ) -> Result { + trace!("Sending streaming request to Bedrock API: {request:?}"); + let url = format!( + "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse-stream", + self.region, model_id + ); + + let body = serde_json::to_string(&request).map_err(|err| Error { + code: ErrorCode::InternalError, + message: "Failed to serialize request".to_string(), + provider_error_json: Some(err.to_string()), + })?; + + let host = format!("bedrock-runtime.{}.amazonaws.com", self.region); + let headers = generate_sigv4_headers( + &self.access_key_id, + &self.secret_access_key, + &self.region, + "bedrock", + "POST", + &format!("/model/{model_id}/converse-stream"), + &host, + &body, + ) + .map_err(|err| Error { + code: ErrorCode::InternalError, + message: "Failed to sign headers".to_string(), + provider_error_json: Some(err.to_string()), + })?; + + let mut request_builder = self.client.request(Method::POST, &url); + for (key, value) in headers { + request_builder = request_builder.header(key, value); + } + + trace!("Sending streaming HTTP request to Bedrock..."); + let response: Response = request_builder.body(body).send().map_err(|err| { + trace!("HTTP request failed with error: {err:?}"); + from_reqwest_error("Request failed", err) + })?; + + trace!("Initializing SSE stream"); + trace!("Response: {:?}", response.headers().clone()); + EventSource::new(response) + .map_err(|err| from_event_source_error("Failed to create SSE stream", err)) + } +} + +#[allow(clippy::too_many_arguments)] +pub fn generate_sigv4_headers( + access_key: &str, + secret_key: &str, + region: &str, + service: &str, + method: &str, + uri: &str, + host: &str, + body: &str, +) -> Result, Box> { + use std::collections::BTreeMap; + + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); + let timestamp = OffsetDateTime::from_unix_timestamp(now.as_secs() as i64).unwrap(); + + let date_str = format!( + "{:04}{:02}{:02}", + timestamp.year(), + timestamp.month() as u8, + timestamp.day() + ); + let datetime_str = format!( + "{:04}{:02}{:02}T{:02}{:02}{:02}Z", + timestamp.year(), + timestamp.month() as u8, + timestamp.day(), + timestamp.hour(), + timestamp.minute(), + timestamp.second() + ); + + let (canonical_uri, canonical_query_string) = if let Some(query_pos) = uri.find('?') { + let path = &uri[..query_pos]; + let query = &uri[query_pos + 1..]; + + let encoded_path = if path.contains(':') { + path.replace(':', "%3A") + } else { + path.to_string() + }; + + let mut query_params: Vec<&str> = query.split('&').collect(); + query_params.sort(); + (encoded_path, query_params.join("&")) + } else { + let encoded_path = if uri.contains(':') { + uri.replace(':', "%3A") + } else { + uri.to_string() + }; + (encoded_path, String::new()) + }; + + let mut headers = BTreeMap::new(); + headers.insert("content-type", "application/x-amz-json-1.0"); + headers.insert("host", host); + headers.insert("x-amz-date", &datetime_str); + + let canonical_headers = headers + .iter() + .map(|(k, v)| format!("{}:{}", k.to_lowercase().trim(), v.trim())) + .collect::>() + .join("\n") + + "\n"; + + let signed_headers = headers + .keys() + .map(|k| k.to_lowercase()) + .collect::>() + .join(";"); + + let payload_hash = format!("{:x}", Sha256::digest(body.as_bytes())); + + let canonical_request = format!( + "{method}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n{signed_headers}\n{payload_hash}" + ); + + let credential_scope = format!("{date_str}/{region}/{service}/aws4_request"); + let canonical_request_hash = format!("{:x}", Sha256::digest(canonical_request.as_bytes())); + let string_to_sign = + format!("AWS4-HMAC-SHA256\n{datetime_str}\n{credential_scope}\n{canonical_request_hash}"); + + type HmacSha256 = Hmac; + + let mut mac = HmacSha256::new_from_slice(format!("AWS4{secret_key}").as_bytes())?; + mac.update(date_str.as_bytes()); + let date_key = mac.finalize().into_bytes(); + + let mut mac = HmacSha256::new_from_slice(&date_key)?; + mac.update(region.as_bytes()); + let region_key = mac.finalize().into_bytes(); + + let mut mac = HmacSha256::new_from_slice(®ion_key)?; + mac.update(service.as_bytes()); + let service_key = mac.finalize().into_bytes(); + + let mut mac = HmacSha256::new_from_slice(&service_key)?; + mac.update(b"aws4_request"); + let signing_key = mac.finalize().into_bytes(); + + let mut mac = HmacSha256::new_from_slice(&signing_key)?; + mac.update(string_to_sign.as_bytes()); + let signature = format!("{:x}", mac.finalize().into_bytes()); + + let auth_header = format!( + "AWS4-HMAC-SHA256 Credential={access_key}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}" + ); + + let result_headers = vec![ + ("authorization".to_string(), auth_header), + ("x-amz-date".to_string(), datetime_str), + ( + "content-type".to_string(), + "application/x-amz-json-1.0".to_string(), + ), + ]; + + Ok(result_headers) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConverseRequest { + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub system: Option>, + #[serde(rename = "inferenceConfig", skip_serializing_if = "Option::is_none")] + pub inference_config: Option, + #[serde(rename = "toolConfig", skip_serializing_if = "Option::is_none")] + pub tool_config: Option, + #[serde(rename = "guardrailConfig", skip_serializing_if = "Option::is_none")] + pub guardrail_config: Option, + #[serde( + rename = "additionalModelRequestFields", + skip_serializing_if = "Option::is_none" + )] + pub additional_model_request_fields: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: Role, + pub content: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Role { + #[serde(rename = "user")] + User, + #[serde(rename = "assistant")] + Assistant, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ContentBlock { + Text { + text: String, + }, + Image { + image: ImageBlock, + }, + ToolUse { + #[serde(rename = "toolUse")] + tool_use: ToolUseBlock, + }, + ToolResult { + #[serde(rename = "toolResult")] + tool_result: ToolResultBlock, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImageBlock { + pub format: ImageFormat, + pub source: ImageSource, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolUseBlock { + #[serde(rename = "toolUseId")] + pub tool_use_id: String, + pub name: String, + pub input: Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolResultBlock { + #[serde(rename = "toolUseId")] + pub tool_use_id: String, + pub content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub status: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ImageFormat { + #[serde(rename = "png")] + Png, + #[serde(rename = "jpeg")] + Jpeg, + #[serde(rename = "gif")] + Gif, + #[serde(rename = "webp")] + Webp, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImageSource { + pub bytes: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ToolResultContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image")] + Image { + #[serde(rename = "format")] + format: ImageFormat, + #[serde(rename = "source")] + source: ImageSource, + }, + #[serde(rename = "json")] + Json { json: Value }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ToolResultStatus { + #[serde(rename = "success")] + Success, + #[serde(rename = "error")] + Error, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum SystemContentBlock { + #[serde(rename = "text")] + Text { text: String }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InferenceConfig { + #[serde(rename = "maxTokens", skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(rename = "topP", skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")] + pub stop_sequences: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolConfig { + pub tools: Vec, + #[serde(rename = "toolChoice", skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Tool { + #[serde(rename = "toolSpec")] + pub tool_spec: ToolSpec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolSpec { + pub name: String, + pub description: String, + #[serde(rename = "inputSchema")] + pub input_schema: ToolInputSchema, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolInputSchema { + pub json: Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolChoice { + Auto { auto: serde_json::Value }, + Any { any: serde_json::Value }, + Tool { tool: ToolChoiceTool }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolChoiceTool { + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GuardrailConfig { + #[serde(rename = "guardrailIdentifier")] + pub guardrail_identifier: String, + #[serde(rename = "guardrailVersion")] + pub guardrail_version: String, + pub trace: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConverseResponse { + pub output: Output, + #[serde(rename = "stopReason")] + pub stop_reason: StopReason, + pub usage: Usage, + pub metrics: Metrics, + #[serde( + rename = "additionalModelResponseFields", + skip_serializing_if = "Option::is_none" + )] + pub additional_model_response_fields: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Output { + pub message: Message, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum StopReason { + #[serde(rename = "end_turn")] + EndTurn, + #[serde(rename = "tool_use")] + ToolUse, + #[serde(rename = "max_tokens")] + MaxTokens, + #[serde(rename = "stop_sequence")] + StopSequence, + #[serde(rename = "guardrail_intervened")] + GuardrailIntervened, + #[serde(rename = "content_filtered")] + ContentFiltered, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Usage { + #[serde(rename = "inputTokens")] + pub input_tokens: u32, + #[serde(rename = "outputTokens")] + pub output_tokens: u32, + #[serde(rename = "totalTokens")] + pub total_tokens: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Metrics { + #[serde(rename = "latencyMs")] + pub latency_ms: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ErrorResponse { + pub message: String, + #[serde(rename = "type")] + pub error_type: String, +} + +fn parse_response(response: Response) -> Result { + let status = response.status(); + if status.is_success() { + let body = response + .json::() + .map_err(|err| from_reqwest_error("Failed to decode response body", err))?; + + trace!("Received response from Bedrock API: {body:?}"); + + Ok(body) + } else { + let body = response + .text() + .map_err(|err| from_reqwest_error("Failed to receive error response body", err))?; + trace!("Received {status} response from Bedrock API: {body:?}"); + + Err(Error { + code: error_code_from_status(status), + message: format!("Request failed with {status}: {body}"), + provider_error_json: Some(body), + }) + } +} diff --git a/llm/bedrock/src/conversions.rs b/llm/bedrock/src/conversions.rs new file mode 100644 index 000000000..19b6854f9 --- /dev/null +++ b/llm/bedrock/src/conversions.rs @@ -0,0 +1,351 @@ +use crate::client::{ + ContentBlock, ConverseRequest, ConverseResponse, ImageBlock, ImageFormat, + ImageSource as ClientImageSource, InferenceConfig, Message as ClientMessage, + Role as ClientRole, StopReason, SystemContentBlock, Tool, ToolChoice, ToolChoiceTool, + ToolConfig, ToolInputSchema, ToolResultBlock, ToolResultContentBlock, ToolResultStatus, + ToolSpec, ToolUseBlock, +}; +use base64::{engine::general_purpose, Engine as _}; +use golem_llm::golem::llm::llm::{ + ChatEvent, CompleteResponse, Config, ContentPart, Error, ErrorCode, FinishReason, + ImageReference, ImageSource, Message, ResponseMetadata, Role, ToolCall, ToolDefinition, + ToolResult, Usage, +}; +use reqwest::{Client, Url}; +use std::{collections::HashMap, fs, path::Path}; + +pub fn messages_to_request( + messages: Vec, + config: Config, +) -> Result { + let options = config + .provider_options + .into_iter() + .map(|kv| (kv.key, kv.value)) + .collect::>(); + + let mut bedrock_messages = Vec::new(); + let mut system_messages = Vec::new(); + + for message in &messages { + match message.role { + Role::System => { + system_messages.extend(message_to_system_content(message)); + } + Role::User | Role::Assistant | Role::Tool => { + bedrock_messages.push(ClientMessage { + role: match &message.role { + Role::User | Role::Tool => ClientRole::User, + Role::Assistant => ClientRole::Assistant, + Role::System => unreachable!(), + }, + content: message_to_content(message)?, + }); + } + } + } + + let inference_config = if config.max_tokens.is_some() + || config.temperature.is_some() + || config.stop_sequences.is_some() + || options.contains_key("top_p") + { + Some(InferenceConfig { + max_tokens: config.max_tokens, + temperature: config.temperature, + top_p: options + .get("top_p") + .and_then(|top_p_s| top_p_s.parse::().ok()), + stop_sequences: config.stop_sequences, + }) + } else { + None + }; + + let tool_config = if config.tools.is_empty() { + None + } else { + let mut tools = Vec::new(); + for tool in &config.tools { + tools.push(tool_definition_to_tool(tool)?); + } + + let tool_choice = config.tool_choice.map(convert_tool_choice); + + if tools.is_empty() { + None + } else { + Some(ToolConfig { tools, tool_choice }) + } + }; + + Ok(ConverseRequest { + messages: bedrock_messages, + system: if system_messages.is_empty() { + None + } else { + Some(system_messages) + }, + inference_config, + tool_config, + guardrail_config: None, + additional_model_request_fields: None, + }) +} + +fn convert_tool_choice(tool_name: String) -> ToolChoice { + use serde_json::Value; + + match tool_name.as_str() { + "auto" => ToolChoice::Auto { + auto: Value::Object(serde_json::Map::new()), + }, + "any" => ToolChoice::Any { + any: Value::Object(serde_json::Map::new()), + }, + name => ToolChoice::Tool { + tool: ToolChoiceTool { + name: name.to_string(), + }, + }, + } +} + +pub fn process_response(response: ConverseResponse) -> ChatEvent { + let mut contents = Vec::new(); + let mut tool_calls = Vec::new(); + + for content in response.output.message.content { + match content { + ContentBlock::Text { text } => contents.push(ContentPart::Text(text)), + ContentBlock::Image { image } => { + match general_purpose::STANDARD.decode(&image.source.bytes) { + Ok(decoded_data) => { + let mime_type = match image.format { + ImageFormat::Jpeg => "image/jpeg", + ImageFormat::Png => "image/png", + ImageFormat::Gif => "image/gif", + ImageFormat::Webp => "image/webp", + }; + contents.push(ContentPart::Image(ImageReference::Inline(ImageSource { + data: decoded_data, + mime_type: mime_type.to_string(), + detail: None, + }))); + } + Err(e) => { + return ChatEvent::Error(Error { + code: ErrorCode::InvalidRequest, + message: format!("Failed to decode base64 image data: {e}"), + provider_error_json: None, + }); + } + } + } + ContentBlock::ToolUse { tool_use } => tool_calls.push(ToolCall { + id: tool_use.tool_use_id, + name: tool_use.name, + arguments_json: serde_json::to_string(&tool_use.input).unwrap(), + }), + ContentBlock::ToolResult { .. } => {} + } + } + + if contents.is_empty() && !tool_calls.is_empty() { + ChatEvent::ToolRequest(tool_calls) + } else { + let request_id = "bedrock-response".to_string(); + + let metadata = ResponseMetadata { + finish_reason: Some(stop_reason_to_finish_reason(response.stop_reason)), + usage: Some(convert_usage(response.usage)), + provider_id: Some(request_id.clone()), + timestamp: None, + provider_metadata_json: None, + }; + + ChatEvent::Message(CompleteResponse { + id: request_id, + content: contents, + tool_calls, + metadata, + }) + } +} + +pub fn tool_results_to_messages(tool_results: Vec<(ToolCall, ToolResult)>) -> Vec { + let mut messages = Vec::new(); + + for (tool_call, tool_result) in tool_results { + messages.push(ClientMessage { + content: vec![ContentBlock::ToolUse { + tool_use: ToolUseBlock { + tool_use_id: tool_call.id.clone(), + name: tool_call.name, + input: serde_json::from_str(&tool_call.arguments_json).unwrap(), + }, + }], + role: ClientRole::Assistant, + }); + + let (content, status) = match tool_result { + ToolResult::Success(success) => ( + vec![ToolResultContentBlock::Text { + text: success.result_json, + }], + Some(ToolResultStatus::Success), + ), + ToolResult::Error(error) => ( + vec![ToolResultContentBlock::Text { + text: error.error_message, + }], + Some(ToolResultStatus::Error), + ), + }; + + messages.push(ClientMessage { + content: vec![ContentBlock::ToolResult { + tool_result: ToolResultBlock { + tool_use_id: tool_call.id, + content, + status, + }, + }], + role: ClientRole::User, + }); + } + + messages +} + +pub fn stop_reason_to_finish_reason(stop_reason: StopReason) -> FinishReason { + match stop_reason { + StopReason::EndTurn => FinishReason::Stop, + StopReason::ToolUse => FinishReason::ToolCalls, + StopReason::MaxTokens => FinishReason::Length, + StopReason::StopSequence => FinishReason::Stop, + StopReason::GuardrailIntervened => FinishReason::ContentFilter, + StopReason::ContentFiltered => FinishReason::ContentFilter, + } +} + +pub fn convert_usage(usage: crate::client::Usage) -> Usage { + Usage { + input_tokens: Some(usage.input_tokens), + output_tokens: Some(usage.output_tokens), + total_tokens: Some(usage.total_tokens), + } +} + +fn message_to_content(message: &Message) -> Result, Error> { + let mut result = Vec::new(); + + for content_part in &message.content { + match content_part { + ContentPart::Text(text) => result.push(ContentBlock::Text { text: text.clone() }), + ContentPart::Image(image_reference) => match image_reference { + ImageReference::Url(image_url) => { + let url = &image_url.url; + let mut format = ImageFormat::Png; + let bytes = if Url::parse(url).is_ok() { + let client = Client::new(); + let response = client.get(url).send().map_err(|e| Error { + code: ErrorCode::InvalidRequest, + message: format!("Failed to fetch image from URL: {e}"), + provider_error_json: None, + }); + response.map(|r| { + format = + match r.headers().get("Content-Type").unwrap().to_str().unwrap() { + "image/jpeg" => ImageFormat::Jpeg, + "image/png" => ImageFormat::Png, + "image/gif" => ImageFormat::Gif, + "image/webp" => ImageFormat::Webp, + _ => ImageFormat::Jpeg, + }; + r.bytes().unwrap().to_vec() + }) + } else { + let path = Path::new(url); + fs::read(path).map_err(|e| Error { + code: ErrorCode::InvalidRequest, + message: format!("Failed to read image from path: {e}"), + provider_error_json: None, + }) + }; + + let base64_data = general_purpose::STANDARD.encode(bytes.unwrap()); + result.push(ContentBlock::Image { + image: ImageBlock { + format: ImageFormat::Png, + source: ClientImageSource { bytes: base64_data }, + }, + }); + } + ImageReference::Inline(image_source) => { + let base64_data = general_purpose::STANDARD.encode(&image_source.data); + let format = match image_source.mime_type.as_str() { + "image/jpeg" => ImageFormat::Jpeg, + "image/png" => ImageFormat::Png, + "image/gif" => ImageFormat::Gif, + "image/webp" => ImageFormat::Webp, + _ => ImageFormat::Jpeg, + }; + + result.push(ContentBlock::Image { + image: ImageBlock { + format, + source: ClientImageSource { bytes: base64_data }, + }, + }); + } + }, + } + } + + Ok(result) +} + +fn message_to_system_content(message: &Message) -> Vec { + let mut result = Vec::new(); + + for content_part in &message.content { + match content_part { + ContentPart::Text(text) => result.push(SystemContentBlock::Text { text: text.clone() }), + ContentPart::Image(_) => {} + } + } + + result +} + +fn tool_definition_to_tool(tool: &ToolDefinition) -> Result { + use serde_json::Value; + + let schema_value = if tool.parameters_schema.trim().is_empty() { + serde_json::json!({ + "type": "object", + "properties": {}, + "additionalProperties": false + }) + } else { + match serde_json::from_str::(&tool.parameters_schema) { + Ok(value) => value, + Err(error) => { + return Err(Error { + code: ErrorCode::InternalError, + message: format!("Failed to parse tool parameters for {}: {error}", tool.name), + provider_error_json: None, + }); + } + } + }; + + Ok(Tool { + tool_spec: ToolSpec { + name: tool.name.clone(), + description: tool.description.clone().unwrap_or_default(), + input_schema: ToolInputSchema { json: schema_value }, + }, + }) +} diff --git a/llm/bedrock/src/lib.rs b/llm/bedrock/src/lib.rs new file mode 100644 index 000000000..a0bca3f64 --- /dev/null +++ b/llm/bedrock/src/lib.rs @@ -0,0 +1,372 @@ +mod client; +mod conversions; + +use crate::client::{BedrockClient, ConverseRequest}; +use crate::conversions::{messages_to_request, process_response, tool_results_to_messages}; +use golem_llm::chat_stream::{LlmChatStream, LlmChatStreamState}; +use golem_llm::config::with_config_keys; +use golem_llm::durability::{DurableLLM, ExtendedGuest}; +use golem_llm::event_source::EventSource; +use golem_llm::golem::llm::llm::{ + ChatEvent, ChatStream, Config, ContentPart, Error, Guest, Message, ResponseMetadata, Role, + StreamDelta, StreamEvent, ToolCall, ToolResult, +}; +use golem_llm::LOGGING_STATE; +use golem_rust::wasm_rpc::Pollable; +use log::trace; +use serde::Deserialize; +use serde_json::Value; +use std::cell::{Ref, RefCell, RefMut}; + +struct BedrockChatStream { + stream: RefCell>, + failure: Option, + finished: RefCell, +} + +#[derive(Debug, Deserialize)] +pub struct EventContentBlock { + #[serde(rename = "contentBlockIndex")] + pub content_block_index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub delta: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub start: Option, + pub p: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum Delta { + ToolUse { + #[serde(rename = "toolUse")] + tool_use: ToolUse, + }, + Text { + text: String, + }, +} + +#[derive(Debug, Deserialize)] +pub struct ToolUse { + pub input: String, +} + +#[derive(Debug, Deserialize)] +pub struct ToolUseStart { + #[serde(rename = "toolUse")] + pub tool_use: ToolUseInfo, +} + +#[derive(Debug, Deserialize)] +pub struct ToolUseInfo { + pub name: String, + #[serde(rename = "toolUseId")] + pub tool_use_id: String, +} + +#[derive(Debug, Deserialize)] +pub struct MetadataMessage { + pub p: String, + pub usage: Option, + pub metrics: Option, +} + +#[derive(Debug, Deserialize)] +pub struct Metrics { + #[serde(rename = "latencyMs")] + pub latency_ms: u32, +} + +#[derive(Debug, Deserialize)] +pub struct Usage { + #[serde(rename = "outputTokens")] + pub output_tokens: u32, + #[serde(rename = "totalTokens")] + pub total_tokens: u32, +} + +impl BedrockChatStream { + pub fn new(stream: EventSource) -> LlmChatStream { + LlmChatStream::new(BedrockChatStream { + stream: RefCell::new(Some(stream)), + failure: None, + finished: RefCell::new(false), + }) + } + + pub fn failed(error: Error) -> LlmChatStream { + LlmChatStream::new(BedrockChatStream { + stream: RefCell::new(None), + failure: Some(error), + finished: RefCell::new(false), + }) + } +} + +impl LlmChatStreamState for BedrockChatStream { + fn failure(&self) -> &Option { + &self.failure + } + + fn is_finished(&self) -> bool { + *self.finished.borrow() + } + + fn set_finished(&self) { + *self.finished.borrow_mut() = true; + } + + fn stream(&self) -> Ref> { + self.stream.borrow() + } + + fn stream_mut(&self) -> RefMut> { + self.stream.borrow_mut() + } + + fn decode_message(&self, raw: &str) -> Result, String> { + trace!("Received raw stream event: {raw}"); + + let json: Value = serde_json::from_str(raw) + .map_err(|err| format!("Failed to deserialize stream event: {err}"))?; + + if json.get("role").is_some() { + return Ok(Some(StreamEvent::Delta(StreamDelta { + content: Some(vec![ContentPart::Text(json.to_string())]), + tool_calls: None, + }))); + } + + if json.get("usage").is_some() || json.get("metrics").is_some() { + if let Ok(metadata) = serde_json::from_value::(json.clone()) { + let usage = metadata.usage.unwrap(); + return Ok(Some(StreamEvent::Finish(ResponseMetadata { + finish_reason: None, + usage: Some(golem_llm::golem::llm::llm::Usage { + input_tokens: Some(usage.total_tokens - usage.output_tokens), + output_tokens: Some(usage.output_tokens), + total_tokens: Some(usage.total_tokens), + }), + provider_id: None, + timestamp: None, + provider_metadata_json: if metadata.metrics.is_some() { + Some(format!("{:?}", metadata.metrics.unwrap())) + } else { + None + }, + }))); + } + } + + if let Ok(event_content_block) = serde_json::from_value::(json.clone()) { + if let Some(delta) = event_content_block.delta { + match delta { + Delta::Text { text } => { + return Ok(Some(StreamEvent::Delta(StreamDelta { + content: Some(vec![ContentPart::Text(text)]), + tool_calls: None, + }))); + } + Delta::ToolUse { tool_use } => { + return Ok(Some(StreamEvent::Delta(StreamDelta { + content: Some(vec![ContentPart::Text(tool_use.input)]), + tool_calls: None, + }))); + } + } + } + if let Some(tool_use_start) = event_content_block.start { + return Ok(Some(StreamEvent::Delta(StreamDelta { + content: Some(vec![]), + tool_calls: Some(vec![ToolCall { + id: tool_use_start.tool_use.tool_use_id, + name: tool_use_start.tool_use.name, + arguments_json: "".to_string(), + }]), + }))); + } + } + + Ok(None) + } +} + +struct BedrockComponent; + +impl BedrockComponent { + const ACCESS_KEY_ID_ENV_VAR: &'static str = "AWS_ACCESS_KEY_ID"; + const SECRET_ACCESS_KEY_ENV_VAR: &'static str = "AWS_SECRET_ACCESS_KEY"; + const REGION_ENV_VAR: &'static str = "AWS_REGION"; + + fn request(client: BedrockClient, model_id: &str, request: ConverseRequest) -> ChatEvent { + match client.converse(model_id, request) { + Ok(response) => process_response(response), + Err(err) => ChatEvent::Error(err), + } + } + + fn streaming_request( + client: BedrockClient, + model_id: &str, + request: ConverseRequest, + ) -> LlmChatStream { + match client.converse_stream(model_id, request) { + Ok(stream) => BedrockChatStream::new(stream), + Err(err) => BedrockChatStream::failed(err), + } + } +} + +impl Guest for BedrockComponent { + type ChatStream = LlmChatStream; + + fn send(messages: Vec, config: Config) -> ChatEvent { + LOGGING_STATE.with_borrow_mut(|state| state.init()); + with_config_keys( + &[ + Self::ACCESS_KEY_ID_ENV_VAR, + Self::SECRET_ACCESS_KEY_ENV_VAR, + Self::REGION_ENV_VAR, + ], + ChatEvent::Error, + |bedrock_api_keys| { + let client = BedrockClient::new( + bedrock_api_keys[Self::ACCESS_KEY_ID_ENV_VAR].clone(), + bedrock_api_keys[Self::SECRET_ACCESS_KEY_ENV_VAR].clone(), + bedrock_api_keys[Self::REGION_ENV_VAR].clone(), + ); + + match messages_to_request(messages, config.clone()) { + Ok(request) => Self::request(client, &config.model, request), + Err(err) => ChatEvent::Error(err), + } + }, + ) + } + + fn continue_( + messages: Vec, + tool_results: Vec<(ToolCall, ToolResult)>, + config: Config, + ) -> ChatEvent { + LOGGING_STATE.with_borrow_mut(|state| state.init()); + + with_config_keys( + &[ + Self::ACCESS_KEY_ID_ENV_VAR, + Self::SECRET_ACCESS_KEY_ENV_VAR, + Self::REGION_ENV_VAR, + ], + ChatEvent::Error, + |bedrock_api_keys| { + let client = BedrockClient::new( + bedrock_api_keys[Self::ACCESS_KEY_ID_ENV_VAR].clone(), + bedrock_api_keys[Self::SECRET_ACCESS_KEY_ENV_VAR].clone(), + bedrock_api_keys[Self::REGION_ENV_VAR].clone(), + ); + + match messages_to_request(messages, config.clone()) { + Ok(mut request) => { + request + .messages + .extend(tool_results_to_messages(tool_results)); + Self::request(client, &config.model, request) + } + Err(err) => ChatEvent::Error(err), + } + }, + ) + } + + fn stream(messages: Vec, config: Config) -> ChatStream { + ChatStream::new(Self::unwrapped_stream(messages, config)) + } +} + +impl ExtendedGuest for BedrockComponent { + fn unwrapped_stream( + messages: Vec, + config: Config, + ) -> LlmChatStream { + LOGGING_STATE.with_borrow_mut(|state| state.init()); + + with_config_keys( + &[ + Self::ACCESS_KEY_ID_ENV_VAR, + Self::SECRET_ACCESS_KEY_ENV_VAR, + Self::REGION_ENV_VAR, + ], + BedrockChatStream::failed, + |bedrock_api_keys| { + let client = BedrockClient::new( + bedrock_api_keys[Self::ACCESS_KEY_ID_ENV_VAR].clone(), + bedrock_api_keys[Self::SECRET_ACCESS_KEY_ENV_VAR].clone(), + bedrock_api_keys[Self::REGION_ENV_VAR].clone(), + ); + + match messages_to_request(messages, config.clone()) { + Ok(request) => Self::streaming_request(client, &config.model, request), + Err(err) => BedrockChatStream::failed(err), + } + }, + ) + } + + fn retry_prompt(original_messages: &[Message], partial_result: &[StreamDelta]) -> Vec { + let mut extended_messages = Vec::new(); + extended_messages.push(Message { + role: Role::System, + name: None, + content: vec![ + ContentPart::Text( + "You were asked the same question previously, but the response was interrupted before completion. \ + Please continue your response from where you left off. \ + Do not include the part of the response that was already seen.".to_string()), + ], + }); + extended_messages.push(Message { + role: Role::User, + name: None, + content: vec![ContentPart::Text( + "Here is the original question:".to_string(), + )], + }); + extended_messages.extend_from_slice(original_messages); + + let mut partial_result_as_content = Vec::new(); + for delta in partial_result { + if let Some(contents) = &delta.content { + partial_result_as_content.extend_from_slice(contents); + } + if let Some(tool_calls) = &delta.tool_calls { + for tool_call in tool_calls { + partial_result_as_content.push(ContentPart::Text(format!( + "", + tool_call.id, tool_call.name, tool_call.arguments_json, + ))); + } + } + } + + extended_messages.push(Message { + role: Role::User, + name: None, + content: vec![ContentPart::Text( + "Here is the partial response that was successfully received:".to_string(), + )] + .into_iter() + .chain(partial_result_as_content) + .collect(), + }); + extended_messages + } + + fn subscribe(stream: &Self::ChatStream) -> Pollable { + stream.subscribe() + } +} + +type DurableBedrockComponent = DurableLLM; + +golem_llm::export_llm!(DurableBedrockComponent with_types_in golem_llm); diff --git a/llm/bedrock/wit/bedrock.wit b/llm/bedrock/wit/bedrock.wit new file mode 100644 index 000000000..266ba7293 --- /dev/null +++ b/llm/bedrock/wit/bedrock.wit @@ -0,0 +1,7 @@ +package golem:llm-bedrock@1.0.0; + +world llm-library { + include golem:llm/llm-library@1.0.0; + + +} \ No newline at end of file diff --git a/llm/bedrock/wit/deps/golem-llm/golem-llm.wit b/llm/bedrock/wit/deps/golem-llm/golem-llm.wit new file mode 100644 index 000000000..67854470a --- /dev/null +++ b/llm/bedrock/wit/deps/golem-llm/golem-llm.wit @@ -0,0 +1,194 @@ +package golem:llm@1.0.0; + +interface llm { + // --- Roles, Error Codes, Finish Reasons --- + + enum role { + user, + assistant, + system, + tool, + } + + enum error-code { + invalid-request, + authentication-failed, + rate-limit-exceeded, + internal-error, + unsupported, + unknown, + } + + enum finish-reason { + stop, + length, + tool-calls, + content-filter, + error, + other, + } + + enum image-detail { + low, + high, + auto, + } + + // --- Message Content --- + + record image-url { + url: string, + detail: option, + } + + record image-source { + data: list, + mime-type: string, + detail: option, + } + + variant image-reference { + url(image-url), + inline(image-source), + } + + variant content-part { + text(string), + image(image-reference), + } + + record message { + role: role, + name: option, + content: list, + } + + // --- Tooling --- + + record tool-definition { + name: string, + description: option, + parameters-schema: string, + } + + record tool-call { + id: string, + name: string, + arguments-json: string, + } + + record tool-success { + id: string, + name: string, + result-json: string, + execution-time-ms: option, + } + + record tool-failure { + id: string, + name: string, + error-message: string, + error-code: option, + } + + variant tool-result { + success(tool-success), + error(tool-failure), + } + + // --- Configuration --- + + record kv { + key: string, + value: string, + } + + record config { + model: string, + temperature: option, + max-tokens: option, + stop-sequences: option>, + tools: list, + tool-choice: option, + provider-options: list, + } + + // --- Usage / Metadata --- + + record usage { + input-tokens: option, + output-tokens: option, + total-tokens: option, + } + + record response-metadata { + finish-reason: option, + usage: option, + provider-id: option, + timestamp: option, + provider-metadata-json: option, + } + + record complete-response { + id: string, + content: list, + tool-calls: list, + metadata: response-metadata, + } + + // --- Error Handling --- + + record error { + code: error-code, + message: string, + provider-error-json: option, + } + + // --- Chat Response Variants --- + + variant chat-event { + message(complete-response), + tool-request(list), + error(error), + } + + // --- Streaming --- + + record stream-delta { + content: option>, + tool-calls: option>, + } + + variant stream-event { + delta(stream-delta), + finish(response-metadata), + error(error), + } + + resource chat-stream { + get-next: func() -> option>; + blocking-get-next: func() -> list; + } + + // --- Core Functions --- + + send: func( + messages: list, + config: config + ) -> chat-event; + + continue: func( + messages: list, + tool-results: list>, + config: config + ) -> chat-event; + + %stream: func( + messages: list, + config: config + ) -> chat-stream; +} + +world llm-library { + export llm; +} diff --git a/llm/bedrock/wit/deps/wasi:io/error.wit b/llm/bedrock/wit/deps/wasi:io/error.wit new file mode 100644 index 000000000..97c606877 --- /dev/null +++ b/llm/bedrock/wit/deps/wasi:io/error.wit @@ -0,0 +1,34 @@ +package wasi:io@0.2.3; + +@since(version = 0.2.0) +interface error { + /// A resource which represents some error information. + /// + /// The only method provided by this resource is `to-debug-string`, + /// which provides some human-readable information about the error. + /// + /// In the `wasi:io` package, this resource is returned through the + /// `wasi:io/streams/stream-error` type. + /// + /// To provide more specific error information, other interfaces may + /// offer functions to "downcast" this error into more specific types. For example, + /// errors returned from streams derived from filesystem types can be described using + /// the filesystem's own error-code type. This is done using the function + /// `wasi:filesystem/types/filesystem-error-code`, which takes a `borrow` + /// parameter and returns an `option`. + /// + /// The set of functions which can "downcast" an `error` into a more + /// concrete type is open. + @since(version = 0.2.0) + resource error { + /// Returns a string that is suitable to assist humans in debugging + /// this error. + /// + /// WARNING: The returned string should not be consumed mechanically! + /// It may change across platforms, hosts, or other implementation + /// details. Parsing this string is a major platform-compatibility + /// hazard. + @since(version = 0.2.0) + to-debug-string: func() -> string; + } +} diff --git a/llm/bedrock/wit/deps/wasi:io/poll.wit b/llm/bedrock/wit/deps/wasi:io/poll.wit new file mode 100644 index 000000000..9bcbe8e03 --- /dev/null +++ b/llm/bedrock/wit/deps/wasi:io/poll.wit @@ -0,0 +1,47 @@ +package wasi:io@0.2.3; + +/// A poll API intended to let users wait for I/O events on multiple handles +/// at once. +@since(version = 0.2.0) +interface poll { + /// `pollable` represents a single I/O event which may be ready, or not. + @since(version = 0.2.0) + resource pollable { + + /// Return the readiness of a pollable. This function never blocks. + /// + /// Returns `true` when the pollable is ready, and `false` otherwise. + @since(version = 0.2.0) + ready: func() -> bool; + + /// `block` returns immediately if the pollable is ready, and otherwise + /// blocks until ready. + /// + /// This function is equivalent to calling `poll.poll` on a list + /// containing only this pollable. + @since(version = 0.2.0) + block: func(); + } + + /// Poll for completion on a set of pollables. + /// + /// This function takes a list of pollables, which identify I/O sources of + /// interest, and waits until one or more of the events is ready for I/O. + /// + /// The result `list` contains one or more indices of handles in the + /// argument list that is ready for I/O. + /// + /// This function traps if either: + /// - the list is empty, or: + /// - the list contains more elements than can be indexed with a `u32` value. + /// + /// A timeout can be implemented by adding a pollable from the + /// wasi-clocks API to the list. + /// + /// This function does not return a `result`; polling in itself does not + /// do any I/O so it doesn't fail. If any of the I/O sources identified by + /// the pollables has an error, it is indicated by marking the source as + /// being ready for I/O. + @since(version = 0.2.0) + poll: func(in: list>) -> list; +} diff --git a/llm/bedrock/wit/deps/wasi:io/streams.wit b/llm/bedrock/wit/deps/wasi:io/streams.wit new file mode 100644 index 000000000..0de084629 --- /dev/null +++ b/llm/bedrock/wit/deps/wasi:io/streams.wit @@ -0,0 +1,290 @@ +package wasi:io@0.2.3; + +/// WASI I/O is an I/O abstraction API which is currently focused on providing +/// stream types. +/// +/// In the future, the component model is expected to add built-in stream types; +/// when it does, they are expected to subsume this API. +@since(version = 0.2.0) +interface streams { + @since(version = 0.2.0) + use error.{error}; + @since(version = 0.2.0) + use poll.{pollable}; + + /// An error for input-stream and output-stream operations. + @since(version = 0.2.0) + variant stream-error { + /// The last operation (a write or flush) failed before completion. + /// + /// More information is available in the `error` payload. + /// + /// After this, the stream will be closed. All future operations return + /// `stream-error::closed`. + last-operation-failed(error), + /// The stream is closed: no more input will be accepted by the + /// stream. A closed output-stream will return this error on all + /// future operations. + closed + } + + /// An input bytestream. + /// + /// `input-stream`s are *non-blocking* to the extent practical on underlying + /// platforms. I/O operations always return promptly; if fewer bytes are + /// promptly available than requested, they return the number of bytes promptly + /// available, which could even be zero. To wait for data to be available, + /// use the `subscribe` function to obtain a `pollable` which can be polled + /// for using `wasi:io/poll`. + @since(version = 0.2.0) + resource input-stream { + /// Perform a non-blocking read from the stream. + /// + /// When the source of a `read` is binary data, the bytes from the source + /// are returned verbatim. When the source of a `read` is known to the + /// implementation to be text, bytes containing the UTF-8 encoding of the + /// text are returned. + /// + /// This function returns a list of bytes containing the read data, + /// when successful. The returned list will contain up to `len` bytes; + /// it may return fewer than requested, but not more. The list is + /// empty when no bytes are available for reading at this time. The + /// pollable given by `subscribe` will be ready when more bytes are + /// available. + /// + /// This function fails with a `stream-error` when the operation + /// encounters an error, giving `last-operation-failed`, or when the + /// stream is closed, giving `closed`. + /// + /// When the caller gives a `len` of 0, it represents a request to + /// read 0 bytes. If the stream is still open, this call should + /// succeed and return an empty list, or otherwise fail with `closed`. + /// + /// The `len` parameter is a `u64`, which could represent a list of u8 which + /// is not possible to allocate in wasm32, or not desirable to allocate as + /// as a return value by the callee. The callee may return a list of bytes + /// less than `len` in size while more bytes are available for reading. + @since(version = 0.2.0) + read: func( + /// The maximum number of bytes to read + len: u64 + ) -> result, stream-error>; + + /// Read bytes from a stream, after blocking until at least one byte can + /// be read. Except for blocking, behavior is identical to `read`. + @since(version = 0.2.0) + blocking-read: func( + /// The maximum number of bytes to read + len: u64 + ) -> result, stream-error>; + + /// Skip bytes from a stream. Returns number of bytes skipped. + /// + /// Behaves identical to `read`, except instead of returning a list + /// of bytes, returns the number of bytes consumed from the stream. + @since(version = 0.2.0) + skip: func( + /// The maximum number of bytes to skip. + len: u64, + ) -> result; + + /// Skip bytes from a stream, after blocking until at least one byte + /// can be skipped. Except for blocking behavior, identical to `skip`. + @since(version = 0.2.0) + blocking-skip: func( + /// The maximum number of bytes to skip. + len: u64, + ) -> result; + + /// Create a `pollable` which will resolve once either the specified stream + /// has bytes available to read or the other end of the stream has been + /// closed. + /// The created `pollable` is a child resource of the `input-stream`. + /// Implementations may trap if the `input-stream` is dropped before + /// all derived `pollable`s created with this function are dropped. + @since(version = 0.2.0) + subscribe: func() -> pollable; + } + + + /// An output bytestream. + /// + /// `output-stream`s are *non-blocking* to the extent practical on + /// underlying platforms. Except where specified otherwise, I/O operations also + /// always return promptly, after the number of bytes that can be written + /// promptly, which could even be zero. To wait for the stream to be ready to + /// accept data, the `subscribe` function to obtain a `pollable` which can be + /// polled for using `wasi:io/poll`. + /// + /// Dropping an `output-stream` while there's still an active write in + /// progress may result in the data being lost. Before dropping the stream, + /// be sure to fully flush your writes. + @since(version = 0.2.0) + resource output-stream { + /// Check readiness for writing. This function never blocks. + /// + /// Returns the number of bytes permitted for the next call to `write`, + /// or an error. Calling `write` with more bytes than this function has + /// permitted will trap. + /// + /// When this function returns 0 bytes, the `subscribe` pollable will + /// become ready when this function will report at least 1 byte, or an + /// error. + @since(version = 0.2.0) + check-write: func() -> result; + + /// Perform a write. This function never blocks. + /// + /// When the destination of a `write` is binary data, the bytes from + /// `contents` are written verbatim. When the destination of a `write` is + /// known to the implementation to be text, the bytes of `contents` are + /// transcoded from UTF-8 into the encoding of the destination and then + /// written. + /// + /// Precondition: check-write gave permit of Ok(n) and contents has a + /// length of less than or equal to n. Otherwise, this function will trap. + /// + /// returns Err(closed) without writing if the stream has closed since + /// the last call to check-write provided a permit. + @since(version = 0.2.0) + write: func( + contents: list + ) -> result<_, stream-error>; + + /// Perform a write of up to 4096 bytes, and then flush the stream. Block + /// until all of these operations are complete, or an error occurs. + /// + /// This is a convenience wrapper around the use of `check-write`, + /// `subscribe`, `write`, and `flush`, and is implemented with the + /// following pseudo-code: + /// + /// ```text + /// let pollable = this.subscribe(); + /// while !contents.is_empty() { + /// // Wait for the stream to become writable + /// pollable.block(); + /// let Ok(n) = this.check-write(); // eliding error handling + /// let len = min(n, contents.len()); + /// let (chunk, rest) = contents.split_at(len); + /// this.write(chunk ); // eliding error handling + /// contents = rest; + /// } + /// this.flush(); + /// // Wait for completion of `flush` + /// pollable.block(); + /// // Check for any errors that arose during `flush` + /// let _ = this.check-write(); // eliding error handling + /// ``` + @since(version = 0.2.0) + blocking-write-and-flush: func( + contents: list + ) -> result<_, stream-error>; + + /// Request to flush buffered output. This function never blocks. + /// + /// This tells the output-stream that the caller intends any buffered + /// output to be flushed. the output which is expected to be flushed + /// is all that has been passed to `write` prior to this call. + /// + /// Upon calling this function, the `output-stream` will not accept any + /// writes (`check-write` will return `ok(0)`) until the flush has + /// completed. The `subscribe` pollable will become ready when the + /// flush has completed and the stream can accept more writes. + @since(version = 0.2.0) + flush: func() -> result<_, stream-error>; + + /// Request to flush buffered output, and block until flush completes + /// and stream is ready for writing again. + @since(version = 0.2.0) + blocking-flush: func() -> result<_, stream-error>; + + /// Create a `pollable` which will resolve once the output-stream + /// is ready for more writing, or an error has occurred. When this + /// pollable is ready, `check-write` will return `ok(n)` with n>0, or an + /// error. + /// + /// If the stream is closed, this pollable is always ready immediately. + /// + /// The created `pollable` is a child resource of the `output-stream`. + /// Implementations may trap if the `output-stream` is dropped before + /// all derived `pollable`s created with this function are dropped. + @since(version = 0.2.0) + subscribe: func() -> pollable; + + /// Write zeroes to a stream. + /// + /// This should be used precisely like `write` with the exact same + /// preconditions (must use check-write first), but instead of + /// passing a list of bytes, you simply pass the number of zero-bytes + /// that should be written. + @since(version = 0.2.0) + write-zeroes: func( + /// The number of zero-bytes to write + len: u64 + ) -> result<_, stream-error>; + + /// Perform a write of up to 4096 zeroes, and then flush the stream. + /// Block until all of these operations are complete, or an error + /// occurs. + /// + /// This is a convenience wrapper around the use of `check-write`, + /// `subscribe`, `write-zeroes`, and `flush`, and is implemented with + /// the following pseudo-code: + /// + /// ```text + /// let pollable = this.subscribe(); + /// while num_zeroes != 0 { + /// // Wait for the stream to become writable + /// pollable.block(); + /// let Ok(n) = this.check-write(); // eliding error handling + /// let len = min(n, num_zeroes); + /// this.write-zeroes(len); // eliding error handling + /// num_zeroes -= len; + /// } + /// this.flush(); + /// // Wait for completion of `flush` + /// pollable.block(); + /// // Check for any errors that arose during `flush` + /// let _ = this.check-write(); // eliding error handling + /// ``` + @since(version = 0.2.0) + blocking-write-zeroes-and-flush: func( + /// The number of zero-bytes to write + len: u64 + ) -> result<_, stream-error>; + + /// Read from one stream and write to another. + /// + /// The behavior of splice is equivalent to: + /// 1. calling `check-write` on the `output-stream` + /// 2. calling `read` on the `input-stream` with the smaller of the + /// `check-write` permitted length and the `len` provided to `splice` + /// 3. calling `write` on the `output-stream` with that read data. + /// + /// Any error reported by the call to `check-write`, `read`, or + /// `write` ends the splice and reports that error. + /// + /// This function returns the number of bytes transferred; it may be less + /// than `len`. + @since(version = 0.2.0) + splice: func( + /// The stream to read from + src: borrow, + /// The number of bytes to splice + len: u64, + ) -> result; + + /// Read from one stream and write to another, with blocking. + /// + /// This is similar to `splice`, except that it blocks until the + /// `output-stream` is ready for writing, and the `input-stream` + /// is ready for reading, before performing the `splice`. + @since(version = 0.2.0) + blocking-splice: func( + /// The stream to read from + src: borrow, + /// The number of bytes to splice + len: u64, + ) -> result; + } +} diff --git a/llm/bedrock/wit/deps/wasi:io/world.wit b/llm/bedrock/wit/deps/wasi:io/world.wit new file mode 100644 index 000000000..f1d2102dc --- /dev/null +++ b/llm/bedrock/wit/deps/wasi:io/world.wit @@ -0,0 +1,10 @@ +package wasi:io@0.2.3; + +@since(version = 0.2.0) +world imports { + @since(version = 0.2.0) + import streams; + + @since(version = 0.2.0) + import poll; +} diff --git a/llm/grok/src/bindings.rs b/llm/grok/src/bindings.rs index 2a101583e..c2f601347 100644 --- a/llm/grok/src/bindings.rs +++ b/llm/grok/src/bindings.rs @@ -1,12 +1,15 @@ -// Generated by `wit-bindgen` 0.36.0. DO NOT EDIT! +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! // Options used: // * runtime_path: "wit_bindgen_rt" // * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" // * generate_unused_types use golem_llm::golem::llm::llm as __with_name0; #[cfg(target_arch = "wasm32")] -#[link_section = "component-type:wit-bindgen:0.36.0:golem:llm-grok@1.0.0:llm-library:encoded world"] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-grok@1.0.0:llm-library:encoded world" +)] #[doc(hidden)] +#[allow(clippy::octal_escapes)] pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1757] = *b"\ \0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xdb\x0c\x01A\x02\x01\ A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ @@ -43,8 +46,8 @@ ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send \0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0\x20gol\ em:llm-grok/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09p\ -roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rust\ -\x060.36.0"; +roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rust\ +\x060.41.0"; #[inline(never)] #[doc(hidden)] pub fn __link_custom_section_describing_imports() { diff --git a/llm/grok/src/conversions.rs b/llm/grok/src/conversions.rs index 68a5d570c..129c128ad 100644 --- a/llm/grok/src/conversions.rs +++ b/llm/grok/src/conversions.rs @@ -183,7 +183,7 @@ fn convert_content_parts(contents: Vec) -> crate::client::Content { let media_type = &image_source.mime_type; // This is already a string result.push(crate::client::ContentPart::ImageInput { image_url: crate::client::ImageUrl { - url: format!("data:{};base64,{}", media_type, base64_data), + url: format!("data:{media_type};base64,{base64_data}"), detail: image_source.detail.map(|d| d.into()), }, }); diff --git a/llm/llm/Cargo.toml b/llm/llm/Cargo.toml index f12ada98b..fee601afe 100644 --- a/llm/llm/Cargo.toml +++ b/llm/llm/Cargo.toml @@ -12,6 +12,9 @@ path = "src/lib.rs" crate-type = ["rlib"] [dependencies] +aws-smithy-eventstream = "0.60.9" +aws-smithy-types = "1.3.2" +base64 = "0.22.1" golem-rust = { workspace = true } log = { workspace = true } mime = "0.3.17" diff --git a/llm/llm/src/config.rs b/llm/llm/src/config.rs index de9822c36..461010a06 100644 --- a/llm/llm/src/config.rs +++ b/llm/llm/src/config.rs @@ -1,5 +1,5 @@ use crate::golem::llm::llm::{Error, ErrorCode}; -use std::ffi::OsStr; +use std::{collections::HashMap, ffi::OsStr}; /// Gets an expected configuration value from the environment, and fails if its is not found /// using the `fail` function. Otherwise, it runs `succeed` with the configuration value. @@ -21,3 +21,31 @@ pub fn with_config_key( } } } + +/// Gets multiple expected configuration values from the environment, and fails if any is not found +/// using the `fail` function. Otherwise, it runs `succeed` with all configuration values. +pub fn with_config_keys( + keys: &[&str], + fail: impl FnOnce(Error) -> R, + succeed: impl FnOnce(HashMap) -> R, +) -> R { + let mut values = HashMap::new(); + + for key in keys { + match std::env::var(key) { + Ok(value) => { + values.insert(key.to_string(), value); + } + Err(_) => { + let error = Error { + code: ErrorCode::InternalError, + message: format!("Missing config key: {key}"), + provider_error_json: None, + }; + return fail(error); + } + } + } + + succeed(values) +} diff --git a/llm/llm/src/event_source/aws_eventstream.rs b/llm/llm/src/event_source/aws_eventstream.rs new file mode 100644 index 000000000..efd0f3748 --- /dev/null +++ b/llm/llm/src/event_source/aws_eventstream.rs @@ -0,0 +1,182 @@ +use std::task::Poll; + +use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder}; +use aws_smithy_types::event_stream::HeaderValue; +use base64::Engine; +use golem_rust::{ + bindings::wasi::io::streams::{InputStream, StreamError}, + wasm_rpc::Pollable, +}; +use log::trace; + +use crate::event_source::{ + stream::{LlmStream, StreamError as AwsEventStreamError}, + MessageEvent, +}; + +#[derive(Debug, Clone, Copy)] +pub enum AwsEventStreamState { + NotStarted, + Started, + Terminated, +} + +impl AwsEventStreamState { + fn is_terminated(self) -> bool { + matches!(self, Self::Terminated) + } +} + +/// A Stream of AWS EventStream events using the vnd.amazon.eventstream binary format +pub struct AwsEventStream { + stream: InputStream, + decoder: MessageFrameDecoder, + buffer: Vec, + state: AwsEventStreamState, + last_event_id: String, + subscription: Pollable, +} + +impl LlmStream for AwsEventStream { + fn new(stream: InputStream) -> Self { + let subscription = stream.subscribe(); + Self { + decoder: MessageFrameDecoder::new(), + buffer: Vec::new(), + state: AwsEventStreamState::NotStarted, + last_event_id: String::new(), + stream, + subscription, + } + } + + fn set_last_event_id(&mut self, id: impl Into) { + self.last_event_id = id.into(); + } + + fn last_event_id(&self) -> &str { + &self.last_event_id + } + + fn subscribe(&self) -> Pollable { + self.stream.subscribe() + } + + fn poll_next( + &mut self, + ) -> Poll>>> { + trace!("Polling for next AWS EventStream event"); + + match try_decode_message(self) { + Ok(Some(event)) => { + return Poll::Ready(Some(Ok(event))); + } + Err(err) => return Poll::Ready(Some(Err(err))), + _ => {} + } + + if self.state.is_terminated() { + return Poll::Ready(None); + } + + loop { + if self.subscription.ready() { + match self.stream.read(8192) { + Ok(bytes) => { + if bytes.is_empty() { + continue; + } + + if !self.state.is_terminated() { + self.state = AwsEventStreamState::Started; + } + + self.buffer.extend_from_slice(&bytes); + + match try_decode_message(self) { + Ok(Some(event)) => { + return Poll::Ready(Some(Ok(event))); + } + Err(err) => return Poll::Ready(Some(Err(err))), + _ => {} + } + } + Err(StreamError::Closed) => { + trace!("AWS EventStream closed"); + self.state = AwsEventStreamState::Terminated; + return Poll::Ready(None); + } + Err(err) => return Poll::Ready(Some(Err(AwsEventStreamError::Transport(err)))), + } + } else { + return Poll::Pending; + } + } + } +} + +fn try_decode_message( + stream: &mut AwsEventStream, +) -> Result, AwsEventStreamError> { + if stream.buffer.is_empty() { + return Ok(None); + } + + let mut buffer_slice = stream.buffer.as_slice(); + let original_len = buffer_slice.len(); + + match stream.decoder.decode_frame(&mut buffer_slice) { + Ok(DecodedFrame::Complete(message)) => { + trace!( + "Decoded AWS EventStream message with {} byte payload", + message.payload().len() + ); + + let consumed = original_len - buffer_slice.len(); + + let event_type = message + .headers() + .iter() + .find(|header| header.name().as_str() == ":event-type") + .and_then(|header| { + if let HeaderValue::String(s) = header.value() { + Some(s.as_str()) + } else { + None + } + }) + .unwrap_or("message"); + + let data = match std::str::from_utf8(message.payload()) { + Ok(s) => s.to_string(), + Err(_) => base64::engine::general_purpose::STANDARD.encode(message.payload()), + }; + + if let Some(id_header) = message + .headers() + .iter() + .find(|header| header.name().as_str() == ":event-id") + { + if let HeaderValue::String(id) = id_header.value() { + stream.last_event_id = id.as_str().to_string(); + } + } + + stream.buffer.drain(..consumed); + + let event = MessageEvent { + event: event_type.to_string(), + data, + id: stream.last_event_id.clone(), + retry: None, + }; + + Ok(Some(event)) + } + Ok(DecodedFrame::Incomplete) => Ok(None), + Err(err) => Err(AwsEventStreamError::Parser(nom::error::Error::new( + format!("AWS EventStream decode error: {err}"), + nom::error::ErrorKind::Tag, + ))), + } +} diff --git a/llm/llm/src/event_source/mod.rs b/llm/llm/src/event_source/mod.rs index 2e92d4dd8..4c8962503 100644 --- a/llm/llm/src/event_source/mod.rs +++ b/llm/llm/src/event_source/mod.rs @@ -1,6 +1,7 @@ // Based on https://github.com/jpopesculian/eventsource-stream and https://github.com/jpopesculian/reqwest-eventsource // modified to use the wasi-http based reqwest, and wasi pollables +mod aws_eventstream; pub mod error; mod event_stream; mod message_event; @@ -11,6 +12,7 @@ mod utf8_stream; use crate::event_source::error::Error; use crate::event_source::event_stream::EventStream; +use aws_eventstream::AwsEventStream; use golem_rust::wasm_rpc::Pollable; pub use message_event::MessageEvent; use ndjson_stream::NdJsonStream; @@ -50,15 +52,17 @@ impl EventSource { >(response.get_raw_input_stream()) }; - let stream = if response + let content_type = response .headers() .get(&reqwest::header::CONTENT_TYPE) .unwrap() .to_str() - .unwrap() - .contains("ndjson") - { + .unwrap(); + + let stream = if content_type.contains("ndjson") { StreamType::NdJsonStream(NdJsonStream::new(handle)) + } else if content_type.contains("vnd.amazon.eventstream") { + StreamType::AwsEventStream(AwsEventStream::new(handle)) } else { StreamType::EventStream(EventStream::new(handle)) }; @@ -90,6 +94,7 @@ impl EventSource { match &self.stream { StreamType::EventStream(stream) => stream.subscribe(), StreamType::NdJsonStream(stream) => stream.subscribe(), + StreamType::AwsEventStream(stream) => stream.subscribe(), } } @@ -111,6 +116,12 @@ impl EventSource { Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, }, + StreamType::AwsEventStream(stream) => match stream.poll_next() { + Poll::Ready(Some(Ok(event))) => Poll::Ready(Some(Ok(Event::Message(event)))), + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + }, } } } @@ -141,6 +152,10 @@ fn check_response(response: Response) -> Result { (mime_type.type_(), mime_type.subtype()), (mime::TEXT, mime::EVENT_STREAM) ) || mime_type.subtype().as_str().contains("ndjson") + || content_type + .to_str() + .unwrap_or("") + .contains("vnd.amazon.eventstream") }) .unwrap_or(false) { diff --git a/llm/llm/src/event_source/ndjson_stream.rs b/llm/llm/src/event_source/ndjson_stream.rs index e2f4cc1b2..1b8ef3773 100644 --- a/llm/llm/src/event_source/ndjson_stream.rs +++ b/llm/llm/src/event_source/ndjson_stream.rs @@ -126,7 +126,7 @@ fn try_parse_line( return Ok(None); } - trace!("Parsed NDJSON line: {}", line); + trace!("Parsed NDJSON line: {line}"); // Create a MessageEvent with the JSON line as data let event = MessageEvent { diff --git a/llm/llm/src/event_source/stream.rs b/llm/llm/src/event_source/stream.rs index 8f2933676..02d65f6c4 100644 --- a/llm/llm/src/event_source/stream.rs +++ b/llm/llm/src/event_source/stream.rs @@ -2,8 +2,8 @@ use core::fmt; use std::{string::FromUtf8Error, task::Poll}; use super::{ - event_stream::EventStream, ndjson_stream::NdJsonStream, utf8_stream::Utf8StreamError, - MessageEvent, + aws_eventstream::AwsEventStream, event_stream::EventStream, ndjson_stream::NdJsonStream, + utf8_stream::Utf8StreamError, MessageEvent, }; use golem_rust::{ bindings::wasi::io::streams::{InputStream, StreamError as WasiStreamError}, @@ -11,9 +11,11 @@ use golem_rust::{ }; use nom::error::Error as NomError; +#[allow(clippy::enum_variant_names)] pub enum StreamType { EventStream(EventStream), NdJsonStream(NdJsonStream), + AwsEventStream(AwsEventStream), } pub trait LlmStream { @@ -56,9 +58,9 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Utf8(err) => f.write_fmt(format_args!("UTF8 error: {}", err)), - Self::Parser(err) => f.write_fmt(format_args!("Parse error: {}", err)), - Self::Transport(err) => f.write_fmt(format_args!("Transport error: {}", err)), + Self::Utf8(err) => f.write_fmt(format_args!("UTF8 error: {err}")), + Self::Parser(err) => f.write_fmt(format_args!("Parse error: {err}")), + Self::Transport(err) => f.write_fmt(format_args!("Transport error: {err}")), } } } diff --git a/llm/ollama/src/bindings.rs b/llm/ollama/src/bindings.rs index dbb704704..269cd07fb 100644 --- a/llm/ollama/src/bindings.rs +++ b/llm/ollama/src/bindings.rs @@ -1,12 +1,15 @@ -// Generated by `wit-bindgen` 0.36.0. DO NOT EDIT! +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! // Options used: // * runtime_path: "wit_bindgen_rt" // * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" // * generate_unused_types use golem_llm::golem::llm::llm as __with_name0; #[cfg(target_arch = "wasm32")] -#[link_section = "component-type:wit-bindgen:0.36.0:golem:llm-ollama@1.0.0:llm-library:encoded world"] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-ollama@1.0.0:llm-library:encoded world" +)] #[doc(hidden)] +#[allow(clippy::octal_escapes)] pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1759] = *b"\ \0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xdd\x0c\x01A\x02\x01\ A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ @@ -43,8 +46,8 @@ ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send \0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0\"golem\ :llm-ollama/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09p\ -roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rust\ -\x060.36.0"; +roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rust\ +\x060.41.0"; #[inline(never)] #[doc(hidden)] pub fn __link_custom_section_describing_imports() { diff --git a/llm/ollama/src/client.rs b/llm/ollama/src/client.rs index e9514a8dd..e2901e70c 100644 --- a/llm/ollama/src/client.rs +++ b/llm/ollama/src/client.rs @@ -335,7 +335,7 @@ pub fn image_to_base64(source: &str) -> Result Error { Error { code: ErrorCode::InternalError, - message: format!("{}: {}", context, err), + message: format!("{context}: {err}"), provider_error_json: None, } } diff --git a/llm/ollama/src/conversions.rs b/llm/ollama/src/conversions.rs index b1db65c61..8d64e954f 100644 --- a/llm/ollama/src/conversions.rs +++ b/llm/ollama/src/conversions.rs @@ -214,7 +214,7 @@ pub fn process_response(response: CompletionsResponse) -> ChatEvent { }; ChatEvent::Message(CompleteResponse { - id: format!("ollama-{}", timestamp), + id: format!("ollama-{timestamp}"), content, tool_calls, metadata, diff --git a/llm/openai/src/bindings.rs b/llm/openai/src/bindings.rs index c960248a8..6d0a77280 100644 --- a/llm/openai/src/bindings.rs +++ b/llm/openai/src/bindings.rs @@ -1,12 +1,15 @@ -// Generated by `wit-bindgen` 0.36.0. DO NOT EDIT! +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! // Options used: // * runtime_path: "wit_bindgen_rt" // * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" // * generate_unused_types use golem_llm::golem::llm::llm as __with_name0; #[cfg(target_arch = "wasm32")] -#[link_section = "component-type:wit-bindgen:0.36.0:golem:llm-openai@1.0.0:llm-library:encoded world"] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-openai@1.0.0:llm-library:encoded world" +)] #[doc(hidden)] +#[allow(clippy::octal_escapes)] pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1759] = *b"\ \0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xdd\x0c\x01A\x02\x01\ A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ @@ -43,8 +46,8 @@ ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send \0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0\"golem\ :llm-openai/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09p\ -roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rust\ -\x060.36.0"; +roducers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rust\ +\x060.41.0"; #[inline(never)] #[doc(hidden)] pub fn __link_custom_section_describing_imports() { diff --git a/llm/openai/src/conversions.rs b/llm/openai/src/conversions.rs index 43694c0f3..a4989b0c1 100644 --- a/llm/openai/src/conversions.rs +++ b/llm/openai/src/conversions.rs @@ -138,7 +138,7 @@ pub fn content_part_to_inner_input_item(content_part: ContentPart) -> InnerInput ImageReference::Inline(image_source) => { let base64_data = general_purpose::STANDARD.encode(&image_source.data); let mime_type = &image_source.mime_type; // This is already a string - let data_url = format!("data:{};base64,{}", mime_type, base64_data); + let data_url = format!("data:{mime_type};base64,{base64_data}"); InnerInputItem::ImageInput { image_url: data_url, diff --git a/llm/openrouter/src/bindings.rs b/llm/openrouter/src/bindings.rs index ba2accf7e..1300cde97 100644 --- a/llm/openrouter/src/bindings.rs +++ b/llm/openrouter/src/bindings.rs @@ -1,12 +1,15 @@ -// Generated by `wit-bindgen` 0.36.0. DO NOT EDIT! +// Generated by `wit-bindgen` 0.41.0. DO NOT EDIT! // Options used: // * runtime_path: "wit_bindgen_rt" // * with "golem:llm/llm@1.0.0" = "golem_llm::golem::llm::llm" // * generate_unused_types use golem_llm::golem::llm::llm as __with_name0; #[cfg(target_arch = "wasm32")] -#[link_section = "component-type:wit-bindgen:0.36.0:golem:llm-openrouter@1.0.0:llm-library:encoded world"] +#[unsafe( + link_section = "component-type:wit-bindgen:0.41.0:golem:llm-openrouter@1.0.0:llm-library:encoded world" +)] #[doc(hidden)] +#[allow(clippy::octal_escapes)] pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 1763] = *b"\ \0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07\xe1\x0c\x01A\x02\x01\ A\x02\x01BO\x01m\x04\x04user\x09assistant\x06system\x04tool\x04\0\x04role\x03\0\0\ @@ -43,8 +46,8 @@ ng-get-next\x01B\x01p\x15\x01@\x02\x08messages\xc3\0\x06config)\06\x04\0\x04send \0\x06config)\06\x04\0\x08continue\x01G\x01i=\x01@\x02\x08messages\xc3\0\x06conf\ ig)\0\xc8\0\x04\0\x06stream\x01I\x04\0\x13golem:llm/llm@1.0.0\x05\0\x04\0&golem:\ llm-openrouter/llm-library@1.0.0\x04\0\x0b\x11\x01\0\x0bllm-library\x03\0\0\0G\x09\ -producers\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rus\ -t\x060.36.0"; +producers\x01\x0cprocessed-by\x02\x0dwit-component\x070.227.1\x10wit-bindgen-rus\ +t\x060.41.0"; #[inline(never)] #[doc(hidden)] pub fn __link_custom_section_describing_imports() { diff --git a/llm/openrouter/src/conversions.rs b/llm/openrouter/src/conversions.rs index d4db2d34c..61b5f973b 100644 --- a/llm/openrouter/src/conversions.rs +++ b/llm/openrouter/src/conversions.rs @@ -184,7 +184,7 @@ fn convert_content_parts(contents: Vec) -> crate::client::Content { let media_type = &image_source.mime_type; // This is already a string result.push(crate::client::ContentPart::ImageInput { image_url: crate::client::ImageUrl { - url: format!("data:{};base64,{}", media_type, base64_data), + url: format!("data:{media_type};base64,{base64_data}"), detail: image_source.detail.map(|d| d.into()), }, }); diff --git a/test/components-rust/test-llm/Cargo.toml b/test/components-rust/test-llm/Cargo.toml index 7f6242874..d4acefb4f 100644 --- a/test/components-rust/test-llm/Cargo.toml +++ b/test/components-rust/test-llm/Cargo.toml @@ -15,6 +15,7 @@ grok = [] openai = [] openrouter = [] ollama = [] +bedrock = [] [dependencies] # To use common shared libs, use the following: @@ -37,8 +38,8 @@ path = "wit-generated" [package.metadata.component.target.dependencies] "golem:llm" = { path = "wit-generated/deps/golem-llm" } -"wasi:clocks" = { path = "wit-generated/deps/clocks" } "wasi:io" = { path = "wit-generated/deps/io" } +"wasi:clocks" = { path = "wit-generated/deps/clocks" } "golem:rpc" = { path = "wit-generated/deps/golem-rpc" } "test:helper-client" = { path = "wit-generated/deps/test_helper-client" } "test:llm-exports" = { path = "wit-generated/deps/test_llm-exports" } diff --git a/test/components-rust/test-llm/golem.yaml b/test/components-rust/test-llm/golem.yaml index 6efa177c7..7076791d1 100644 --- a/test/components-rust/test-llm/golem.yaml +++ b/test/components-rust/test-llm/golem.yaml @@ -139,6 +139,32 @@ components: clean: - src/bindings.rs + bedrock-debug: + files: + - sourcePath: ../../data/cat.png + targetPath: /data/cat.png + permissions: read-only + build: + - command: cargo component build --no-default-features --features bedrock + sources: + - src + - wit-generated + - ../../common-rust + targets: + - ../../target/wasm32-wasip1/debug/test_llm.wasm + - command: wac plug --plug ../../../target/wasm32-wasip1/debug/golem_llm_bedrock.wasm ../../target/wasm32-wasip1/debug/test_llm.wasm -o ../../target/wasm32-wasip1/debug/test_bedrock_plugged.wasm + sources: + - ../../target/wasm32-wasip1/debug/test_llm.wasm + - ../../../target/wasm32-wasip1/debug/golem_llm_bedrock.wasm + targets: + - ../../target/wasm32-wasip1/debug/test_bedrock_plugged.wasm + sourceWit: wit + generatedWit: wit-generated + componentWasm: ../../target/wasm32-wasip1/debug/test_bedrock_plugged.wasm + linkedWasm: ../../golem-temp/components/test_bedrock_debug.wasm + clean: + - src/bindings.rs + # RELEASE PROFILES openai-release: files: @@ -270,6 +296,34 @@ components: clean: - src/bindings.rs + + bedrock-release: + files: + - sourcePath: ../../data/cat.png + targetPath: /data/cat.png + permissions: read-only + build: + - command: cargo component build --release --no-default-features --features bedrock + sources: + - src + - wit-generated + - ../../common-rust + targets: + - ../../target/wasm32-wasip1/release/test_llm.wasm + - command: wac plug --plug ../../../target/wasm32-wasip1/release/golem_llm_bedrock.wasm ../../target/wasm32-wasip1/release/test_llm.wasm -o ../../target/wasm32-wasip1/release/test_bedrock_plugged.wasm + sources: + - ../../target/wasm32-wasip1/release/test_llm.wasm + - ../../../target/wasm32-wasip1/release/golem_llm_bedrock.wasm + targets: + - ../../target/wasm32-wasip1/release/test_bedrock_plugged.wasm + sourceWit: wit + generatedWit: wit-generated + componentWasm: ../../target/wasm32-wasip1/release/test_bedrock_plugged.wasm + linkedWasm: ../../golem-temp/components/test_bedrock_release.wasm + clean: + - src/bindings.rs + + defaultProfile: openai-debug dependencies: diff --git a/test/components-rust/test-llm/src/lib.rs b/test/components-rust/test-llm/src/lib.rs index fa11684de..e0c7a8a0b 100644 --- a/test/components-rust/test-llm/src/lib.rs +++ b/test/components-rust/test-llm/src/lib.rs @@ -19,6 +19,8 @@ const MODEL: &'static str = "grok-3-beta"; const MODEL: &'static str = "openrouter/auto"; #[cfg(feature = "ollama")] const MODEL: &'static str = "qwen3:1.7b"; +#[cfg(feature = "bedrock")] +const MODEL: &'static str = "anthropic.claude-3-haiku-20240307-v1:0"; #[cfg(feature = "openai")] const IMAGE_MODEL: &'static str = "gpt-4o-mini"; @@ -30,6 +32,8 @@ const IMAGE_MODEL: &'static str = "grok-2-vision-latest"; const IMAGE_MODEL: &'static str = "openrouter/auto"; #[cfg(feature = "ollama")] const IMAGE_MODEL: &'static str = "gemma3:4b"; +#[cfg(feature = "bedrock")] +const IMAGE_MODEL: &'static str = "anthropic.claude-3-haiku-20240307-v1:0"; impl Guest for Component { /// test1 demonstrates a simple, non-streaming text question-answer interaction with the LLM.