Skip to content

Commit 348323d

Browse files
committed
feat: switch to wasi-async-runtime from tokio
1 parent 70a0146 commit 348323d

File tree

7 files changed

+96
-22
lines changed

7 files changed

+96
-22
lines changed

Cargo.lock

Lines changed: 62 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llm/bedrock/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ aws-smithy-wasm = { version = "0.1.4", default-features = false }
1717
aws-sdk-bedrockruntime = { version = "1.56.0", default-features = false }
1818
aws-smithy-types = { version = "1.3.1" }
1919

20-
tokio = { version = "1.43.0", features = ["rt", "sync", "time"] }
20+
wasi-async-runtime = "0.1.2"
2121

2222
# To infer mime types of downloaded images before passing to bedrock
2323
infer = { version = "0.19.0", default-features = false }

llm/bedrock/src/async_utils.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
use std::future::Future;
2+
3+
pub fn get_async_runtime() -> AsyncRuntime {
4+
AsyncRuntime
5+
}
6+
7+
pub struct AsyncRuntime;
8+
9+
impl AsyncRuntime {
10+
pub fn block_on<F>(self, f: F) -> F::Output
11+
where
12+
F: Future,
13+
{
14+
wasi_async_runtime::block_on(|_| async { f.await })
15+
}
16+
}

llm/bedrock/src/client.rs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::{
2+
async_utils,
23
conversions::{self, from_converse_sdk_error, from_converse_stream_sdk_error, BedrockInput},
34
stream::BedrockChatStream,
45
};
@@ -17,6 +18,7 @@ use golem_llm::{
1718
config::{get_config_key, get_config_key_or_none},
1819
golem::llm::llm,
1920
};
21+
use golem_rust::bindings::wasi::clocks::monotonic_clock;
2022
use log::trace;
2123

2224
#[derive(Debug)]
@@ -30,14 +32,14 @@ impl Bedrock {
3032

3133
let wasi_http = WasiHttpClientBuilder::new().build();
3234

33-
let runtime = get_async_runtime();
35+
let runtime = async_utils::get_async_runtime();
3436

3537
runtime.block_on(async {
3638
let sdk_config = aws_config::defaults(BehaviorVersion::latest())
3739
.region(environment.aws_region())
3840
.http_client(wasi_http)
3941
.credentials_provider(environment.aws_credentials())
40-
.sleep_impl(TokioSleep)
42+
.sleep_impl(WasiSleep)
4143
.load()
4244
.await;
4345
let client = bedrock::Client::new(&sdk_config);
@@ -53,7 +55,7 @@ impl Bedrock {
5355
) -> llm::ChatEvent {
5456
let bedrock_input = BedrockInput::from(messages, config, tool_results);
5557

56-
let runtime = get_async_runtime();
58+
let runtime = async_utils::get_async_runtime();
5759

5860
match bedrock_input {
5961
Err(err) => llm::ChatEvent::Error(err),
@@ -97,7 +99,7 @@ impl Bedrock {
9799
match bedrock_input {
98100
Err(err) => BedrockChatStream::failed(err),
99101
Ok(input) => {
100-
let runtime = get_async_runtime();
102+
let runtime = async_utils::get_async_runtime();
101103
trace!("Sending request to AWS Bedrock: {input:?}");
102104
runtime.block_on(async {
103105
let model_id = input.model_id.clone();
@@ -176,19 +178,13 @@ impl BedrockEnvironment {
176178
}
177179
}
178180

179-
pub fn get_async_runtime() -> tokio::runtime::Runtime {
180-
tokio::runtime::Builder::new_current_thread()
181-
.enable_time()
182-
.build()
183-
.unwrap()
184-
}
185-
186181
#[derive(Debug, Clone)]
187-
struct TokioSleep;
188-
impl AsyncSleep for TokioSleep {
182+
struct WasiSleep;
183+
impl AsyncSleep for WasiSleep {
189184
fn sleep(&self, duration: std::time::Duration) -> Sleep {
190185
Sleep::new(Box::pin(async move {
191-
tokio::time::sleep(duration).await;
186+
let nanos = duration.as_nanos() as u64;
187+
monotonic_clock::subscribe_duration(nanos).block();
192188
}))
193189
}
194190
}

llm/bedrock/src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ use golem_llm::{
66
golem::llm::llm::{self, ChatEvent, ChatStream, Config, Guest, Message, ToolCall, ToolResult},
77
LOGGING_STATE,
88
};
9+
use golem_rust::bindings::wasi::clocks::monotonic_clock;
910
use stream::BedrockChatStream;
1011

12+
mod async_utils;
1113
mod client;
1214
mod conversions;
1315
mod stream;
@@ -116,7 +118,8 @@ impl ExtendedGuest for BedrockComponent {
116118
}
117119

118120
fn subscribe(_stream: &Self::ChatStream) -> golem_rust::wasm_rpc::Pollable {
119-
unimplemented!()
121+
// this function will never get called in bedrock implementation because of `golem-llm/nopoll` feature flag
122+
monotonic_clock::subscribe_duration(0)
120123
}
121124
}
122125

llm/bedrock/src/stream.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use golem_llm::golem::llm::llm;
66
use std::cell::{RefCell, RefMut};
77

88
use crate::{
9-
client::get_async_runtime,
9+
async_utils,
1010
conversions::{converse_stream_output_to_stream_event, custom_error, merge_metadata},
1111
};
1212

@@ -53,7 +53,7 @@ impl BedrockChatStream {
5353
}
5454
fn get_single_event(&self) -> Option<llm::StreamEvent> {
5555
if let Some(stream) = self.stream_mut().as_mut() {
56-
let runtime = get_async_runtime();
56+
let runtime = async_utils::get_async_runtime();
5757

5858
runtime.block_on(async move {
5959
let token = stream.recv().await;

test/components-rust/test-llm/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ path = "wit-generated"
3838

3939
[package.metadata.component.target.dependencies]
4040
"golem:llm" = { path = "wit-generated/deps/golem-llm" }
41-
"wasi:clocks" = { path = "wit-generated/deps/clocks" }
4241
"wasi:io" = { path = "wit-generated/deps/io" }
42+
"wasi:clocks" = { path = "wit-generated/deps/clocks" }
4343
"golem:rpc" = { path = "wit-generated/deps/golem-rpc" }
4444
"test:helper-client" = { path = "wit-generated/deps/test_helper-client" }
4545
"test:llm-exports" = { path = "wit-generated/deps/test_llm-exports" }

0 commit comments

Comments
 (0)