@@ -55,7 +55,7 @@ impl BedrockClient {
55
55
& self . region ,
56
56
"bedrock" ,
57
57
"POST" ,
58
- & format ! ( "/model/{}/converse" , model_id ) ,
58
+ & format ! ( "/model/{model_id }/converse" ) ,
59
59
& host,
60
60
& body,
61
61
)
@@ -71,11 +71,11 @@ impl BedrockClient {
71
71
}
72
72
73
73
let response: Response = request_builder. body ( body) . send ( ) . map_err ( |err| {
74
- trace ! ( "HTTP request failed with error: {:?}" , err ) ;
74
+ trace ! ( "HTTP request failed with error: {err :?}" ) ;
75
75
from_reqwest_error ( "Request failed" , err)
76
76
} ) ?;
77
77
78
- trace ! ( "Received response from Bedrock API: {:?}" , response ) ;
78
+ trace ! ( "Received response from Bedrock API: {response :?}" ) ;
79
79
80
80
parse_response ( response)
81
81
}
@@ -104,7 +104,7 @@ impl BedrockClient {
104
104
& self . region ,
105
105
"bedrock" ,
106
106
"POST" ,
107
- & format ! ( "/model/{}/converse-stream" , model_id ) ,
107
+ & format ! ( "/model/{model_id }/converse-stream" ) ,
108
108
& host,
109
109
& body,
110
110
)
@@ -121,7 +121,7 @@ impl BedrockClient {
121
121
122
122
trace ! ( "Sending streaming HTTP request to Bedrock..." ) ;
123
123
let response: Response = request_builder. body ( body) . send ( ) . map_err ( |err| {
124
- trace ! ( "HTTP request failed with error: {:?}" , err ) ;
124
+ trace ! ( "HTTP request failed with error: {err :?}" ) ;
125
125
from_reqwest_error ( "Request failed" , err)
126
126
} ) ?;
127
127
@@ -132,6 +132,7 @@ impl BedrockClient {
132
132
}
133
133
}
134
134
135
+ #[ allow( clippy:: too_many_arguments) ]
135
136
pub fn generate_sigv4_headers (
136
137
access_key : & str ,
137
138
secret_key : & str ,
@@ -143,7 +144,7 @@ pub fn generate_sigv4_headers(
143
144
body : & str ,
144
145
) -> Result < Vec < ( String , String ) > , Box < dyn std:: error:: Error > > {
145
146
use std:: collections:: BTreeMap ;
146
-
147
+
147
148
let now = SystemTime :: now ( ) . duration_since ( UNIX_EPOCH ) . unwrap ( ) ;
148
149
let timestamp = OffsetDateTime :: from_unix_timestamp ( now. as_secs ( ) as i64 ) . unwrap ( ) ;
149
150
@@ -166,13 +167,13 @@ pub fn generate_sigv4_headers(
166
167
let ( canonical_uri, canonical_query_string) = if let Some ( query_pos) = uri. find ( '?' ) {
167
168
let path = & uri[ ..query_pos] ;
168
169
let query = & uri[ query_pos + 1 ..] ;
169
-
170
+
170
171
let encoded_path = if path. contains ( ':' ) {
171
172
path. replace ( ':' , "%3A" )
172
173
} else {
173
174
path. to_string ( )
174
175
} ;
175
-
176
+
176
177
let mut query_params: Vec < & str > = query. split ( '&' ) . collect ( ) ;
177
178
query_params. sort ( ) ;
178
179
( encoded_path, query_params. join ( "&" ) )
@@ -206,20 +207,17 @@ pub fn generate_sigv4_headers(
206
207
let payload_hash = format ! ( "{:x}" , Sha256 :: digest( body. as_bytes( ) ) ) ;
207
208
208
209
let canonical_request = format ! (
209
- "{}\n {}\n {}\n {}\n {}\n {}" ,
210
- method, canonical_uri, canonical_query_string, canonical_headers, signed_headers, payload_hash
210
+ "{method}\n {canonical_uri}\n {canonical_query_string}\n {canonical_headers}\n {signed_headers}\n {payload_hash}"
211
211
) ;
212
212
213
- let credential_scope = format ! ( "{}/{}/{}/aws4_request" , date_str , region , service ) ;
213
+ let credential_scope = format ! ( "{date_str }/{region }/{service }/aws4_request" ) ;
214
214
let canonical_request_hash = format ! ( "{:x}" , Sha256 :: digest( canonical_request. as_bytes( ) ) ) ;
215
- let string_to_sign = format ! (
216
- "AWS4-HMAC-SHA256\n {}\n {}\n {}" ,
217
- datetime_str, credential_scope, canonical_request_hash
218
- ) ;
215
+ let string_to_sign =
216
+ format ! ( "AWS4-HMAC-SHA256\n {datetime_str}\n {credential_scope}\n {canonical_request_hash}" ) ;
219
217
220
218
type HmacSha256 = Hmac < Sha256 > ;
221
219
222
- let mut mac = HmacSha256 :: new_from_slice ( format ! ( "AWS4{}" , secret_key ) . as_bytes ( ) ) ?;
220
+ let mut mac = HmacSha256 :: new_from_slice ( format ! ( "AWS4{secret_key}" ) . as_bytes ( ) ) ?;
223
221
mac. update ( date_str. as_bytes ( ) ) ;
224
222
let date_key = mac. finalize ( ) . into_bytes ( ) ;
225
223
@@ -240,23 +238,23 @@ pub fn generate_sigv4_headers(
240
238
let signature = format ! ( "{:x}" , mac. finalize( ) . into_bytes( ) ) ;
241
239
242
240
let auth_header = format ! (
243
- "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}" ,
244
- access_key, credential_scope, signed_headers, signature
241
+ "AWS4-HMAC-SHA256 Credential={access_key}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}"
245
242
) ;
246
243
247
244
let result_headers = vec ! [
248
245
( "authorization" . to_string( ) , auth_header) ,
249
246
( "x-amz-date" . to_string( ) , datetime_str) ,
250
- ( "content-type" . to_string( ) , "application/x-amz-json-1.0" . to_string( ) ) ,
247
+ (
248
+ "content-type" . to_string( ) ,
249
+ "application/x-amz-json-1.0" . to_string( ) ,
250
+ ) ,
251
251
] ;
252
252
253
253
Ok ( result_headers)
254
254
}
255
255
256
256
#[ derive( Debug , Clone , Serialize , Deserialize ) ]
257
257
pub struct ConverseRequest {
258
- #[ serde( skip_serializing, rename = "modelId" ) ]
259
- pub model_id : String ,
260
258
pub messages : Vec < Message > ,
261
259
#[ serde( skip_serializing_if = "Option::is_none" ) ]
262
260
pub system : Option < Vec < SystemContentBlock > > ,
@@ -290,7 +288,9 @@ pub enum Role {
290
288
#[ derive( Debug , Clone , Serialize , Deserialize ) ]
291
289
#[ serde( untagged) ]
292
290
pub enum ContentBlock {
293
- Text { text : String } ,
291
+ Text {
292
+ text : String ,
293
+ } ,
294
294
Image {
295
295
image : ImageBlock ,
296
296
} ,
@@ -442,11 +442,13 @@ pub struct ConverseResponse {
442
442
pub stop_reason : StopReason ,
443
443
pub usage : Usage ,
444
444
pub metrics : Metrics ,
445
- #[ serde( rename = "additionalModelResponseFields" , skip_serializing_if = "Option::is_none" ) ]
445
+ #[ serde(
446
+ rename = "additionalModelResponseFields" ,
447
+ skip_serializing_if = "Option::is_none"
448
+ ) ]
446
449
pub additional_model_response_fields : Option < Value > ,
447
450
}
448
451
449
-
450
452
#[ derive( Debug , Clone , Serialize , Deserialize ) ]
451
453
pub struct Output {
452
454
pub message : Message ,
@@ -509,134 +511,8 @@ fn parse_response<T: DeserializeOwned + Debug>(response: Response) -> Result<T,
509
511
510
512
Err ( Error {
511
513
code : error_code_from_status ( status) ,
512
- message : format ! ( "Request failed with {status}: {}" , body ) ,
514
+ message : format ! ( "Request failed with {status}: {body}" ) ,
513
515
provider_error_json : Some ( body) ,
514
516
} )
515
517
}
516
518
}
517
-
518
- #[ cfg( test) ]
519
- mod tests {
520
- use super :: * ;
521
-
522
- #[ test]
523
- fn test_generate_sigv4_headers_basic ( ) {
524
- let access_key = "AKIAIOSFODNN7EXAMPLE" ;
525
- let secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" ;
526
- let region = "us-east-1" ;
527
- let service = "bedrock" ;
528
- let method = "POST" ;
529
- let uri = "/model/anthropic.claude-3-sonnet-20240229-v1:0/converse" ;
530
- let host = "bedrock-runtime.us-east-1.amazonaws.com" ;
531
- let body = r#"{"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}"# ;
532
-
533
- let result = generate_sigv4_headers (
534
- access_key,
535
- secret_key,
536
- region,
537
- service,
538
- method,
539
- uri,
540
- host,
541
- body,
542
- ) ;
543
-
544
- assert ! ( result. is_ok( ) ) ;
545
- let headers = result. unwrap ( ) ;
546
-
547
- let header_map: std:: collections:: HashMap < String , String > = headers. into_iter ( ) . collect ( ) ;
548
-
549
- assert ! ( header_map. contains_key( "authorization" ) ) ;
550
- assert ! ( header_map. contains_key( "x-amz-date" ) ) ;
551
- assert ! ( header_map. contains_key( "content-type" ) ) ;
552
-
553
- let auth_header = & header_map[ "authorization" ] ;
554
- assert ! ( auth_header. starts_with( "AWS4-HMAC-SHA256 Credential=" ) ) ;
555
- assert ! ( auth_header. contains( "SignedHeaders=" ) ) ;
556
- assert ! ( auth_header. contains( "Signature=" ) ) ;
557
-
558
- assert_eq ! ( header_map[ "content-type" ] , "application/x-amz-json-1.0" ) ;
559
-
560
- let date_header = & header_map[ "x-amz-date" ] ;
561
- assert ! ( date_header. ends_with( 'Z' ) ) ;
562
- assert ! ( date_header. contains( 'T' ) ) ;
563
- }
564
-
565
- #[ test]
566
- fn test_canonical_headers_ordering ( ) {
567
- let access_key = "AKIAIOSFODNN7EXAMPLE" ;
568
- let secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" ;
569
- let region = "us-east-1" ;
570
- let service = "bedrock" ;
571
- let method = "POST" ;
572
- let uri = "/model/test/converse" ;
573
- let host = "bedrock-runtime.us-east-1.amazonaws.com" ;
574
- let body = "{}" ;
575
-
576
- let result = generate_sigv4_headers (
577
- access_key,
578
- secret_key,
579
- region,
580
- service,
581
- method,
582
- uri,
583
- host,
584
- body,
585
- ) ;
586
-
587
- assert ! ( result. is_ok( ) ) ;
588
- let headers = result. unwrap ( ) ;
589
- let header_map: std:: collections:: HashMap < String , String > = headers. into_iter ( ) . collect ( ) ;
590
-
591
- let auth_header = & header_map[ "authorization" ] ;
592
-
593
- assert ! ( auth_header. contains( "SignedHeaders=content-type;host;x-amz-date" ) ) ;
594
- }
595
-
596
- #[ test]
597
- fn test_bedrock_client_integration ( ) {
598
- let client = BedrockClient :: new (
599
- "AKIAIOSFODNN7EXAMPLE" . to_string ( ) ,
600
- "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" . to_string ( ) ,
601
- "us-east-1" . to_string ( ) ,
602
- ) ;
603
-
604
- let request = ConverseRequest {
605
- model_id : "anthropic.claude-3-sonnet-20240229-v1:0" . to_string ( ) ,
606
- messages : vec ! [ Message {
607
- role: Role :: User ,
608
- content: vec![ ContentBlock :: Text {
609
- text: "Hello, how are you?" . to_string( ) ,
610
- } ] ,
611
- } ] ,
612
- system : None ,
613
- inference_config : None ,
614
- tool_config : None ,
615
- guardrail_config : None ,
616
- additional_model_request_fields : None ,
617
- } ;
618
-
619
-
620
- let body = serde_json:: to_string ( & request) . expect ( "Failed to serialize request" ) ;
621
- let host = format ! ( "bedrock-runtime.{}.amazonaws.com" , client. region) ;
622
-
623
- let headers_result = generate_sigv4_headers (
624
- & client. access_key_id ,
625
- & client. secret_access_key ,
626
- & client. region ,
627
- "bedrock" ,
628
- "POST" ,
629
- "/model/anthropic.claude-3-sonnet-20240229-v1:0/converse" ,
630
- & host,
631
- & body,
632
- ) ;
633
-
634
- assert ! ( headers_result. is_ok( ) ) ;
635
- let headers = headers_result. unwrap ( ) ;
636
-
637
- let header_names: Vec < & str > = headers. iter ( ) . map ( |( k, _) | k. as_str ( ) ) . collect ( ) ;
638
- assert ! ( header_names. contains( & "authorization" ) ) ;
639
- assert ! ( header_names. contains( & "x-amz-date" ) ) ;
640
- assert ! ( header_names. contains( & "content-type" ) ) ;
641
- }
642
- }
0 commit comments