Skip to content

Commit 975b00a

Browse files
authored
fix: fixed bug with base64 encoding on AWS Bedrock (#432)
* fix: fixed bug with base64 encoding on AWS Bedrock * fix: use uuid v4 for document names * fix: prefix documents names
1 parent 8c12b16 commit 975b00a

File tree

3 files changed

+40
-10
lines changed

3 files changed

+40
-10
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rig-bedrock/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ aws-sdk-bedrockruntime = "1.77.0"
1818
aws-smithy-types = "1.3.0"
1919
base64 = "0.22.1"
2020
async-stream = "0.3.6"
21+
uuid = { version = "1.16.0", features = ["v4"]}
2122

2223
[dev-dependencies]
2324
anyhow = "1.0.75"

rig-bedrock/src/types/document.rs

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
use aws_sdk_bedrockruntime::types as aws_bedrock;
2-
32
use rig::{
43
completion::CompletionError,
54
message::{ContentFormat, Document},
65
};
76

87
pub(crate) use crate::types::media_types::RigDocumentMediaType;
98
use base64::{prelude::BASE64_STANDARD, Engine};
9+
use uuid::Uuid;
1010

1111
#[derive(Clone)]
1212
pub struct RigDocument(pub Document);
@@ -15,27 +15,33 @@ impl TryFrom<RigDocument> for aws_bedrock::DocumentBlock {
1515
type Error = CompletionError;
1616

1717
fn try_from(value: RigDocument) -> Result<Self, Self::Error> {
18-
let maybe_format = value
18+
let document_media_type = value
1919
.0
2020
.media_type
2121
.map(|doc| RigDocumentMediaType(doc).try_into());
2222

23-
let format = match maybe_format {
23+
let document_media_type = match document_media_type {
2424
Some(Ok(document_format)) => Ok(Some(document_format)),
2525
Some(Err(err)) => Err(err),
2626
None => Ok(None),
2727
}?;
2828

29-
let document_data = BASE64_STANDARD
30-
.decode(value.0.data)
31-
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
29+
let document_data = match value.0.format {
30+
Some(ContentFormat::Base64) => BASE64_STANDARD
31+
.decode(value.0.data)
32+
.map_err(|e| CompletionError::ProviderError(e.to_string()))?,
33+
_ => value.0.data.as_bytes().to_vec(),
34+
};
35+
3236
let data = aws_smithy_types::Blob::new(document_data);
3337
let document_source = aws_bedrock::DocumentSource::Bytes(data);
3438

39+
let random_string = Uuid::new_v4().simple().to_string();
40+
let document_name = format!("document-{}", random_string);
3541
let result = aws_bedrock::DocumentBlock::builder()
3642
.source(document_source)
37-
.name("Document")
38-
.set_format(format)
43+
.name(document_name)
44+
.set_format(document_media_type)
3945
.build()
4046
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
4147
Ok(result)
@@ -82,13 +88,35 @@ mod tests {
8288
fn test_document_to_aws_document() {
8389
let rig_document = RigDocument(Document {
8490
data: "data".into(),
85-
format: Some(ContentFormat::Base64),
91+
format: Some(ContentFormat::String),
8692
media_type: Some(DocumentMediaType::PDF),
8793
});
8894
let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();
8995
assert_eq!(aws_document.is_ok(), true);
9096
let aws_document = aws_document.unwrap();
9197
assert_eq!(aws_document.format, aws_bedrock::DocumentFormat::Pdf);
98+
let document_data = rig_document.0.data.as_bytes().to_vec();
99+
let aws_document_bytes = aws_document
100+
.source()
101+
.unwrap()
102+
.as_bytes()
103+
.unwrap()
104+
.as_ref()
105+
.to_owned();
106+
107+
let doc_name = aws_document.name;
108+
assert!(doc_name.starts_with("document-"));
109+
assert_eq!(aws_document_bytes, document_data)
110+
}
111+
112+
#[test]
113+
fn test_base64_document_to_aws_document() {
114+
let rig_document = RigDocument(Document {
115+
data: "data".into(),
116+
format: Some(ContentFormat::Base64),
117+
media_type: Some(DocumentMediaType::PDF),
118+
});
119+
let aws_document: aws_bedrock::DocumentBlock = rig_document.clone().try_into().unwrap();
92120
let document_data = BASE64_STANDARD.decode(rig_document.0.data).unwrap();
93121
let aws_document_bytes = aws_document
94122
.source()
@@ -104,7 +132,7 @@ mod tests {
104132
fn test_unsupported_document_to_aws_document() {
105133
let rig_document = RigDocument(Document {
106134
data: "data".into(),
107-
format: Some(ContentFormat::Base64),
135+
format: Some(ContentFormat::String),
108136
media_type: Some(DocumentMediaType::Javascript),
109137
});
110138
let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();

0 commit comments

Comments
 (0)