@@ -66,7 +66,6 @@ impl BedrockClient {
66
66
} ) ?;
67
67
68
68
let mut request_builder = self . client . request ( Method :: POST , & url) ;
69
- request_builder = request_builder. header ( "content-type" , "application/json" ) ;
70
69
for ( key, value) in headers {
71
70
request_builder = request_builder. header ( key, value) ;
72
71
}
@@ -116,7 +115,6 @@ impl BedrockClient {
116
115
} ) ?;
117
116
118
117
let mut request_builder = self . client . request ( Method :: POST , & url) ;
119
- request_builder = request_builder. header ( "content-type" , "application/json" ) ;
120
118
for ( key, value) in headers {
121
119
request_builder = request_builder. header ( key, value) ;
122
120
}
@@ -144,6 +142,8 @@ pub fn generate_sigv4_headers(
144
142
host : & str ,
145
143
body : & str ,
146
144
) -> Result < Vec < ( String , String ) > , Box < dyn std:: error:: Error > > {
145
+ use std:: collections:: BTreeMap ;
146
+
147
147
let now = SystemTime :: now ( ) . duration_since ( UNIX_EPOCH ) . unwrap ( ) ;
148
148
let timestamp = OffsetDateTime :: from_unix_timestamp ( now. as_secs ( ) as i64 ) . unwrap ( ) ;
149
149
@@ -163,48 +163,60 @@ pub fn generate_sigv4_headers(
163
163
timestamp. second( )
164
164
) ;
165
165
166
- // Create canonical request
167
- let path = if uri. starts_with ( '/' ) { uri } else { "/" } ;
168
- let query = "" ;
169
-
170
- // Create canonical headers
171
- let mut headers: Vec < ( String , String ) > = vec ! [
172
- ( "host" . to_string( ) , host. to_string( ) ) ,
173
- ( "x-amz-date" . to_string( ) , datetime_str. clone( ) ) ,
174
- ] ;
175
- headers. sort_by ( |a, b| a. 0 . cmp ( & b. 0 ) ) ;
166
+ let ( canonical_uri, canonical_query_string) = if let Some ( query_pos) = uri. find ( '?' ) {
167
+ let path = & uri[ ..query_pos] ;
168
+ let query = & uri[ query_pos + 1 ..] ;
169
+
170
+ let encoded_path = if path. contains ( ':' ) {
171
+ path. replace ( ':' , "%3A" )
172
+ } else {
173
+ path. to_string ( )
174
+ } ;
175
+
176
+ let mut query_params: Vec < & str > = query. split ( '&' ) . collect ( ) ;
177
+ query_params. sort ( ) ;
178
+ ( encoded_path, query_params. join ( "&" ) )
179
+ } else {
180
+ let encoded_path = if uri. contains ( ':' ) {
181
+ uri. replace ( ':' , "%3A" )
182
+ } else {
183
+ uri. to_string ( )
184
+ } ;
185
+ ( encoded_path, String :: new ( ) )
186
+ } ;
187
+
188
+ let mut headers = BTreeMap :: new ( ) ;
189
+ headers. insert ( "content-type" , "application/x-amz-json-1.0" ) ;
190
+ headers. insert ( "host" , host) ;
191
+ headers. insert ( "x-amz-date" , & datetime_str) ;
176
192
177
193
let canonical_headers = headers
178
194
. iter ( )
179
- . map ( |( k, v) | format ! ( "{}:{}" , k, v) )
195
+ . map ( |( k, v) | format ! ( "{}:{}" , k. to_lowercase ( ) . trim ( ) , v. trim ( ) ) )
180
196
. collect :: < Vec < _ > > ( )
181
197
. join ( "\n " )
182
198
+ "\n " ;
183
199
184
200
let signed_headers = headers
185
- . iter ( )
186
- . map ( |( k , _ ) | k. as_str ( ) )
201
+ . keys ( )
202
+ . map ( |k | k. to_lowercase ( ) )
187
203
. collect :: < Vec < _ > > ( )
188
204
. join ( ";" ) ;
189
205
190
- // Hash payload
191
206
let payload_hash = format ! ( "{:x}" , Sha256 :: digest( body. as_bytes( ) ) ) ;
192
207
193
208
let canonical_request = format ! (
194
209
"{}\n {}\n {}\n {}\n {}\n {}" ,
195
- method, path , query , canonical_headers, signed_headers, payload_hash
210
+ method, canonical_uri , canonical_query_string , canonical_headers, signed_headers, payload_hash
196
211
) ;
197
212
198
- // Create string to sign
199
213
let credential_scope = format ! ( "{}/{}/{}/aws4_request" , date_str, region, service) ;
214
+ let canonical_request_hash = format ! ( "{:x}" , Sha256 :: digest( canonical_request. as_bytes( ) ) ) ;
200
215
let string_to_sign = format ! (
201
- "AWS4-HMAC-SHA256\n {}\n {}\n {:x}" ,
202
- datetime_str,
203
- credential_scope,
204
- Sha256 :: digest( canonical_request. as_bytes( ) )
216
+ "AWS4-HMAC-SHA256\n {}\n {}\n {}" ,
217
+ datetime_str, credential_scope, canonical_request_hash
205
218
) ;
206
219
207
- // Calculate signature
208
220
type HmacSha256 = Hmac < Sha256 > ;
209
221
210
222
let mut mac = HmacSha256 :: new_from_slice ( format ! ( "AWS4{}" , secret_key) . as_bytes ( ) ) ?;
@@ -227,15 +239,15 @@ pub fn generate_sigv4_headers(
227
239
mac. update ( string_to_sign. as_bytes ( ) ) ;
228
240
let signature = format ! ( "{:x}" , mac. finalize( ) . into_bytes( ) ) ;
229
241
230
- // Create authorization header
231
242
let auth_header = format ! (
232
243
"AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}" ,
233
244
access_key, credential_scope, signed_headers, signature
234
245
) ;
235
246
236
- let mut result_headers = vec ! [
247
+ let result_headers = vec ! [
237
248
( "authorization" . to_string( ) , auth_header) ,
238
249
( "x-amz-date" . to_string( ) , datetime_str) ,
250
+ ( "content-type" . to_string( ) , "application/x-amz-json-1.0" . to_string( ) ) ,
239
251
] ;
240
252
241
253
Ok ( result_headers)
@@ -373,20 +385,17 @@ pub struct ToolConfig {
373
385
}
374
386
375
387
#[ derive( Debug , Clone , Serialize , Deserialize ) ]
376
- #[ serde( tag = "type" ) ]
377
- pub enum Tool {
388
+ pub struct Tool {
378
389
#[ serde( rename = "toolSpec" ) ]
379
- ToolSpec {
380
- name : String ,
381
- description : String ,
382
- #[ serde( rename = "inputSchema" ) ]
383
- input_schema : ToolInputSchema ,
384
- } ,
390
+ pub tool_spec : ToolSpec ,
385
391
}
386
392
387
393
#[ derive( Debug , Clone , Serialize , Deserialize ) ]
388
- pub struct ToolInputSchema {
389
- pub json : Value ,
394
+ pub struct ToolSpec {
395
+ pub name : String ,
396
+ pub description : String ,
397
+ #[ serde( rename = "inputSchema" ) ]
398
+ pub input_schema : Value ,
390
399
}
391
400
392
401
#[ derive( Debug , Clone , Serialize , Deserialize ) ]
@@ -406,16 +415,7 @@ pub struct GuardrailConfig {
406
415
pub guardrail_identifier : String ,
407
416
#[ serde( rename = "guardrailVersion" ) ]
408
417
pub guardrail_version : String ,
409
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
410
- pub trace : Option < GuardrailTrace > ,
411
- }
412
-
413
- #[ derive( Debug , Clone , Serialize , Deserialize ) ]
414
- pub enum GuardrailTrace {
415
- #[ serde( rename = "enabled" ) ]
416
- Enabled ,
417
- #[ serde( rename = "disabled" ) ]
418
- Disabled ,
418
+ pub trace : Option < String > ,
419
419
}
420
420
421
421
#[ derive( Debug , Clone , Serialize , Deserialize ) ]
@@ -503,4 +503,128 @@ fn parse_response<T: DeserializeOwned + Debug>(response: Response) -> Result<T,
503
503
}
504
504
}
505
505
506
+ #[ cfg( test) ]
507
+ mod tests {
508
+ use super :: * ;
509
+
510
+ #[ test]
511
+ fn test_generate_sigv4_headers_basic ( ) {
512
+ let access_key = "AKIAIOSFODNN7EXAMPLE" ;
513
+ let secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" ;
514
+ let region = "us-east-1" ;
515
+ let service = "bedrock" ;
516
+ let method = "POST" ;
517
+ let uri = "/model/anthropic.claude-3-sonnet-20240229-v1:0/converse" ;
518
+ let host = "bedrock-runtime.us-east-1.amazonaws.com" ;
519
+ let body = r#"{"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}"# ;
520
+
521
+ let result = generate_sigv4_headers (
522
+ access_key,
523
+ secret_key,
524
+ region,
525
+ service,
526
+ method,
527
+ uri,
528
+ host,
529
+ body,
530
+ ) ;
531
+
532
+ assert ! ( result. is_ok( ) ) ;
533
+ let headers = result. unwrap ( ) ;
534
+
535
+ let header_map: std:: collections:: HashMap < String , String > = headers. into_iter ( ) . collect ( ) ;
536
+
537
+ assert ! ( header_map. contains_key( "authorization" ) ) ;
538
+ assert ! ( header_map. contains_key( "x-amz-date" ) ) ;
539
+ assert ! ( header_map. contains_key( "content-type" ) ) ;
540
+
541
+ let auth_header = & header_map[ "authorization" ] ;
542
+ assert ! ( auth_header. starts_with( "AWS4-HMAC-SHA256 Credential=" ) ) ;
543
+ assert ! ( auth_header. contains( "SignedHeaders=" ) ) ;
544
+ assert ! ( auth_header. contains( "Signature=" ) ) ;
545
+
546
+ assert_eq ! ( header_map[ "content-type" ] , "application/x-amz-json-1.0" ) ;
547
+
548
+ let date_header = & header_map[ "x-amz-date" ] ;
549
+ assert ! ( date_header. ends_with( 'Z' ) ) ;
550
+ assert ! ( date_header. contains( 'T' ) ) ;
551
+ }
552
+
553
+ #[ test]
554
+ fn test_canonical_headers_ordering ( ) {
555
+ let access_key = "AKIAIOSFODNN7EXAMPLE" ;
556
+ let secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" ;
557
+ let region = "us-east-1" ;
558
+ let service = "bedrock" ;
559
+ let method = "POST" ;
560
+ let uri = "/model/test/converse" ;
561
+ let host = "bedrock-runtime.us-east-1.amazonaws.com" ;
562
+ let body = "{}" ;
563
+
564
+ let result = generate_sigv4_headers (
565
+ access_key,
566
+ secret_key,
567
+ region,
568
+ service,
569
+ method,
570
+ uri,
571
+ host,
572
+ body,
573
+ ) ;
574
+
575
+ assert ! ( result. is_ok( ) ) ;
576
+ let headers = result. unwrap ( ) ;
577
+ let header_map: std:: collections:: HashMap < String , String > = headers. into_iter ( ) . collect ( ) ;
578
+
579
+ let auth_header = & header_map[ "authorization" ] ;
580
+
581
+ assert ! ( auth_header. contains( "SignedHeaders=content-type;host;x-amz-date" ) ) ;
582
+ }
583
+
584
+ #[ test]
585
+ fn test_bedrock_client_integration ( ) {
586
+ let client = BedrockClient :: new (
587
+ "AKIAIOSFODNN7EXAMPLE" . to_string ( ) ,
588
+ "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" . to_string ( ) ,
589
+ "us-east-1" . to_string ( ) ,
590
+ ) ;
591
+
592
+ let request = ConverseRequest {
593
+ model_id : "anthropic.claude-3-sonnet-20240229-v1:0" . to_string ( ) ,
594
+ messages : vec ! [ Message {
595
+ role: Role :: User ,
596
+ content: vec![ ContentBlock :: Text {
597
+ text: "Hello, how are you?" . to_string( ) ,
598
+ } ] ,
599
+ } ] ,
600
+ system : None ,
601
+ inference_config : None ,
602
+ tool_config : None ,
603
+ guardrail_config : None ,
604
+ additional_model_request_fields : None ,
605
+ } ;
606
+
607
+
608
+ let body = serde_json:: to_string ( & request) . expect ( "Failed to serialize request" ) ;
609
+ let host = format ! ( "bedrock-runtime.{}.amazonaws.com" , client. region) ;
610
+
611
+ let headers_result = generate_sigv4_headers (
612
+ & client. access_key_id ,
613
+ & client. secret_access_key ,
614
+ & client. region ,
615
+ "bedrock" ,
616
+ "POST" ,
617
+ "/model/anthropic.claude-3-sonnet-20240229-v1:0/converse" ,
618
+ & host,
619
+ & body,
620
+ ) ;
621
+
622
+ assert ! ( headers_result. is_ok( ) ) ;
623
+ let headers = headers_result. unwrap ( ) ;
506
624
625
+ let header_names: Vec < & str > = headers. iter ( ) . map ( |( k, _) | k. as_str ( ) ) . collect ( ) ;
626
+ assert ! ( header_names. contains( & "authorization" ) ) ;
627
+ assert ! ( header_names. contains( & "x-amz-date" ) ) ;
628
+ assert ! ( header_names. contains( & "content-type" ) ) ;
629
+ }
630
+ }
0 commit comments