@@ -66,7 +66,6 @@ impl BedrockClient {
66
66
} ) ?;
67
67
68
68
let mut request_builder = self . client . request ( Method :: POST , & url) ;
69
- request_builder = request_builder. header ( "content-type" , "application/json" ) ;
70
69
for ( key, value) in headers {
71
70
request_builder = request_builder. header ( key, value) ;
72
71
}
@@ -116,7 +115,6 @@ impl BedrockClient {
116
115
} ) ?;
117
116
118
117
let mut request_builder = self . client . request ( Method :: POST , & url) ;
119
- request_builder = request_builder. header ( "content-type" , "application/json" ) ;
120
118
for ( key, value) in headers {
121
119
request_builder = request_builder. header ( key, value) ;
122
120
}
@@ -134,6 +132,7 @@ impl BedrockClient {
134
132
}
135
133
}
136
134
135
+ /// FIXED: AWS SigV4 headers generation with proper content-type handling
137
136
pub fn generate_sigv4_headers (
138
137
access_key : & str ,
139
138
secret_key : & str ,
@@ -167,16 +166,18 @@ pub fn generate_sigv4_headers(
167
166
let path = if uri. starts_with ( '/' ) { uri } else { "/" } ;
168
167
let query = "" ;
169
168
170
- // Create canonical headers
169
+ // FIXED: Create canonical headers with content-type included and proper sorting
171
170
let mut headers: Vec < ( String , String ) > = vec ! [
172
171
( "host" . to_string( ) , host. to_string( ) ) ,
173
172
( "x-amz-date" . to_string( ) , datetime_str. clone( ) ) ,
173
+ ( "content-type" . to_string( ) , "application/x-amz-json-1.0" . to_string( ) ) ,
174
174
] ;
175
175
headers. sort_by ( |a, b| a. 0 . cmp ( & b. 0 ) ) ;
176
176
177
+ // FIXED: Proper trimming of header keys and values
177
178
let canonical_headers = headers
178
179
. iter ( )
179
- . map ( |( k, v) | format ! ( "{}:{}" , k, v) )
180
+ . map ( |( k, v) | format ! ( "{}:{}" , k. trim ( ) , v. trim ( ) ) )
180
181
. collect :: < Vec < _ > > ( )
181
182
. join ( "\n " )
182
183
+ "\n " ;
@@ -204,7 +205,7 @@ pub fn generate_sigv4_headers(
204
205
Sha256 :: digest( canonical_request. as_bytes( ) )
205
206
) ;
206
207
207
- // Calculate signature
208
+ // Calculate signature using AWS SigV4 key derivation
208
209
type HmacSha256 = Hmac < Sha256 > ;
209
210
210
211
let mut mac = HmacSha256 :: new_from_slice ( format ! ( "AWS4{}" , secret_key) . as_bytes ( ) ) ?;
@@ -233,9 +234,11 @@ pub fn generate_sigv4_headers(
233
234
access_key, credential_scope, signed_headers, signature
234
235
) ;
235
236
236
- let mut result_headers = vec ! [
237
+ // FIXED: Return all required headers including content-type
238
+ let result_headers = vec ! [
237
239
( "authorization" . to_string( ) , auth_header) ,
238
240
( "x-amz-date" . to_string( ) , datetime_str) ,
241
+ ( "content-type" . to_string( ) , "application/x-amz-json-1.0" . to_string( ) ) ,
239
242
] ;
240
243
241
244
Ok ( result_headers)
@@ -373,20 +376,17 @@ pub struct ToolConfig {
373
376
}
374
377
375
378
#[ derive( Debug , Clone , Serialize , Deserialize ) ]
376
- #[ serde( tag = "type" ) ]
377
- pub enum Tool {
379
+ pub struct Tool {
378
380
#[ serde( rename = "toolSpec" ) ]
379
- ToolSpec {
380
- name : String ,
381
- description : String ,
382
- #[ serde( rename = "inputSchema" ) ]
383
- input_schema : ToolInputSchema ,
384
- } ,
381
+ pub tool_spec : ToolSpec ,
385
382
}
386
383
387
384
#[ derive( Debug , Clone , Serialize , Deserialize ) ]
388
- pub struct ToolInputSchema {
389
- pub json : Value ,
385
+ pub struct ToolSpec {
386
+ pub name : String ,
387
+ pub description : String ,
388
+ #[ serde( rename = "inputSchema" ) ]
389
+ pub input_schema : Value ,
390
390
}
391
391
392
392
#[ derive( Debug , Clone , Serialize , Deserialize ) ]
@@ -406,16 +406,7 @@ pub struct GuardrailConfig {
406
406
pub guardrail_identifier : String ,
407
407
#[ serde( rename = "guardrailVersion" ) ]
408
408
pub guardrail_version : String ,
409
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
410
- pub trace : Option < GuardrailTrace > ,
411
- }
412
-
413
- #[ derive( Debug , Clone , Serialize , Deserialize ) ]
414
- pub enum GuardrailTrace {
415
- #[ serde( rename = "enabled" ) ]
416
- Enabled ,
417
- #[ serde( rename = "disabled" ) ]
418
- Disabled ,
409
+ pub trace : Option < String > ,
419
410
}
420
411
421
412
#[ derive( Debug , Clone , Serialize , Deserialize ) ]
@@ -503,4 +494,135 @@ fn parse_response<T: DeserializeOwned + Debug>(response: Response) -> Result<T,
503
494
}
504
495
}
505
496
497
+ #[ cfg( test) ]
498
+ mod tests {
499
+ use super :: * ;
500
+
501
+ #[ test]
502
+ fn test_generate_sigv4_headers_basic ( ) {
503
+ let access_key = "AKIAIOSFODNN7EXAMPLE" ;
504
+ let secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" ;
505
+ let region = "us-east-1" ;
506
+ let service = "bedrock" ;
507
+ let method = "POST" ;
508
+ let uri = "/model/anthropic.claude-3-sonnet-20240229-v1:0/converse" ;
509
+ let host = "bedrock-runtime.us-east-1.amazonaws.com" ;
510
+ let body = r#"{"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}"# ;
511
+
512
+ let result = generate_sigv4_headers (
513
+ access_key,
514
+ secret_key,
515
+ region,
516
+ service,
517
+ method,
518
+ uri,
519
+ host,
520
+ body,
521
+ ) ;
522
+
523
+ assert ! ( result. is_ok( ) ) ;
524
+ let headers = result. unwrap ( ) ;
525
+
526
+ // Check that required headers are present
527
+ let header_map: std:: collections:: HashMap < String , String > = headers. into_iter ( ) . collect ( ) ;
528
+
529
+ assert ! ( header_map. contains_key( "authorization" ) ) ;
530
+ assert ! ( header_map. contains_key( "x-amz-date" ) ) ;
531
+ assert ! ( header_map. contains_key( "content-type" ) ) ;
532
+
533
+ // Check authorization header format
534
+ let auth_header = & header_map[ "authorization" ] ;
535
+ assert ! ( auth_header. starts_with( "AWS4-HMAC-SHA256 Credential=" ) ) ;
536
+ assert ! ( auth_header. contains( "SignedHeaders=" ) ) ;
537
+ assert ! ( auth_header. contains( "Signature=" ) ) ;
538
+
539
+ // Check content-type
540
+ assert_eq ! ( header_map[ "content-type" ] , "application/x-amz-json-1.0" ) ;
541
+
542
+ // Check x-amz-date format (should be ISO8601)
543
+ let date_header = & header_map[ "x-amz-date" ] ;
544
+ assert_eq ! ( date_header. len( ) , 16 ) ; // YYYYMMDDTHHMMSSZ
545
+ assert ! ( date_header. ends_with( 'Z' ) ) ;
546
+ assert ! ( date_header. contains( 'T' ) ) ;
547
+ }
548
+
549
+ #[ test]
550
+ fn test_canonical_headers_ordering ( ) {
551
+ let access_key = "AKIAIOSFODNN7EXAMPLE" ;
552
+ let secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" ;
553
+ let region = "us-east-1" ;
554
+ let service = "bedrock" ;
555
+ let method = "POST" ;
556
+ let uri = "/model/test/converse" ;
557
+ let host = "bedrock-runtime.us-east-1.amazonaws.com" ;
558
+ let body = "{}" ;
559
+
560
+ let result = generate_sigv4_headers (
561
+ access_key,
562
+ secret_key,
563
+ region,
564
+ service,
565
+ method,
566
+ uri,
567
+ host,
568
+ body,
569
+ ) ;
570
+
571
+ assert ! ( result. is_ok( ) ) ;
572
+ let headers = result. unwrap ( ) ;
573
+ let header_map: std:: collections:: HashMap < String , String > = headers. into_iter ( ) . collect ( ) ;
574
+
575
+ let auth_header = & header_map[ "authorization" ] ;
576
+
577
+ // SignedHeaders should be in alphabetical order: content-type;host;x-amz-date
578
+ assert ! ( auth_header. contains( "SignedHeaders=content-type;host;x-amz-date" ) ) ;
579
+ }
580
+
581
+ #[ test]
582
+ fn test_bedrock_client_integration ( ) {
583
+ let client = BedrockClient :: new (
584
+ "AKIAIOSFODNN7EXAMPLE" . to_string ( ) ,
585
+ "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" . to_string ( ) ,
586
+ "us-east-1" . to_string ( ) ,
587
+ ) ;
588
+
589
+ let request = ConverseRequest {
590
+ model_id : "anthropic.claude-3-sonnet-20240229-v1:0" . to_string ( ) ,
591
+ messages : vec ! [ Message {
592
+ role: Role :: User ,
593
+ content: vec![ ContentBlock :: Text {
594
+ text: "Hello, how are you?" . to_string( ) ,
595
+ } ] ,
596
+ } ] ,
597
+ system : None ,
598
+ inference_config : None ,
599
+ tool_config : None ,
600
+ guardrail_config : None ,
601
+ additional_model_request_fields : None ,
602
+ } ;
603
+
604
+ // This test verifies that serialization and signing work together
605
+ let body = serde_json:: to_string ( & request) . expect ( "Failed to serialize request" ) ;
606
+ let host = format ! ( "bedrock-runtime.{}.amazonaws.com" , client. region) ;
506
607
608
+ let headers_result = generate_sigv4_headers (
609
+ & client. access_key_id ,
610
+ & client. secret_access_key ,
611
+ & client. region ,
612
+ "bedrock" ,
613
+ "POST" ,
614
+ "/model/anthropic.claude-3-sonnet-20240229-v1:0/converse" ,
615
+ & host,
616
+ & body,
617
+ ) ;
618
+
619
+ assert ! ( headers_result. is_ok( ) ) ;
620
+ let headers = headers_result. unwrap ( ) ;
621
+
622
+ // Verify all required headers are present for AWS API call
623
+ let header_names: Vec < & str > = headers. iter ( ) . map ( |( k, _) | k. as_str ( ) ) . collect ( ) ;
624
+ assert ! ( header_names. contains( & "authorization" ) ) ;
625
+ assert ! ( header_names. contains( & "x-amz-date" ) ) ;
626
+ assert ! ( header_names. contains( & "content-type" ) ) ;
627
+ }
628
+ }
0 commit comments