@@ -13,7 +13,6 @@ use aws_sdk_bedrockruntime::types::{
13
13
ContentBlock , ContentBlockDelta , ConversationRole , ConverseStreamOutput ,
14
14
InferenceConfiguration , Message ,
15
15
} ;
16
- use serde:: { Deserialize , Serialize } ;
17
16
18
17
pub struct LargeLanguageModel {
19
18
#[ expect( dead_code) ]
@@ -22,65 +21,69 @@ pub struct LargeLanguageModel {
22
21
#[ expect( dead_code) ]
23
22
bedrock_client : aws_sdk_bedrock:: Client ,
24
23
inference_parameters : InferenceConfiguration ,
25
- model_id : ArgModel ,
24
+ model_id : String ,
26
25
}
27
26
28
- #[ derive( Clone , Serialize , Deserialize , Debug , Copy ) ]
29
- pub enum ArgModel {
30
- Llama270b ,
31
- CohereCommand ,
32
- ClaudeV2 ,
33
- ClaudeV21 ,
34
- ClaudeV3Sonnet ,
35
- ClaudeV3Haiku ,
36
- ClaudeV35Sonnet ,
37
- Jurrasic2Ultra ,
38
- TitanTextExpressV1 ,
39
- Mixtral8x7bInstruct ,
40
- Mistral7bInstruct ,
41
- MistralLarge ,
42
- MistralLarge2 ,
43
- }
44
-
45
- impl std:: fmt:: Display for ArgModel {
46
- fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
47
- write ! ( f, "{}" , self . model_id_str( ) )
48
- }
49
- }
50
-
51
- impl ArgModel {
52
- pub fn model_id_str ( & self ) -> & ' static str {
53
- match self {
54
- ArgModel :: ClaudeV2 => "anthropic.claude-v2" ,
55
- ArgModel :: ClaudeV21 => "anthropic.claude-v2:1" ,
56
- ArgModel :: ClaudeV3Haiku => "anthropic.claude-3-haiku-20240307-v1:0" ,
57
- ArgModel :: ClaudeV3Sonnet => "anthropic.claude-3-sonnet-20240229-v1:0" ,
58
- ArgModel :: ClaudeV35Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0" ,
59
- ArgModel :: Llama270b => "meta.llama2-70b-chat-v1" ,
60
- ArgModel :: CohereCommand => "cohere.command-text-v14" ,
61
- ArgModel :: Jurrasic2Ultra => "ai21.j2-ultra-v1" ,
62
- ArgModel :: TitanTextExpressV1 => "amazon.titan-text-express-v1" ,
63
- ArgModel :: Mixtral8x7bInstruct => "mistral.mixtral-8x7b-instruct-v0:1" ,
64
- ArgModel :: Mistral7bInstruct => "mistral.mistral-7b-instruct-v0:2" ,
65
- ArgModel :: MistralLarge => "mistral.mistral-large-2402-v1:0" ,
66
- ArgModel :: MistralLarge2 => "mistral.mistral-large-2407-v1:0" ,
67
- }
68
- }
69
- }
27
+ const MODELS : & [ ( & str , & str ) ] = & [
28
+ ( "ClaudeV2" , "anthropic.claude-v2" ) ,
29
+ ( "ClaudeV21" , "anthropic.claude-v2:1" ) ,
30
+ ( "ClaudeV3Haiku" , "anthropic.claude-3-haiku-20240307-v1:0" ) ,
31
+ ( "ClaudeV3Sonnet" , "anthropic.claude-3-sonnet-20240229-v1:0" ) ,
32
+ (
33
+ "ClaudeV35Sonnet" ,
34
+ "anthropic.claude-3-5-sonnet-20240620-v1:0" ,
35
+ ) ,
36
+ ( "Llama270b" , "meta.llama2-70b-chat-v1" ) ,
37
+ ( "CohereCommand" , "cohere.command-text-v14" ) ,
38
+ ( "Jurrasic2Ultra" , "ai21.j2-ultra-v1" ) ,
39
+ ( "TitanTextExpressV1" , "amazon.titan-text-express-v1" ) ,
40
+ ( "Mixtral8x7bInstruct" , "mistral.mixtral-8x7b-instruct-v0:1" ) ,
41
+ ( "Mistral7bInstruct" , "mistral.mistral-7b-instruct-v0:2" ) ,
42
+ ( "MistralLarge" , "mistral.mistral-large-2402-v1:0" ) ,
43
+ ( "MistralLarge2" , "mistral.mistral-large-2407-v1:0" ) ,
44
+ ] ;
70
45
71
46
impl LargeLanguageModel {
72
- pub async fn new ( ) -> Self {
73
- let aws_config = Self :: aws_config ( "us-east-1" , "default" ) . await ;
47
+ pub async fn new ( model_id : Option < & str > , region : Option < & str > ) -> anyhow:: Result < Self > {
48
+ let model_id = Self :: lookup_model_id ( model_id) ?;
49
+ let region = region. unwrap_or ( "us-east-1" ) ;
50
+
51
+ let aws_config = Self :: aws_config ( region, "default" ) . await ;
74
52
let bedrock_runtime_client = aws_sdk_bedrockruntime:: Client :: new ( & aws_config) ;
75
53
let bedrock_client = aws_sdk_bedrock:: Client :: new ( & aws_config) ;
76
54
let inference_parameters = InferenceConfiguration :: builder ( ) . build ( ) ;
77
- Self {
55
+ Ok ( Self {
78
56
aws_config,
79
57
bedrock_runtime_client,
80
58
bedrock_client,
81
59
inference_parameters,
82
- model_id : ArgModel :: ClaudeV3Sonnet ,
60
+ model_id,
61
+ } )
62
+ }
63
+
64
+ fn lookup_model_id ( model_id : Option < & str > ) -> anyhow:: Result < String > {
65
+ let Some ( s) = model_id else {
66
+ return Self :: lookup_model_id ( Some ( "ClaudeV3Sonnet" ) ) ;
67
+ } ;
68
+
69
+ if s. contains ( "." ) {
70
+ return Ok ( s. to_string ( ) ) ;
83
71
}
72
+
73
+ for & ( key, value) in MODELS {
74
+ if key == s {
75
+ return Ok ( value. to_string ( ) ) ;
76
+ }
77
+ }
78
+
79
+ anyhow:: bail!(
80
+ "unknown model-id; try one of the following: [{}]" ,
81
+ MODELS
82
+ . iter( )
83
+ . map( |& ( k, _) | k)
84
+ . collect:: <Vec <_>>( )
85
+ . join( ", " )
86
+ ) ;
84
87
}
85
88
86
89
pub async fn query ( & self , prompt : & str , query : & str ) -> anyhow:: Result < String > {
@@ -89,7 +92,7 @@ impl LargeLanguageModel {
89
92
let mut output = self
90
93
. bedrock_runtime_client
91
94
. converse_stream ( )
92
- . model_id ( self . model_id . model_id_str ( ) )
95
+ . model_id ( & self . model_id )
93
96
. messages (
94
97
Message :: builder ( )
95
98
. role ( ConversationRole :: Assistant )
0 commit comments