Skip to content

Commit 1d6fb47

Browse files
committed
cooprate with list avalibale models, remove model options
1 parent f9448b5 commit 1d6fb47

File tree

5 files changed

+122
-50
lines changed

5 files changed

+122
-50
lines changed

crates/chat-cli/src/api_client/error.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenError;
22
use amzn_codewhisperer_client::operation::generate_completions::GenerateCompletionsError;
33
use amzn_codewhisperer_client::operation::list_available_customizations::ListAvailableCustomizationsError;
4+
use amzn_codewhisperer_client::operation::list_available_models::ListAvailableModelsError;
45
use amzn_codewhisperer_client::operation::list_available_profiles::ListAvailableProfilesError;
56
use amzn_codewhisperer_client::operation::send_telemetry_event::SendTelemetryEventError;
67
pub use amzn_codewhisperer_streaming_client::operation::generate_assistant_response::GenerateAssistantResponseError;
@@ -93,6 +94,9 @@ pub enum ApiClientError {
9394
// Credential errors
9495
#[error("failed to load credentials: {}", .0)]
9596
Credentials(CredentialsError),
97+
98+
#[error(transparent)]
99+
ListAvailableModelsError(#[from] SdkError<ListAvailableModelsError, HttpResponse>),
96100
}
97101

98102
impl ApiClientError {
@@ -116,6 +120,7 @@ impl ApiClientError {
116120
Self::ModelOverloadedError { status_code, .. } => *status_code,
117121
Self::MonthlyLimitReached { status_code } => *status_code,
118122
Self::Credentials(_e) => None,
123+
Self::ListAvailableModelsError(e) => sdk_status_code(e),
119124
}
120125
}
121126
}
@@ -141,6 +146,7 @@ impl ReasonCode for ApiClientError {
141146
Self::ModelOverloadedError { .. } => "ModelOverloadedError".to_string(),
142147
Self::MonthlyLimitReached { .. } => "MonthlyLimitReached".to_string(),
143148
Self::Credentials(_) => "CredentialsError".to_string(),
149+
Self::ListAvailableModelsError(e) => sdk_error_code(e),
144150
}
145151
}
146152
}
@@ -188,6 +194,10 @@ mod tests {
188194
ListAvailableCustomizationsError::unhandled("<unhandled>"),
189195
response(),
190196
)),
197+
ApiClientError::ListAvailableModelsError(SdkError::service_error(
198+
ListAvailableModelsError::unhandled("<unhandled>"),
199+
response(),
200+
)),
191201
ApiClientError::ListAvailableServices(SdkError::service_error(
192202
ListCustomizationsError::unhandled("<unhandled>"),
193203
response(),

crates/chat-cli/src/api_client/mod.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ use std::time::Duration;
1212

1313
use amzn_codewhisperer_client::Client as CodewhispererClient;
1414
use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenOutput;
15+
use amzn_codewhisperer_client::types::Origin::Cli;
1516
use amzn_codewhisperer_client::types::{
17+
Model,
1618
OptOutPreference,
1719
SubscriptionStatus,
1820
TelemetryEvent,
@@ -232,6 +234,46 @@ impl ApiClient {
232234
Ok(profiles)
233235
}
234236

237+
pub async fn list_available_models(&self) -> Result<(Vec<Model>, Option<Model>), ApiClientError> {
238+
if cfg!(test) {
239+
return Ok((
240+
vec![
241+
Model::builder()
242+
.model_id("model-1")
243+
.description("Test Model 1")
244+
.build()
245+
.unwrap(),
246+
],
247+
Some(
248+
Model::builder()
249+
.model_id("model-1")
250+
.description("Test Model 1")
251+
.build()
252+
.unwrap(),
253+
),
254+
));
255+
}
256+
257+
// todo yifan: add default_model once API is ready
258+
let mut models = Vec::new();
259+
let default_model = None;
260+
let request = self
261+
.client
262+
.list_available_models()
263+
.set_origin(Some(Cli))
264+
.set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone()));
265+
let mut paginator = request.into_paginator().send();
266+
267+
while let Some(models_output) = paginator.next().await {
268+
models.extend(models_output?.models().iter().cloned());
269+
// if default_model.is_none() && output.default_model().is_some() {
270+
// default_model = output.default_model().cloned();
271+
// }
272+
}
273+
274+
Ok((models, default_model))
275+
}
276+
235277
pub async fn create_subscription_token(&self) -> Result<CreateSubscriptionTokenOutput, ApiClientError> {
236278
if cfg!(test) {
237279
return Ok(CreateSubscriptionTokenOutput::builder()

crates/chat-cli/src/cli/chat/cli/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ impl SlashCommand {
108108
Self::Hooks(args) => args.execute(session).await,
109109
Self::Usage(args) => args.execute(os, session).await,
110110
Self::Mcp(args) => args.execute(session).await,
111-
Self::Model(args) => args.execute(session).await,
111+
Self::Model(args) => args.execute(os, session).await,
112112
Self::Subscribe(args) => args.execute(os, session).await,
113113
Self::Persist(subcommand) => subcommand.execute(os, session).await,
114114
// Self::Root(subcommand) => {

crates/chat-cli/src/cli/chat/cli/model.rs

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,44 +20,63 @@ use crate::cli::chat::{
2020
};
2121
use crate::os::Os;
2222

23-
pub struct ModelOption {
24-
pub name: &'static str,
25-
pub model_id: &'static str,
26-
}
23+
// pub struct ModelOption {
24+
// pub name: &'static str,
25+
// pub model_id: &'static str,
26+
// }
2727

28-
pub const MODEL_OPTIONS: [ModelOption; 2] = [
29-
ModelOption {
30-
name: "claude-4-sonnet",
31-
model_id: "CLAUDE_SONNET_4_20250514_V1_0",
32-
},
33-
ModelOption {
34-
name: "claude-3.7-sonnet",
35-
model_id: "CLAUDE_3_7_SONNET_20250219_V1_0",
36-
},
37-
];
28+
// pub const MODEL_OPTIONS: [ModelOption; 2] = [
29+
// ModelOption {
30+
// name: "claude-4-sonnet",
31+
// model_id: "CLAUDE_SONNET_4_20250514_V1_0",
32+
// },
33+
// ModelOption {
34+
// name: "claude-3.7-sonnet",
35+
// model_id: "CLAUDE_3_7_SONNET_20250219_V1_0",
36+
// },
37+
// ];
3838

3939
#[deny(missing_docs)]
4040
#[derive(Debug, PartialEq, Args)]
4141
pub struct ModelArgs;
4242

4343
impl ModelArgs {
44-
pub async fn execute(self, session: &mut ChatSession) -> Result<ChatState, ChatError> {
45-
Ok(select_model(session)?.unwrap_or(ChatState::PromptUser {
44+
pub async fn execute(self, os: &mut Os, session: &mut ChatSession) -> Result<ChatState, ChatError> {
45+
Ok(select_model(os, session).await?.unwrap_or(ChatState::PromptUser {
4646
skip_printing_tools: false,
4747
}))
4848
}
4949
}
5050

51-
pub fn select_model(session: &mut ChatSession) -> Result<Option<ChatState>, ChatError> {
51+
pub async fn select_model(os: &mut Os, session: &mut ChatSession) -> Result<Option<ChatState>, ChatError> {
5252
queue!(session.stderr, style::Print("\n"))?;
53+
54+
// Fetch available models from service
55+
let (models, _default_model) = os
56+
.client
57+
.list_available_models()
58+
.await
59+
.map_err(|e| ChatError::Custom(format!("Failed to fetch available models: {}", e).into()))?;
60+
61+
if models.is_empty() {
62+
queue!(
63+
session.stderr,
64+
style::SetForegroundColor(Color::Red),
65+
style::Print("No models available\n"),
66+
style::ResetColor
67+
)?;
68+
return Ok(None);
69+
}
70+
5371
let active_model_id = session.conversation.model.as_deref();
54-
let labels: Vec<String> = MODEL_OPTIONS
72+
73+
let labels: Vec<String> = models
5574
.iter()
5675
.map(|opt| {
57-
if (opt.model_id.is_empty() && active_model_id.is_none()) || Some(opt.model_id) == active_model_id {
58-
format!("{} (active)", opt.name)
76+
if Some(opt.model_id.as_str()) == active_model_id {
77+
format!("{} (active)", opt.model_id)
5978
} else {
60-
opt.name.to_owned()
79+
opt.model_id.to_owned()
6180
}
6281
})
6382
.collect();
@@ -83,14 +102,14 @@ pub fn select_model(session: &mut ChatSession) -> Result<Option<ChatState>, Chat
83102
queue!(session.stderr, style::ResetColor)?;
84103

85104
if let Some(index) = selection {
86-
let selected = &MODEL_OPTIONS[index];
105+
let selected = &models[index];
87106
let model_id_str = selected.model_id.to_string();
88107
session.conversation.model = Some(model_id_str);
89108

90109
queue!(
91110
session.stderr,
92111
style::Print("\n"),
93-
style::Print(format!(" Using {}\n\n", selected.name)),
112+
style::Print(format!(" Using {}\n\n", selected.model_id)),
94113
style::ResetColor,
95114
style::SetForegroundColor(Color::Reset),
96115
style::SetBackgroundColor(Color::Reset),
@@ -106,21 +125,24 @@ pub fn select_model(session: &mut ChatSession) -> Result<Option<ChatState>, Chat
106125

107126
/// Returns Claude 3.7 for: Amazon IDC users, FRA region users
108127
/// Returns Claude 4.0 for: Builder ID users, other regions
109-
pub async fn default_model_id(os: &Os) -> &'static str {
128+
pub async fn default_model_id(os: &Os) -> String {
129+
if let Ok((_, Some(default_model))) = os.client.list_available_models().await {
130+
return default_model.model_id().to_string();
131+
}
110132
// Check FRA region first
111133
if let Ok(Some(profile)) = os.database.get_auth_profile() {
112134
if profile.arn.split(':').nth(3) == Some("eu-central-1") {
113-
return "CLAUDE_3_7_SONNET_20250219_V1_0";
135+
return "CLAUDE_3_7_SONNET_20250219_V1_0".to_string();
114136
}
115137
}
116138

117139
// Check if Amazon IDC user
118140
if let Ok(Some(token)) = BuilderIdToken::load(&os.database).await {
119141
if matches!(token.token_type(), TokenType::IamIdentityCenter) && token.is_amzn_user() {
120-
return "CLAUDE_3_7_SONNET_20250219_V1_0";
142+
return "CLAUDE_3_7_SONNET_20250219_V1_0".to_string();
121143
}
122144
}
123145

124146
// Default to 4.0
125-
"CLAUDE_SONNET_4_20250514_V1_0"
147+
"CLAUDE_SONNET_4_20250514_V1_0".to_string()
126148
}

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,7 @@ use crate::auth::AuthError;
126126
use crate::auth::builder_id::is_idc_user;
127127
use crate::cli::agent::Agents;
128128
use crate::cli::chat::cli::SlashCommand;
129-
use crate::cli::chat::cli::model::{
130-
MODEL_OPTIONS,
131-
default_model_id,
132-
};
129+
use crate::cli::chat::cli::model::default_model_id;
133130
use crate::cli::chat::cli::prompts::{
134131
GetPromptError,
135132
PromptsSubcommand,
@@ -267,15 +264,16 @@ impl ChatArgs {
267264
};
268265

269266
// If modelId is specified, verify it exists before starting the chat
270-
let model_id: Option<String> = if let Some(model_name) = self.model {
271-
let model_name_lower = model_name.to_lowercase();
272-
match MODEL_OPTIONS.iter().find(|opt| opt.name == model_name_lower) {
267+
let model_id: Option<String> = if let Some(requested_model_id) = self.model {
268+
let requested_model_id_lower = requested_model_id.to_lowercase();
269+
let (models, _) = os.client.list_available_models().await?;
270+
match models.iter().find(|opt| opt.model_id == requested_model_id_lower) {
273271
Some(opt) => Some((opt.model_id).to_string()),
274272
None => {
275-
let available_names: Vec<&str> = MODEL_OPTIONS.iter().map(|opt| opt.name).collect();
273+
let available_names: Vec<&str> = models.iter().map(|opt| opt.model_id()).collect();
276274
bail!(
277275
"Model '{}' does not exist. Available models: {}",
278-
model_name,
276+
requested_model_id,
279277
available_names.join(", ")
280278
);
281279
},
@@ -516,17 +514,19 @@ impl ChatSession {
516514
tool_config: HashMap<String, ToolSpec>,
517515
interactive: bool,
518516
) -> Result<Self> {
517+
let (models, _) = os.client.list_available_models().await?;
518+
519519
let valid_model_id = match model_id {
520520
Some(id) => id,
521521
None => {
522522
let from_settings = os
523523
.database
524524
.settings
525525
.get_string(Setting::ChatDefaultModel)
526-
.and_then(|model_name| {
527-
MODEL_OPTIONS
526+
.and_then(|model_id| {
527+
models
528528
.iter()
529-
.find(|opt| opt.name == model_name)
529+
.find(|opt| opt.model_id == model_id)
530530
.map(|opt| opt.model_id.to_owned())
531531
});
532532

@@ -1133,15 +1133,13 @@ impl ChatSession {
11331133
self.stderr.flush()?;
11341134

11351135
if let Some(ref id) = self.conversation.model {
1136-
if let Some(model_option) = MODEL_OPTIONS.iter().find(|option| option.model_id == *id) {
1137-
execute!(
1138-
self.stderr,
1139-
style::SetForegroundColor(Color::Cyan),
1140-
style::Print(format!("🤖 You are chatting with {}\n", model_option.name)),
1141-
style::SetForegroundColor(Color::Reset),
1142-
style::Print("\n")
1143-
)?;
1144-
}
1136+
execute!(
1137+
self.stderr,
1138+
style::SetForegroundColor(Color::Cyan),
1139+
style::Print(format!("🤖 You are chatting with {}\n", id)),
1140+
style::SetForegroundColor(Color::Reset),
1141+
style::Print("\n")
1142+
)?;
11451143
}
11461144

11471145
if let Some(user_input) = self.initial_input.take() {
@@ -2321,7 +2319,7 @@ impl ChatSession {
23212319
}
23222320

23232321
async fn retry_model_overload(&mut self, os: &mut Os) -> Result<ChatState, ChatError> {
2324-
match select_model(self) {
2322+
match select_model(os, self).await {
23252323
Ok(Some(_)) => (),
23262324
Ok(None) => {
23272325
// User did not select a model, so reset the current request state.

0 commit comments

Comments
 (0)