Skip to content

Commit feedb3b

Browse files
committed
allow llm model selection
1 parent 4208203 commit feedb3b

File tree

3 files changed

+67
-50
lines changed

3 files changed

+67
-50
lines changed

mdbook-goals/src/llm.rs

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ use aws_sdk_bedrockruntime::types::{
1313
ContentBlock, ContentBlockDelta, ConversationRole, ConverseStreamOutput,
1414
InferenceConfiguration, Message,
1515
};
16-
use serde::{Deserialize, Serialize};
1716

1817
pub struct LargeLanguageModel {
1918
#[expect(dead_code)]
@@ -22,65 +21,69 @@ pub struct LargeLanguageModel {
2221
#[expect(dead_code)]
2322
bedrock_client: aws_sdk_bedrock::Client,
2423
inference_parameters: InferenceConfiguration,
25-
model_id: ArgModel,
24+
model_id: String,
2625
}
2726

28-
#[derive(Clone, Serialize, Deserialize, Debug, Copy)]
29-
pub enum ArgModel {
30-
Llama270b,
31-
CohereCommand,
32-
ClaudeV2,
33-
ClaudeV21,
34-
ClaudeV3Sonnet,
35-
ClaudeV3Haiku,
36-
ClaudeV35Sonnet,
37-
Jurrasic2Ultra,
38-
TitanTextExpressV1,
39-
Mixtral8x7bInstruct,
40-
Mistral7bInstruct,
41-
MistralLarge,
42-
MistralLarge2,
43-
}
44-
45-
impl std::fmt::Display for ArgModel {
46-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47-
write!(f, "{}", self.model_id_str())
48-
}
49-
}
50-
51-
impl ArgModel {
52-
pub fn model_id_str(&self) -> &'static str {
53-
match self {
54-
ArgModel::ClaudeV2 => "anthropic.claude-v2",
55-
ArgModel::ClaudeV21 => "anthropic.claude-v2:1",
56-
ArgModel::ClaudeV3Haiku => "anthropic.claude-3-haiku-20240307-v1:0",
57-
ArgModel::ClaudeV3Sonnet => "anthropic.claude-3-sonnet-20240229-v1:0",
58-
ArgModel::ClaudeV35Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0",
59-
ArgModel::Llama270b => "meta.llama2-70b-chat-v1",
60-
ArgModel::CohereCommand => "cohere.command-text-v14",
61-
ArgModel::Jurrasic2Ultra => "ai21.j2-ultra-v1",
62-
ArgModel::TitanTextExpressV1 => "amazon.titan-text-express-v1",
63-
ArgModel::Mixtral8x7bInstruct => "mistral.mixtral-8x7b-instruct-v0:1",
64-
ArgModel::Mistral7bInstruct => "mistral.mistral-7b-instruct-v0:2",
65-
ArgModel::MistralLarge => "mistral.mistral-large-2402-v1:0",
66-
ArgModel::MistralLarge2 => "mistral.mistral-large-2407-v1:0",
67-
}
68-
}
69-
}
27+
const MODELS: &[(&str, &str)] = &[
28+
("ClaudeV2", "anthropic.claude-v2"),
29+
("ClaudeV21", "anthropic.claude-v2:1"),
30+
("ClaudeV3Haiku", "anthropic.claude-3-haiku-20240307-v1:0"),
31+
("ClaudeV3Sonnet", "anthropic.claude-3-sonnet-20240229-v1:0"),
32+
(
33+
"ClaudeV35Sonnet",
34+
"anthropic.claude-3-5-sonnet-20240620-v1:0",
35+
),
36+
("Llama270b", "meta.llama2-70b-chat-v1"),
37+
("CohereCommand", "cohere.command-text-v14"),
38+
("Jurrasic2Ultra", "ai21.j2-ultra-v1"),
39+
("TitanTextExpressV1", "amazon.titan-text-express-v1"),
40+
("Mixtral8x7bInstruct", "mistral.mixtral-8x7b-instruct-v0:1"),
41+
("Mistral7bInstruct", "mistral.mistral-7b-instruct-v0:2"),
42+
("MistralLarge", "mistral.mistral-large-2402-v1:0"),
43+
("MistralLarge2", "mistral.mistral-large-2407-v1:0"),
44+
];
7045

7146
impl LargeLanguageModel {
72-
pub async fn new() -> Self {
73-
let aws_config = Self::aws_config("us-east-1", "default").await;
47+
pub async fn new(model_id: Option<&str>, region: Option<&str>) -> anyhow::Result<Self> {
48+
let model_id = Self::lookup_model_id(model_id)?;
49+
let region = region.unwrap_or("us-east-1");
50+
51+
let aws_config = Self::aws_config(region, "default").await;
7452
let bedrock_runtime_client = aws_sdk_bedrockruntime::Client::new(&aws_config);
7553
let bedrock_client = aws_sdk_bedrock::Client::new(&aws_config);
7654
let inference_parameters = InferenceConfiguration::builder().build();
77-
Self {
55+
Ok(Self {
7856
aws_config,
7957
bedrock_runtime_client,
8058
bedrock_client,
8159
inference_parameters,
82-
model_id: ArgModel::ClaudeV3Sonnet,
60+
model_id,
61+
})
62+
}
63+
64+
fn lookup_model_id(model_id: Option<&str>) -> anyhow::Result<String> {
65+
let Some(s) = model_id else {
66+
return Self::lookup_model_id(Some("ClaudeV3Sonnet"));
67+
};
68+
69+
if s.contains(".") {
70+
return Ok(s.to_string());
8371
}
72+
73+
for &(key, value) in MODELS {
74+
if key == s {
75+
return Ok(value.to_string());
76+
}
77+
}
78+
79+
anyhow::bail!(
80+
"unknown model-id; try one of the following: [{}]",
81+
MODELS
82+
.iter()
83+
.map(|&(k, _)| k)
84+
.collect::<Vec<_>>()
85+
.join(", ")
86+
);
8487
}
8588

8689
pub async fn query(&self, prompt: &str, query: &str) -> anyhow::Result<String> {
@@ -89,7 +92,7 @@ impl LargeLanguageModel {
8992
let mut output = self
9093
.bedrock_runtime_client
9194
.converse_stream()
92-
.model_id(self.model_id.model_id_str())
95+
.model_id(&self.model_id)
9396
.messages(
9497
Message::builder()
9598
.role(ConversationRole::Assistant)

mdbook-goals/src/main.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,14 @@ enum Command {
110110
/// End date for comments.
111111
/// If not given, no end date.
112112
end_date: Option<chrono::NaiveDate>,
113+
114+
/// Set a custom model id for the LLM.
115+
#[structopt(long)]
116+
model_id: Option<String>,
117+
118+
/// Set a custom region.
119+
#[structopt(long)]
120+
region: Option<String>,
113121
},
114122
}
115123

@@ -167,6 +175,8 @@ async fn main() -> anyhow::Result<()> {
167175
end_date,
168176
quick,
169177
vscode,
178+
model_id,
179+
region,
170180
} => {
171181
updates::updates(
172182
&opt.repository,
@@ -176,6 +186,8 @@ async fn main() -> anyhow::Result<()> {
176186
end_date,
177187
*quick,
178188
*vscode,
189+
model_id.as_deref(),
190+
region.as_deref(),
179191
)
180192
.await?;
181193
}

mdbook-goals/src/updates.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,14 @@ pub async fn updates(
4545
end_date: &Option<NaiveDate>,
4646
quick: bool,
4747
vscode: bool,
48+
model_id: Option<&str>,
49+
region: Option<&str>,
4850
) -> anyhow::Result<()> {
4951
if output_file.is_none() && !vscode {
5052
anyhow::bail!("either `--output-file` or `--vscode` must be specified");
5153
}
5254

53-
let llm = LargeLanguageModel::new().await;
55+
let llm = LargeLanguageModel::new(model_id, region).await?;
5456

5557
let issues = list_issue_titles_in_milestone(repository, milestone)?;
5658

0 commit comments

Comments
 (0)