Skip to content

Commit d68436b

Browse files
committed
feat: implement wasi_client for aws operations. Use reqwest based wasi async client
1 parent 8686a0c commit d68436b

File tree

8 files changed

+699
-136
lines changed

8 files changed

+699
-136
lines changed

Cargo.lock

Lines changed: 379 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llm/bedrock/Cargo.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ aws-types = { version = "1.3.4", default-features = false }
1616
aws-smithy-wasm = { version = "0.1.4", default-features = false }
1717
aws-sdk-bedrockruntime = { version = "1.56.0", default-features = false }
1818
aws-smithy-types = { version = "1.3.1" }
19+
aws-smithy-runtime-api = "1.8.3"
1920

2021
wasi-async-runtime = "0.1.2"
2122

@@ -26,11 +27,15 @@ golem-llm = { workspace = true }
2627

2728
golem-rust = { workspace = true }
2829
log = { workspace = true }
29-
reqwest = { workspace = true }
30+
reqwest = { git = "https://github.com/golemcloud/reqwest", branch = "update-july-2025", features = [
31+
"json",
32+
"async",
33+
] }
3034
serde = { workspace = true }
3135
serde_json = { workspace = true }
3236
wit-bindgen-rt = { workspace = true }
3337
base64 = { workspace = true }
38+
bytes = "1.10.1"
3439

3540
[lib]
3641
crate-type = ["cdylib"]

llm/bedrock/src/async_utils.rs

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,43 @@ pub fn get_async_runtime() -> AsyncRuntime {
77
pub struct AsyncRuntime;
88

99
impl AsyncRuntime {
10-
pub fn block_on<F>(self, f: F) -> F::Output
10+
pub fn block_on<F, Fut>(self, f: F) -> Fut::Output
1111
where
12-
F: Future,
12+
F: FnOnce(wasi_async_runtime::Reactor) -> Fut,
13+
Fut: Future,
1314
{
14-
wasi_async_runtime::block_on(|_| f)
15+
wasi_async_runtime::block_on(f)
16+
}
17+
}
18+
19+
#[derive(Clone)]
20+
pub struct UnsafeFuture<Fut> {
21+
inner: Fut,
22+
}
23+
24+
impl<F> UnsafeFuture<F>
25+
where
26+
F: Future,
27+
{
28+
pub fn new(inner: F) -> Self {
29+
Self { inner }
30+
}
31+
}
32+
33+
unsafe impl<F> Send for UnsafeFuture<F> where F: Future {}
34+
unsafe impl<F> Sync for UnsafeFuture<F> where F: Future {}
35+
36+
impl<F> Future for UnsafeFuture<F>
37+
where
38+
F: Future,
39+
{
40+
type Output = F::Output;
41+
42+
fn poll(
43+
mut self: std::pin::Pin<&mut Self>,
44+
cx: &mut std::task::Context<'_>,
45+
) -> std::task::Poll<Self::Output> {
46+
let pinned_future = unsafe { self.as_mut().map_unchecked_mut(|this| &mut this.inner) };
47+
pinned_future.poll(cx)
1548
}
1649
}

llm/bedrock/src/client.rs

Lines changed: 73 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use crate::{
2-
async_utils,
2+
async_utils::UnsafeFuture,
33
conversions::{self, from_converse_sdk_error, from_converse_stream_sdk_error, BedrockInput},
44
stream::BedrockChatStream,
5+
wasi_client::WasiClient,
56
};
67
use aws_config::BehaviorVersion;
78
use aws_sdk_bedrockruntime::{
@@ -12,7 +13,6 @@ use aws_sdk_bedrockruntime::{
1213
converse_stream::builders::ConverseStreamFluentBuilder,
1314
},
1415
};
15-
use aws_smithy_wasm::wasi::WasiHttpClientBuilder;
1616
use aws_types::region;
1717
use golem_llm::{
1818
config::{get_config_key, get_config_key_or_none},
@@ -27,69 +27,59 @@ pub struct Bedrock {
2727
}
2828

2929
impl Bedrock {
30-
pub fn new() -> Result<Self, llm::Error> {
30+
pub async fn new(reactor: wasi_async_runtime::Reactor) -> Result<Self, llm::Error> {
3131
let environment = BedrockEnvironment::load_from_env()?;
3232

33-
let wasi_http = WasiHttpClientBuilder::new().build();
34-
35-
let runtime = async_utils::get_async_runtime();
36-
37-
runtime.block_on(async {
38-
let sdk_config = aws_config::defaults(BehaviorVersion::latest())
39-
.region(environment.aws_region())
40-
.http_client(wasi_http)
41-
.credentials_provider(environment.aws_credentials())
42-
.sleep_impl(WasiSleep)
43-
.load()
44-
.await;
45-
let client = bedrock::Client::new(&sdk_config);
46-
Ok(Self { client })
47-
})
33+
let sdk_config = aws_config::defaults(BehaviorVersion::latest())
34+
.region(environment.aws_region())
35+
.http_client(WasiClient::new(reactor.clone()))
36+
.credentials_provider(environment.aws_credentials())
37+
.sleep_impl(WasiSleep::new(reactor))
38+
.load()
39+
.await;
40+
let client = bedrock::Client::new(&sdk_config);
41+
Ok(Self { client })
4842
}
4943

50-
pub fn converse(
44+
pub async fn converse(
5145
&self,
5246
messages: Vec<llm::Message>,
5347
config: llm::Config,
5448
tool_results: Option<Vec<(llm::ToolCall, llm::ToolResult)>>,
5549
) -> llm::ChatEvent {
5650
let bedrock_input = BedrockInput::from(messages, config, tool_results);
5751

58-
let runtime = async_utils::get_async_runtime();
59-
6052
match bedrock_input {
6153
Err(err) => llm::ChatEvent::Error(err),
6254
Ok(input) => {
6355
trace!("Sending request to AWS Bedrock: {input:?}");
64-
runtime.block_on(async {
65-
let model_id = input.model_id.clone();
66-
let response = self
67-
.init_converse(input)
68-
.send()
69-
.await
70-
.map_err(|e| from_converse_sdk_error(model_id, e));
71-
72-
match response {
73-
Err(err) => llm::ChatEvent::Error(err),
74-
Ok(response) => {
75-
let event = match response.stop_reason() {
76-
bedrock::types::StopReason::ToolUse => {
77-
conversions::converse_output_to_tool_calls(response)
78-
.map(llm::ChatEvent::ToolRequest)
79-
}
80-
_ => conversions::converse_output_to_complete_response(response)
81-
.map(llm::ChatEvent::Message),
82-
};
83-
84-
event.unwrap_or_else(llm::ChatEvent::Error)
85-
}
56+
let model_id = input.model_id.clone();
57+
let response = self
58+
.init_converse(input)
59+
.send()
60+
.await
61+
.map_err(|e| from_converse_sdk_error(model_id, e));
62+
63+
match response {
64+
Err(err) => llm::ChatEvent::Error(err),
65+
Ok(response) => {
66+
let event = match response.stop_reason() {
67+
bedrock::types::StopReason::ToolUse => {
68+
conversions::converse_output_to_tool_calls(response)
69+
.map(llm::ChatEvent::ToolRequest)
70+
}
71+
_ => conversions::converse_output_to_complete_response(response)
72+
.map(llm::ChatEvent::Message),
73+
};
74+
75+
event.unwrap_or_else(llm::ChatEvent::Error)
8676
}
87-
})
77+
}
8878
}
8979
}
9080
}
9181

92-
pub fn converse_stream(
82+
pub async fn converse_stream(
9383
&self,
9484
messages: Vec<llm::Message>,
9585
config: llm::Config,
@@ -99,22 +89,19 @@ impl Bedrock {
9989
match bedrock_input {
10090
Err(err) => BedrockChatStream::failed(err),
10191
Ok(input) => {
102-
let runtime = async_utils::get_async_runtime();
10392
trace!("Sending request to AWS Bedrock: {input:?}");
104-
runtime.block_on(async {
105-
let model_id = input.model_id.clone();
106-
let response = self
107-
.init_converse_stream(input)
108-
.send()
109-
.await
110-
.map_err(|e| from_converse_stream_sdk_error(model_id, e));
111-
112-
trace!("Creating AWS Bedrock event stream");
113-
match response {
114-
Ok(response) => BedrockChatStream::new(response.stream),
115-
Err(error) => BedrockChatStream::failed(error),
116-
}
117-
})
93+
let model_id = input.model_id.clone();
94+
let response = self
95+
.init_converse_stream(input)
96+
.send()
97+
.await
98+
.map_err(|e| from_converse_stream_sdk_error(model_id, e));
99+
100+
trace!("Creating AWS Bedrock event stream");
101+
match response {
102+
Ok(response) => BedrockChatStream::new(response.stream),
103+
Err(error) => BedrockChatStream::failed(error),
104+
}
118105
}
119106
}
120107
}
@@ -146,15 +133,15 @@ impl Bedrock {
146133
}
147134

148135
#[derive(Debug)]
149-
struct BedrockEnvironment {
136+
pub struct BedrockEnvironment {
150137
access_key_id: String,
151138
region: String,
152139
secret_access_key: String,
153140
session_token: Option<String>,
154141
}
155142

156143
impl BedrockEnvironment {
157-
fn load_from_env() -> Result<Self, llm::Error> {
144+
pub fn load_from_env() -> Result<Self, llm::Error> {
158145
Ok(Self {
159146
access_key_id: get_config_key("AWS_ACCESS_KEY_ID")?,
160147
region: get_config_key("AWS_REGION")?,
@@ -179,12 +166,32 @@ impl BedrockEnvironment {
179166
}
180167

181168
#[derive(Debug, Clone)]
182-
struct WasiSleep;
169+
struct WasiSleep {
170+
reactor: wasi_async_runtime::Reactor,
171+
}
172+
173+
impl WasiSleep {
174+
fn new(reactor: wasi_async_runtime::Reactor) -> Self {
175+
Self { reactor }
176+
}
177+
}
178+
183179
impl AsyncSleep for WasiSleep {
184180
fn sleep(&self, duration: std::time::Duration) -> Sleep {
185-
Sleep::new(Box::pin(async move {
181+
let reactor = self.reactor.clone();
182+
183+
let fut = async move {
186184
let nanos = duration.as_nanos() as u64;
187-
monotonic_clock::subscribe_duration(nanos).block();
188-
}))
185+
let pollable = monotonic_clock::subscribe_duration(nanos);
186+
187+
reactor
188+
.clone()
189+
.wait_for(unsafe { std::mem::transmute(pollable) })
190+
.await;
191+
};
192+
Sleep::new(Box::pin(UnsafeFuture::new(fut)))
189193
}
190194
}
195+
196+
unsafe impl Send for WasiSleep {}
197+
unsafe impl Sync for WasiSleep {}

llm/bedrock/src/conversions.rs

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ use aws_sdk_bedrockruntime::{
1414
};
1515
use golem_llm::golem::llm::llm;
1616

17+
use crate::async_utils::get_async_runtime;
18+
1719
#[derive(Debug)]
1820
pub struct BedrockInput {
1921
pub model_id: String,
@@ -225,34 +227,39 @@ fn get_image_content_block_from_url(url: &str) -> Result<bedrock::types::Content
225227
}
226228

227229
fn get_bytes_from_url(url: &str) -> Result<Vec<u8>, llm::Error> {
228-
let client = reqwest::Client::builder()
229-
.build()
230-
.expect("Failed to initialize HTTP client");
231-
232-
let response = client.get(url).send().map_err(|err| {
233-
custom_error(
234-
llm::ErrorCode::InvalidRequest,
235-
format!("Could not read image bytes from url: {url}, cause: {err}"),
236-
)
237-
})?;
238-
if !response.status().is_success() {
239-
return Err(custom_error(
240-
llm::ErrorCode::InvalidRequest,
241-
format!(
242-
"Could not read image bytes from url: {url}, cause: request failed with status: {}",
243-
response.status()
244-
),
245-
));
246-
}
230+
let runtime = get_async_runtime();
247231

248-
let bytes = response.bytes().map_err(|err| {
249-
custom_error(
250-
llm::ErrorCode::InvalidRequest,
251-
format!("Could not read image bytes from url: {url}, cause: {err}"),
252-
)
253-
})?;
232+
runtime.block_on(|reactor| async {
233+
let client = reqwest::Client::builder(reactor)
234+
.build()
235+
.expect("Failed to initialize HTTP client");
236+
237+
let response = client.get(url).send().await.map_err(|err| {
238+
custom_error(
239+
llm::ErrorCode::InvalidRequest,
240+
format!("Could not read image bytes from url: {url}, cause: {err}"),
241+
)
242+
})?;
243+
if !response.status().is_success() {
244+
return Err(custom_error(
245+
llm::ErrorCode::InvalidRequest,
246+
format!(
247+
"Could not read image bytes from url: {url}, cause: request failed with status: {}",
248+
response.status()
249+
),
250+
));
251+
}
252+
253+
let bytes = response.bytes().await.map_err(|err| {
254+
custom_error(
255+
llm::ErrorCode::InvalidRequest,
256+
format!("Could not read image bytes from url: {url}, cause: {err}"),
257+
)
258+
})?;
259+
260+
Ok(bytes.to_vec())
254261

255-
Ok(bytes.to_vec())
262+
})
256263
}
257264

258265
fn str_to_bedrock_mime_type(mime_type: &str) -> Result<ImageFormat, llm::Error> {

0 commit comments

Comments
 (0)