1
1
use aws_sdk_bedrockruntime:: types as aws_bedrock;
2
-
3
2
use rig:: {
4
3
completion:: CompletionError ,
5
4
message:: { ContentFormat , Document } ,
6
5
} ;
7
6
8
7
pub ( crate ) use crate :: types:: media_types:: RigDocumentMediaType ;
9
8
use base64:: { prelude:: BASE64_STANDARD , Engine } ;
9
+ use uuid:: Uuid ;
10
10
11
11
#[ derive( Clone ) ]
12
12
pub struct RigDocument ( pub Document ) ;
@@ -15,27 +15,33 @@ impl TryFrom<RigDocument> for aws_bedrock::DocumentBlock {
15
15
type Error = CompletionError ;
16
16
17
17
fn try_from ( value : RigDocument ) -> Result < Self , Self :: Error > {
18
- let maybe_format = value
18
+ let document_media_type = value
19
19
. 0
20
20
. media_type
21
21
. map ( |doc| RigDocumentMediaType ( doc) . try_into ( ) ) ;
22
22
23
- let format = match maybe_format {
23
+ let document_media_type = match document_media_type {
24
24
Some ( Ok ( document_format) ) => Ok ( Some ( document_format) ) ,
25
25
Some ( Err ( err) ) => Err ( err) ,
26
26
None => Ok ( None ) ,
27
27
} ?;
28
28
29
- let document_data = BASE64_STANDARD
30
- . decode ( value. 0 . data )
31
- . map_err ( |e| CompletionError :: ProviderError ( e. to_string ( ) ) ) ?;
29
+ let document_data = match value. 0 . format {
30
+ Some ( ContentFormat :: Base64 ) => BASE64_STANDARD
31
+ . decode ( value. 0 . data )
32
+ . map_err ( |e| CompletionError :: ProviderError ( e. to_string ( ) ) ) ?,
33
+ _ => value. 0 . data . as_bytes ( ) . to_vec ( ) ,
34
+ } ;
35
+
32
36
let data = aws_smithy_types:: Blob :: new ( document_data) ;
33
37
let document_source = aws_bedrock:: DocumentSource :: Bytes ( data) ;
34
38
39
+ let random_string = Uuid :: new_v4 ( ) . simple ( ) . to_string ( ) ;
40
+ let document_name = format ! ( "document-{}" , random_string) ;
35
41
let result = aws_bedrock:: DocumentBlock :: builder ( )
36
42
. source ( document_source)
37
- . name ( "Document" )
38
- . set_format ( format )
43
+ . name ( document_name )
44
+ . set_format ( document_media_type )
39
45
. build ( )
40
46
. map_err ( |e| CompletionError :: ProviderError ( e. to_string ( ) ) ) ?;
41
47
Ok ( result)
@@ -82,13 +88,35 @@ mod tests {
82
88
fn test_document_to_aws_document ( ) {
83
89
let rig_document = RigDocument ( Document {
84
90
data : "data" . into ( ) ,
85
- format : Some ( ContentFormat :: Base64 ) ,
91
+ format : Some ( ContentFormat :: String ) ,
86
92
media_type : Some ( DocumentMediaType :: PDF ) ,
87
93
} ) ;
88
94
let aws_document: Result < aws_bedrock:: DocumentBlock , _ > = rig_document. clone ( ) . try_into ( ) ;
89
95
assert_eq ! ( aws_document. is_ok( ) , true ) ;
90
96
let aws_document = aws_document. unwrap ( ) ;
91
97
assert_eq ! ( aws_document. format, aws_bedrock:: DocumentFormat :: Pdf ) ;
98
+ let document_data = rig_document. 0 . data . as_bytes ( ) . to_vec ( ) ;
99
+ let aws_document_bytes = aws_document
100
+ . source ( )
101
+ . unwrap ( )
102
+ . as_bytes ( )
103
+ . unwrap ( )
104
+ . as_ref ( )
105
+ . to_owned ( ) ;
106
+
107
+ let doc_name = aws_document. name ;
108
+ assert ! ( doc_name. starts_with( "document-" ) ) ;
109
+ assert_eq ! ( aws_document_bytes, document_data)
110
+ }
111
+
112
+ #[ test]
113
+ fn test_base64_document_to_aws_document ( ) {
114
+ let rig_document = RigDocument ( Document {
115
+ data : "data" . into ( ) ,
116
+ format : Some ( ContentFormat :: Base64 ) ,
117
+ media_type : Some ( DocumentMediaType :: PDF ) ,
118
+ } ) ;
119
+ let aws_document: aws_bedrock:: DocumentBlock = rig_document. clone ( ) . try_into ( ) . unwrap ( ) ;
92
120
let document_data = BASE64_STANDARD . decode ( rig_document. 0 . data ) . unwrap ( ) ;
93
121
let aws_document_bytes = aws_document
94
122
. source ( )
@@ -104,7 +132,7 @@ mod tests {
104
132
fn test_unsupported_document_to_aws_document ( ) {
105
133
let rig_document = RigDocument ( Document {
106
134
data : "data" . into ( ) ,
107
- format : Some ( ContentFormat :: Base64 ) ,
135
+ format : Some ( ContentFormat :: String ) ,
108
136
media_type : Some ( DocumentMediaType :: Javascript ) ,
109
137
} ) ;
110
138
let aws_document: Result < aws_bedrock:: DocumentBlock , _ > = rig_document. clone ( ) . try_into ( ) ;
0 commit comments