Skip to content

Commit 45317f6

Browse files
committed
sigv4 wip
1 parent 47cfffa commit 45317f6

File tree

5 files changed

+158
-32
lines changed

5 files changed

+158
-32
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llm/bedrock/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ base64 = { workspace = true }
2828
hmac = "0.12"
2929
sha2 = "0.10"
3030
time = { version = "0.3", features = ["formatting"] }
31+
percent-encoding = "2.3"
3132

3233
[package.metadata.component]
3334
package = "golem:llm-bedrock"

llm/bedrock/src/client.rs

Lines changed: 148 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ impl BedrockClient {
6666
})?;
6767

6868
let mut request_builder = self.client.request(Method::POST, &url);
69-
request_builder = request_builder.header("content-type", "application/json");
7069
for (key, value) in headers {
7170
request_builder = request_builder.header(key, value);
7271
}
@@ -116,7 +115,6 @@ impl BedrockClient {
116115
})?;
117116

118117
let mut request_builder = self.client.request(Method::POST, &url);
119-
request_builder = request_builder.header("content-type", "application/json");
120118
for (key, value) in headers {
121119
request_builder = request_builder.header(key, value);
122120
}
@@ -134,6 +132,7 @@ impl BedrockClient {
134132
}
135133
}
136134

135+
/// FIXED: AWS SigV4 headers generation with proper content-type handling
137136
pub fn generate_sigv4_headers(
138137
access_key: &str,
139138
secret_key: &str,
@@ -167,16 +166,18 @@ pub fn generate_sigv4_headers(
167166
let path = if uri.starts_with('/') { uri } else { "/" };
168167
let query = "";
169168

170-
// Create canonical headers
169+
// FIXED: Create canonical headers with content-type included and proper sorting
171170
let mut headers: Vec<(String, String)> = vec![
172171
("host".to_string(), host.to_string()),
173172
("x-amz-date".to_string(), datetime_str.clone()),
173+
("content-type".to_string(), "application/x-amz-json-1.0".to_string()),
174174
];
175175
headers.sort_by(|a, b| a.0.cmp(&b.0));
176176

177+
// FIXED: Proper trimming of header keys and values
177178
let canonical_headers = headers
178179
.iter()
179-
.map(|(k, v)| format!("{}:{}", k, v))
180+
.map(|(k, v)| format!("{}:{}", k.trim(), v.trim()))
180181
.collect::<Vec<_>>()
181182
.join("\n")
182183
+ "\n";
@@ -204,7 +205,7 @@ pub fn generate_sigv4_headers(
204205
Sha256::digest(canonical_request.as_bytes())
205206
);
206207

207-
// Calculate signature
208+
// Calculate signature using AWS SigV4 key derivation
208209
type HmacSha256 = Hmac<Sha256>;
209210

210211
let mut mac = HmacSha256::new_from_slice(format!("AWS4{}", secret_key).as_bytes())?;
@@ -233,9 +234,11 @@ pub fn generate_sigv4_headers(
233234
access_key, credential_scope, signed_headers, signature
234235
);
235236

236-
let mut result_headers = vec![
237+
// FIXED: Return all required headers including content-type
238+
let result_headers = vec![
237239
("authorization".to_string(), auth_header),
238240
("x-amz-date".to_string(), datetime_str),
241+
("content-type".to_string(), "application/x-amz-json-1.0".to_string()),
239242
];
240243

241244
Ok(result_headers)
@@ -373,20 +376,17 @@ pub struct ToolConfig {
373376
}
374377

375378
#[derive(Debug, Clone, Serialize, Deserialize)]
376-
#[serde(tag = "type")]
377-
pub enum Tool {
379+
pub struct Tool {
378380
#[serde(rename = "toolSpec")]
379-
ToolSpec {
380-
name: String,
381-
description: String,
382-
#[serde(rename = "inputSchema")]
383-
input_schema: ToolInputSchema,
384-
},
381+
pub tool_spec: ToolSpec,
385382
}
386383

387384
#[derive(Debug, Clone, Serialize, Deserialize)]
388-
pub struct ToolInputSchema {
389-
pub json: Value,
385+
pub struct ToolSpec {
386+
pub name: String,
387+
pub description: String,
388+
#[serde(rename = "inputSchema")]
389+
pub input_schema: Value,
390390
}
391391

392392
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -406,16 +406,7 @@ pub struct GuardrailConfig {
406406
pub guardrail_identifier: String,
407407
#[serde(rename = "guardrailVersion")]
408408
pub guardrail_version: String,
409-
#[serde(skip_serializing_if = "Option::is_none")]
410-
pub trace: Option<GuardrailTrace>,
411-
}
412-
413-
#[derive(Debug, Clone, Serialize, Deserialize)]
414-
pub enum GuardrailTrace {
415-
#[serde(rename = "enabled")]
416-
Enabled,
417-
#[serde(rename = "disabled")]
418-
Disabled,
409+
pub trace: Option<String>,
419410
}
420411

421412
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -503,4 +494,135 @@ fn parse_response<T: DeserializeOwned + Debug>(response: Response) -> Result<T,
503494
}
504495
}
505496

497+
#[cfg(test)]
498+
mod tests {
499+
use super::*;
500+
501+
#[test]
502+
fn test_generate_sigv4_headers_basic() {
503+
let access_key = "AKIAIOSFODNN7EXAMPLE";
504+
let secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY";
505+
let region = "us-east-1";
506+
let service = "bedrock";
507+
let method = "POST";
508+
let uri = "/model/anthropic.claude-3-sonnet-20240229-v1:0/converse";
509+
let host = "bedrock-runtime.us-east-1.amazonaws.com";
510+
let body = r#"{"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}"#;
511+
512+
let result = generate_sigv4_headers(
513+
access_key,
514+
secret_key,
515+
region,
516+
service,
517+
method,
518+
uri,
519+
host,
520+
body,
521+
);
522+
523+
assert!(result.is_ok());
524+
let headers = result.unwrap();
525+
526+
// Check that required headers are present
527+
let header_map: std::collections::HashMap<String, String> = headers.into_iter().collect();
528+
529+
assert!(header_map.contains_key("authorization"));
530+
assert!(header_map.contains_key("x-amz-date"));
531+
assert!(header_map.contains_key("content-type"));
532+
533+
// Check authorization header format
534+
let auth_header = &header_map["authorization"];
535+
assert!(auth_header.starts_with("AWS4-HMAC-SHA256 Credential="));
536+
assert!(auth_header.contains("SignedHeaders="));
537+
assert!(auth_header.contains("Signature="));
538+
539+
// Check content-type
540+
assert_eq!(header_map["content-type"], "application/x-amz-json-1.0");
541+
542+
// Check x-amz-date format (should be ISO8601)
543+
let date_header = &header_map["x-amz-date"];
544+
assert_eq!(date_header.len(), 16); // YYYYMMDDTHHMMSSZ
545+
assert!(date_header.ends_with('Z'));
546+
assert!(date_header.contains('T'));
547+
}
548+
549+
#[test]
550+
fn test_canonical_headers_ordering() {
551+
let access_key = "AKIAIOSFODNN7EXAMPLE";
552+
let secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY";
553+
let region = "us-east-1";
554+
let service = "bedrock";
555+
let method = "POST";
556+
let uri = "/model/test/converse";
557+
let host = "bedrock-runtime.us-east-1.amazonaws.com";
558+
let body = "{}";
559+
560+
let result = generate_sigv4_headers(
561+
access_key,
562+
secret_key,
563+
region,
564+
service,
565+
method,
566+
uri,
567+
host,
568+
body,
569+
);
570+
571+
assert!(result.is_ok());
572+
let headers = result.unwrap();
573+
let header_map: std::collections::HashMap<String, String> = headers.into_iter().collect();
574+
575+
let auth_header = &header_map["authorization"];
576+
577+
// SignedHeaders should be in alphabetical order: content-type;host;x-amz-date
578+
assert!(auth_header.contains("SignedHeaders=content-type;host;x-amz-date"));
579+
}
580+
581+
#[test]
582+
fn test_bedrock_client_integration() {
583+
let client = BedrockClient::new(
584+
"AKIAIOSFODNN7EXAMPLE".to_string(),
585+
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
586+
"us-east-1".to_string(),
587+
);
588+
589+
let request = ConverseRequest {
590+
model_id: "anthropic.claude-3-sonnet-20240229-v1:0".to_string(),
591+
messages: vec![Message {
592+
role: Role::User,
593+
content: vec![ContentBlock::Text {
594+
text: "Hello, how are you?".to_string(),
595+
}],
596+
}],
597+
system: None,
598+
inference_config: None,
599+
tool_config: None,
600+
guardrail_config: None,
601+
additional_model_request_fields: None,
602+
};
603+
604+
// This test verifies that serialization and signing work together
605+
let body = serde_json::to_string(&request).expect("Failed to serialize request");
606+
let host = format!("bedrock-runtime.{}.amazonaws.com", client.region);
506607

608+
let headers_result = generate_sigv4_headers(
609+
&client.access_key_id,
610+
&client.secret_access_key,
611+
&client.region,
612+
"bedrock",
613+
"POST",
614+
"/model/anthropic.claude-3-sonnet-20240229-v1:0/converse",
615+
&host,
616+
&body,
617+
);
618+
619+
assert!(headers_result.is_ok());
620+
let headers = headers_result.unwrap();
621+
622+
// Verify all required headers are present for AWS API call
623+
let header_names: Vec<&str> = headers.iter().map(|(k, _)| k.as_str()).collect();
624+
assert!(header_names.contains(&"authorization"));
625+
assert!(header_names.contains(&"x-amz-date"));
626+
assert!(header_names.contains(&"content-type"));
627+
}
628+
}

llm/bedrock/src/conversions.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::client::{
22
ContentBlock, ConverseRequest, ConverseResponse, ImageFormat, ImageSource as ClientImageSource,
33
InferenceConfig, Message as ClientMessage, Role as ClientRole, StopReason, SystemContentBlock,
4-
Tool, ToolChoice, ToolConfig, ToolInputSchema, ToolResultContentBlock, ToolResultStatus,
4+
Tool, ToolChoice, ToolConfig, ToolSpec, ToolResultContentBlock, ToolResultStatus,
55
};
66
use base64::{engine::general_purpose, Engine as _};
77
use golem_llm::golem::llm::llm::{
@@ -287,10 +287,12 @@ fn message_to_system_content(message: &Message) -> Vec<SystemContentBlock> {
287287

288288
fn tool_definition_to_tool(tool: &ToolDefinition) -> Result<Tool, Error> {
289289
match serde_json::from_str(&tool.parameters_schema) {
290-
Ok(json_schema) => Ok(Tool::ToolSpec {
291-
name: tool.name.clone(),
292-
description: tool.description.clone().unwrap_or_default(),
293-
input_schema: ToolInputSchema { json: json_schema },
290+
Ok(json_schema) => Ok(Tool {
291+
tool_spec: ToolSpec {
292+
name: tool.name.clone(),
293+
description: tool.description.clone().unwrap_or_default(),
294+
input_schema: json_schema,
295+
},
294296
}),
295297
Err(error) => Err(Error {
296298
code: ErrorCode::InternalError,

test/components-rust/test-llm/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ path = "wit-generated"
3838

3939
[package.metadata.component.target.dependencies]
4040
"golem:llm" = { path = "wit-generated/deps/golem-llm" }
41-
"wasi:io" = { path = "wit-generated/deps/io" }
4241
"wasi:clocks" = { path = "wit-generated/deps/clocks" }
42+
"wasi:io" = { path = "wit-generated/deps/io" }
4343
"golem:rpc" = { path = "wit-generated/deps/golem-rpc" }
4444
"test:helper-client" = { path = "wit-generated/deps/test_helper-client" }
4545
"test:llm-exports" = { path = "wit-generated/deps/test_llm-exports" }

0 commit comments

Comments
 (0)