Skip to content

Commit 3cbc715

Browse files
committed
sigv4 fixes
1 parent 47cfffa commit 3cbc715

File tree

5 files changed

+179
-51
lines changed

5 files changed

+179
-51
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llm/bedrock/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ base64 = { workspace = true }
2828
hmac = "0.12"
2929
sha2 = "0.10"
3030
time = { version = "0.3", features = ["formatting"] }
31+
percent-encoding = "2.3"
3132

3233
[package.metadata.component]
3334
package = "golem:llm-bedrock"

llm/bedrock/src/client.rs

Lines changed: 169 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ impl BedrockClient {
6666
})?;
6767

6868
let mut request_builder = self.client.request(Method::POST, &url);
69-
request_builder = request_builder.header("content-type", "application/json");
7069
for (key, value) in headers {
7170
request_builder = request_builder.header(key, value);
7271
}
@@ -116,7 +115,6 @@ impl BedrockClient {
116115
})?;
117116

118117
let mut request_builder = self.client.request(Method::POST, &url);
119-
request_builder = request_builder.header("content-type", "application/json");
120118
for (key, value) in headers {
121119
request_builder = request_builder.header(key, value);
122120
}
@@ -144,6 +142,8 @@ pub fn generate_sigv4_headers(
144142
host: &str,
145143
body: &str,
146144
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error>> {
145+
use std::collections::BTreeMap;
146+
147147
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
148148
let timestamp = OffsetDateTime::from_unix_timestamp(now.as_secs() as i64).unwrap();
149149

@@ -163,48 +163,60 @@ pub fn generate_sigv4_headers(
163163
timestamp.second()
164164
);
165165

166-
// Create canonical request
167-
let path = if uri.starts_with('/') { uri } else { "/" };
168-
let query = "";
169-
170-
// Create canonical headers
171-
let mut headers: Vec<(String, String)> = vec![
172-
("host".to_string(), host.to_string()),
173-
("x-amz-date".to_string(), datetime_str.clone()),
174-
];
175-
headers.sort_by(|a, b| a.0.cmp(&b.0));
166+
let (canonical_uri, canonical_query_string) = if let Some(query_pos) = uri.find('?') {
167+
let path = &uri[..query_pos];
168+
let query = &uri[query_pos + 1..];
169+
170+
let encoded_path = if path.contains(':') {
171+
path.replace(':', "%3A")
172+
} else {
173+
path.to_string()
174+
};
175+
176+
let mut query_params: Vec<&str> = query.split('&').collect();
177+
query_params.sort();
178+
(encoded_path, query_params.join("&"))
179+
} else {
180+
let encoded_path = if uri.contains(':') {
181+
uri.replace(':', "%3A")
182+
} else {
183+
uri.to_string()
184+
};
185+
(encoded_path, String::new())
186+
};
187+
188+
let mut headers = BTreeMap::new();
189+
headers.insert("content-type", "application/x-amz-json-1.0");
190+
headers.insert("host", host);
191+
headers.insert("x-amz-date", &datetime_str);
176192

177193
let canonical_headers = headers
178194
.iter()
179-
.map(|(k, v)| format!("{}:{}", k, v))
195+
.map(|(k, v)| format!("{}:{}", k.to_lowercase().trim(), v.trim()))
180196
.collect::<Vec<_>>()
181197
.join("\n")
182198
+ "\n";
183199

184200
let signed_headers = headers
185-
.iter()
186-
.map(|(k, _)| k.as_str())
201+
.keys()
202+
.map(|k| k.to_lowercase())
187203
.collect::<Vec<_>>()
188204
.join(";");
189205

190-
// Hash payload
191206
let payload_hash = format!("{:x}", Sha256::digest(body.as_bytes()));
192207

193208
let canonical_request = format!(
194209
"{}\n{}\n{}\n{}\n{}\n{}",
195-
method, path, query, canonical_headers, signed_headers, payload_hash
210+
method, canonical_uri, canonical_query_string, canonical_headers, signed_headers, payload_hash
196211
);
197212

198-
// Create string to sign
199213
let credential_scope = format!("{}/{}/{}/aws4_request", date_str, region, service);
214+
let canonical_request_hash = format!("{:x}", Sha256::digest(canonical_request.as_bytes()));
200215
let string_to_sign = format!(
201-
"AWS4-HMAC-SHA256\n{}\n{}\n{:x}",
202-
datetime_str,
203-
credential_scope,
204-
Sha256::digest(canonical_request.as_bytes())
216+
"AWS4-HMAC-SHA256\n{}\n{}\n{}",
217+
datetime_str, credential_scope, canonical_request_hash
205218
);
206219

207-
// Calculate signature
208220
type HmacSha256 = Hmac<Sha256>;
209221

210222
let mut mac = HmacSha256::new_from_slice(format!("AWS4{}", secret_key).as_bytes())?;
@@ -227,15 +239,15 @@ pub fn generate_sigv4_headers(
227239
mac.update(string_to_sign.as_bytes());
228240
let signature = format!("{:x}", mac.finalize().into_bytes());
229241

230-
// Create authorization header
231242
let auth_header = format!(
232243
"AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
233244
access_key, credential_scope, signed_headers, signature
234245
);
235246

236-
let mut result_headers = vec![
247+
let result_headers = vec![
237248
("authorization".to_string(), auth_header),
238249
("x-amz-date".to_string(), datetime_str),
250+
("content-type".to_string(), "application/x-amz-json-1.0".to_string()),
239251
];
240252

241253
Ok(result_headers)
@@ -373,20 +385,17 @@ pub struct ToolConfig {
373385
}
374386

375387
#[derive(Debug, Clone, Serialize, Deserialize)]
376-
#[serde(tag = "type")]
377-
pub enum Tool {
388+
pub struct Tool {
378389
#[serde(rename = "toolSpec")]
379-
ToolSpec {
380-
name: String,
381-
description: String,
382-
#[serde(rename = "inputSchema")]
383-
input_schema: ToolInputSchema,
384-
},
390+
pub tool_spec: ToolSpec,
385391
}
386392

387393
#[derive(Debug, Clone, Serialize, Deserialize)]
388-
pub struct ToolInputSchema {
389-
pub json: Value,
394+
pub struct ToolSpec {
395+
pub name: String,
396+
pub description: String,
397+
#[serde(rename = "inputSchema")]
398+
pub input_schema: Value,
390399
}
391400

392401
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -406,16 +415,7 @@ pub struct GuardrailConfig {
406415
pub guardrail_identifier: String,
407416
#[serde(rename = "guardrailVersion")]
408417
pub guardrail_version: String,
409-
#[serde(skip_serializing_if = "Option::is_none")]
410-
pub trace: Option<GuardrailTrace>,
411-
}
412-
413-
#[derive(Debug, Clone, Serialize, Deserialize)]
414-
pub enum GuardrailTrace {
415-
#[serde(rename = "enabled")]
416-
Enabled,
417-
#[serde(rename = "disabled")]
418-
Disabled,
418+
pub trace: Option<String>,
419419
}
420420

421421
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -503,4 +503,128 @@ fn parse_response<T: DeserializeOwned + Debug>(response: Response) -> Result<T,
503503
}
504504
}
505505

506+
#[cfg(test)]
507+
mod tests {
508+
use super::*;
509+
510+
#[test]
511+
fn test_generate_sigv4_headers_basic() {
512+
let access_key = "AKIAIOSFODNN7EXAMPLE";
513+
let secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY";
514+
let region = "us-east-1";
515+
let service = "bedrock";
516+
let method = "POST";
517+
let uri = "/model/anthropic.claude-3-sonnet-20240229-v1:0/converse";
518+
let host = "bedrock-runtime.us-east-1.amazonaws.com";
519+
let body = r#"{"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}"#;
520+
521+
let result = generate_sigv4_headers(
522+
access_key,
523+
secret_key,
524+
region,
525+
service,
526+
method,
527+
uri,
528+
host,
529+
body,
530+
);
531+
532+
assert!(result.is_ok());
533+
let headers = result.unwrap();
534+
535+
let header_map: std::collections::HashMap<String, String> = headers.into_iter().collect();
536+
537+
assert!(header_map.contains_key("authorization"));
538+
assert!(header_map.contains_key("x-amz-date"));
539+
assert!(header_map.contains_key("content-type"));
540+
541+
let auth_header = &header_map["authorization"];
542+
assert!(auth_header.starts_with("AWS4-HMAC-SHA256 Credential="));
543+
assert!(auth_header.contains("SignedHeaders="));
544+
assert!(auth_header.contains("Signature="));
545+
546+
assert_eq!(header_map["content-type"], "application/x-amz-json-1.0");
547+
548+
let date_header = &header_map["x-amz-date"];
549+
assert!(date_header.ends_with('Z'));
550+
assert!(date_header.contains('T'));
551+
}
552+
553+
#[test]
554+
fn test_canonical_headers_ordering() {
555+
let access_key = "AKIAIOSFODNN7EXAMPLE";
556+
let secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY";
557+
let region = "us-east-1";
558+
let service = "bedrock";
559+
let method = "POST";
560+
let uri = "/model/test/converse";
561+
let host = "bedrock-runtime.us-east-1.amazonaws.com";
562+
let body = "{}";
563+
564+
let result = generate_sigv4_headers(
565+
access_key,
566+
secret_key,
567+
region,
568+
service,
569+
method,
570+
uri,
571+
host,
572+
body,
573+
);
574+
575+
assert!(result.is_ok());
576+
let headers = result.unwrap();
577+
let header_map: std::collections::HashMap<String, String> = headers.into_iter().collect();
578+
579+
let auth_header = &header_map["authorization"];
580+
581+
assert!(auth_header.contains("SignedHeaders=content-type;host;x-amz-date"));
582+
}
583+
584+
#[test]
585+
fn test_bedrock_client_integration() {
586+
let client = BedrockClient::new(
587+
"AKIAIOSFODNN7EXAMPLE".to_string(),
588+
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
589+
"us-east-1".to_string(),
590+
);
591+
592+
let request = ConverseRequest {
593+
model_id: "anthropic.claude-3-sonnet-20240229-v1:0".to_string(),
594+
messages: vec![Message {
595+
role: Role::User,
596+
content: vec![ContentBlock::Text {
597+
text: "Hello, how are you?".to_string(),
598+
}],
599+
}],
600+
system: None,
601+
inference_config: None,
602+
tool_config: None,
603+
guardrail_config: None,
604+
additional_model_request_fields: None,
605+
};
606+
607+
608+
let body = serde_json::to_string(&request).expect("Failed to serialize request");
609+
let host = format!("bedrock-runtime.{}.amazonaws.com", client.region);
610+
611+
let headers_result = generate_sigv4_headers(
612+
&client.access_key_id,
613+
&client.secret_access_key,
614+
&client.region,
615+
"bedrock",
616+
"POST",
617+
"/model/anthropic.claude-3-sonnet-20240229-v1:0/converse",
618+
&host,
619+
&body,
620+
);
621+
622+
assert!(headers_result.is_ok());
623+
let headers = headers_result.unwrap();
506624

625+
let header_names: Vec<&str> = headers.iter().map(|(k, _)| k.as_str()).collect();
626+
assert!(header_names.contains(&"authorization"));
627+
assert!(header_names.contains(&"x-amz-date"));
628+
assert!(header_names.contains(&"content-type"));
629+
}
630+
}

llm/bedrock/src/conversions.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::client::{
22
ContentBlock, ConverseRequest, ConverseResponse, ImageFormat, ImageSource as ClientImageSource,
33
InferenceConfig, Message as ClientMessage, Role as ClientRole, StopReason, SystemContentBlock,
4-
Tool, ToolChoice, ToolConfig, ToolInputSchema, ToolResultContentBlock, ToolResultStatus,
4+
Tool, ToolChoice, ToolConfig, ToolSpec, ToolResultContentBlock, ToolResultStatus,
55
};
66
use base64::{engine::general_purpose, Engine as _};
77
use golem_llm::golem::llm::llm::{
@@ -287,10 +287,12 @@ fn message_to_system_content(message: &Message) -> Vec<SystemContentBlock> {
287287

288288
fn tool_definition_to_tool(tool: &ToolDefinition) -> Result<Tool, Error> {
289289
match serde_json::from_str(&tool.parameters_schema) {
290-
Ok(json_schema) => Ok(Tool::ToolSpec {
291-
name: tool.name.clone(),
292-
description: tool.description.clone().unwrap_or_default(),
293-
input_schema: ToolInputSchema { json: json_schema },
290+
Ok(json_schema) => Ok(Tool {
291+
tool_spec: ToolSpec {
292+
name: tool.name.clone(),
293+
description: tool.description.clone().unwrap_or_default(),
294+
input_schema: json_schema,
295+
},
294296
}),
295297
Err(error) => Err(Error {
296298
code: ErrorCode::InternalError,

test/components-rust/test-llm/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ path = "wit-generated"
3838

3939
[package.metadata.component.target.dependencies]
4040
"golem:llm" = { path = "wit-generated/deps/golem-llm" }
41-
"wasi:io" = { path = "wit-generated/deps/io" }
4241
"wasi:clocks" = { path = "wit-generated/deps/clocks" }
42+
"wasi:io" = { path = "wit-generated/deps/io" }
4343
"golem:rpc" = { path = "wit-generated/deps/golem-rpc" }
4444
"test:helper-client" = { path = "wit-generated/deps/test_helper-client" }
4545
"test:llm-exports" = { path = "wit-generated/deps/test_llm-exports" }

0 commit comments

Comments
 (0)