1
1
use crate :: {
2
- async_utils,
2
+ async_utils:: UnsafeFuture ,
3
3
conversions:: { self , from_converse_sdk_error, from_converse_stream_sdk_error, BedrockInput } ,
4
4
stream:: BedrockChatStream ,
5
+ wasi_client:: WasiClient ,
5
6
} ;
6
7
use aws_config:: BehaviorVersion ;
7
8
use aws_sdk_bedrockruntime:: {
@@ -12,7 +13,6 @@ use aws_sdk_bedrockruntime::{
12
13
converse_stream:: builders:: ConverseStreamFluentBuilder ,
13
14
} ,
14
15
} ;
15
- use aws_smithy_wasm:: wasi:: WasiHttpClientBuilder ;
16
16
use aws_types:: region;
17
17
use golem_llm:: {
18
18
config:: { get_config_key, get_config_key_or_none} ,
@@ -27,69 +27,59 @@ pub struct Bedrock {
27
27
}
28
28
29
29
impl Bedrock {
30
- pub fn new ( ) -> Result < Self , llm:: Error > {
30
+ pub async fn new ( reactor : wasi_async_runtime :: Reactor ) -> Result < Self , llm:: Error > {
31
31
let environment = BedrockEnvironment :: load_from_env ( ) ?;
32
32
33
- let wasi_http = WasiHttpClientBuilder :: new ( ) . build ( ) ;
34
-
35
- let runtime = async_utils:: get_async_runtime ( ) ;
36
-
37
- runtime. block_on ( async {
38
- let sdk_config = aws_config:: defaults ( BehaviorVersion :: latest ( ) )
39
- . region ( environment. aws_region ( ) )
40
- . http_client ( wasi_http)
41
- . credentials_provider ( environment. aws_credentials ( ) )
42
- . sleep_impl ( WasiSleep )
43
- . load ( )
44
- . await ;
45
- let client = bedrock:: Client :: new ( & sdk_config) ;
46
- Ok ( Self { client } )
47
- } )
33
+ let sdk_config = aws_config:: defaults ( BehaviorVersion :: latest ( ) )
34
+ . region ( environment. aws_region ( ) )
35
+ . http_client ( WasiClient :: new ( reactor. clone ( ) ) )
36
+ . credentials_provider ( environment. aws_credentials ( ) )
37
+ . sleep_impl ( WasiSleep :: new ( reactor) )
38
+ . load ( )
39
+ . await ;
40
+ let client = bedrock:: Client :: new ( & sdk_config) ;
41
+ Ok ( Self { client } )
48
42
}
49
43
50
- pub fn converse (
44
+ pub async fn converse (
51
45
& self ,
52
46
messages : Vec < llm:: Message > ,
53
47
config : llm:: Config ,
54
48
tool_results : Option < Vec < ( llm:: ToolCall , llm:: ToolResult ) > > ,
55
49
) -> llm:: ChatEvent {
56
50
let bedrock_input = BedrockInput :: from ( messages, config, tool_results) ;
57
51
58
- let runtime = async_utils:: get_async_runtime ( ) ;
59
-
60
52
match bedrock_input {
61
53
Err ( err) => llm:: ChatEvent :: Error ( err) ,
62
54
Ok ( input) => {
63
55
trace ! ( "Sending request to AWS Bedrock: {input:?}" ) ;
64
- runtime. block_on ( async {
65
- let model_id = input. model_id . clone ( ) ;
66
- let response = self
67
- . init_converse ( input)
68
- . send ( )
69
- . await
70
- . map_err ( |e| from_converse_sdk_error ( model_id, e) ) ;
71
-
72
- match response {
73
- Err ( err) => llm:: ChatEvent :: Error ( err) ,
74
- Ok ( response) => {
75
- let event = match response. stop_reason ( ) {
76
- bedrock:: types:: StopReason :: ToolUse => {
77
- conversions:: converse_output_to_tool_calls ( response)
78
- . map ( llm:: ChatEvent :: ToolRequest )
79
- }
80
- _ => conversions:: converse_output_to_complete_response ( response)
81
- . map ( llm:: ChatEvent :: Message ) ,
82
- } ;
83
-
84
- event. unwrap_or_else ( llm:: ChatEvent :: Error )
85
- }
56
+ let model_id = input. model_id . clone ( ) ;
57
+ let response = self
58
+ . init_converse ( input)
59
+ . send ( )
60
+ . await
61
+ . map_err ( |e| from_converse_sdk_error ( model_id, e) ) ;
62
+
63
+ match response {
64
+ Err ( err) => llm:: ChatEvent :: Error ( err) ,
65
+ Ok ( response) => {
66
+ let event = match response. stop_reason ( ) {
67
+ bedrock:: types:: StopReason :: ToolUse => {
68
+ conversions:: converse_output_to_tool_calls ( response)
69
+ . map ( llm:: ChatEvent :: ToolRequest )
70
+ }
71
+ _ => conversions:: converse_output_to_complete_response ( response)
72
+ . map ( llm:: ChatEvent :: Message ) ,
73
+ } ;
74
+
75
+ event. unwrap_or_else ( llm:: ChatEvent :: Error )
86
76
}
87
- } )
77
+ }
88
78
}
89
79
}
90
80
}
91
81
92
- pub fn converse_stream (
82
+ pub async fn converse_stream (
93
83
& self ,
94
84
messages : Vec < llm:: Message > ,
95
85
config : llm:: Config ,
@@ -99,22 +89,19 @@ impl Bedrock {
99
89
match bedrock_input {
100
90
Err ( err) => BedrockChatStream :: failed ( err) ,
101
91
Ok ( input) => {
102
- let runtime = async_utils:: get_async_runtime ( ) ;
103
92
trace ! ( "Sending request to AWS Bedrock: {input:?}" ) ;
104
- runtime. block_on ( async {
105
- let model_id = input. model_id . clone ( ) ;
106
- let response = self
107
- . init_converse_stream ( input)
108
- . send ( )
109
- . await
110
- . map_err ( |e| from_converse_stream_sdk_error ( model_id, e) ) ;
111
-
112
- trace ! ( "Creating AWS Bedrock event stream" ) ;
113
- match response {
114
- Ok ( response) => BedrockChatStream :: new ( response. stream ) ,
115
- Err ( error) => BedrockChatStream :: failed ( error) ,
116
- }
117
- } )
93
+ let model_id = input. model_id . clone ( ) ;
94
+ let response = self
95
+ . init_converse_stream ( input)
96
+ . send ( )
97
+ . await
98
+ . map_err ( |e| from_converse_stream_sdk_error ( model_id, e) ) ;
99
+
100
+ trace ! ( "Creating AWS Bedrock event stream" ) ;
101
+ match response {
102
+ Ok ( response) => BedrockChatStream :: new ( response. stream ) ,
103
+ Err ( error) => BedrockChatStream :: failed ( error) ,
104
+ }
118
105
}
119
106
}
120
107
}
@@ -146,15 +133,15 @@ impl Bedrock {
146
133
}
147
134
148
135
#[ derive( Debug ) ]
149
- struct BedrockEnvironment {
136
+ pub struct BedrockEnvironment {
150
137
access_key_id : String ,
151
138
region : String ,
152
139
secret_access_key : String ,
153
140
session_token : Option < String > ,
154
141
}
155
142
156
143
impl BedrockEnvironment {
157
- fn load_from_env ( ) -> Result < Self , llm:: Error > {
144
+ pub fn load_from_env ( ) -> Result < Self , llm:: Error > {
158
145
Ok ( Self {
159
146
access_key_id : get_config_key ( "AWS_ACCESS_KEY_ID" ) ?,
160
147
region : get_config_key ( "AWS_REGION" ) ?,
@@ -179,12 +166,32 @@ impl BedrockEnvironment {
179
166
}
180
167
181
168
#[ derive( Debug , Clone ) ]
182
- struct WasiSleep ;
169
+ struct WasiSleep {
170
+ reactor : wasi_async_runtime:: Reactor ,
171
+ }
172
+
173
+ impl WasiSleep {
174
+ fn new ( reactor : wasi_async_runtime:: Reactor ) -> Self {
175
+ Self { reactor }
176
+ }
177
+ }
178
+
183
179
impl AsyncSleep for WasiSleep {
184
180
fn sleep ( & self , duration : std:: time:: Duration ) -> Sleep {
185
- Sleep :: new ( Box :: pin ( async move {
181
+ let reactor = self . reactor . clone ( ) ;
182
+
183
+ let fut = async move {
186
184
let nanos = duration. as_nanos ( ) as u64 ;
187
- monotonic_clock:: subscribe_duration ( nanos) . block ( ) ;
188
- } ) )
185
+ let pollable = monotonic_clock:: subscribe_duration ( nanos) ;
186
+
187
+ reactor
188
+ . clone ( )
189
+ . wait_for ( unsafe { std:: mem:: transmute ( pollable) } )
190
+ . await ;
191
+ } ;
192
+ Sleep :: new ( Box :: pin ( UnsafeFuture :: new ( fut) ) )
189
193
}
190
194
}
195
+
196
+ unsafe impl Send for WasiSleep { }
197
+ unsafe impl Sync for WasiSleep { }
0 commit comments