Skip to content

Commit 2013a83

Browse files
authored
fix: update error message for deploy (#130)
1 parent 65e9c1c commit 2013a83

File tree

1 file changed

+35
-30
lines changed

1 file changed

+35
-30
lines changed

src/emd/commands/deploy.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@
4040
from questionary import Style
4141

4242
custom_style = Style([
43-
('qmark', 'fg:#66BB6A bold'), # 问题前的标记
44-
('question', 'fg:default'), # 将问题文本颜色设置为默认颜色
45-
('answer', 'fg:#4CAF50 bold'), # 提交的答案文本
46-
('pointer', 'fg:#66BB6A bold'), # 选择提示符
47-
('highlighted', 'fg:#4CAF50 bold'), # 高亮的选项
48-
('selected', 'fg:#A5D6A7 bold'), # 选中的选项
49-
('disabled', 'fg:#CED4DA italic'), # 禁用的选项
50-
('error', 'fg:#F44336 bold'), # 错误信息
43+
('qmark', 'fg:#66BB6A bold'),
44+
('question', 'fg:default'),
45+
('answer', 'fg:#4CAF50 bold'),
46+
('pointer', 'fg:#66BB6A bold'),
47+
('highlighted', 'fg:#4CAF50 bold'),
48+
('selected', 'fg:#A5D6A7 bold'),
49+
('disabled', 'fg:#CED4DA italic'),
50+
('error', 'fg:#F44336 bold'),
5151
])
5252

5353
def show_help(choice):
@@ -145,7 +145,6 @@ def is_valid_model_tag(name,pattern=MODEL_TAG_PATTERN):
145145
return bool(re.match(pattern, name))
146146

147147

148-
# Define a natural sort key function to handle numeric values in model names
149148
def natural_sort_key(s):
150149
# Split the string into text and numeric parts
151150
return [int(c) if c.isdigit() else float(c) if c.replace('.', '', 1).isdigit() else c.lower()
@@ -173,30 +172,31 @@ def ask_model_id(region, allow_local_deploy, only_allow_local_deploy, model_id=N
173172
session = PromptSession(
174173
completer=completer,
175174
complete_while_typing=True,
175+
rprompt=HTML('<span fg="#888888">(Run "emd list-supported-models" for full model list)</span>')
176176
)
177177

178178
def get_prompt_message():
179-
buffer = get_app().current_buffer
180-
if buffer.text:
181-
return HTML('<b>? Enter model name: </b>')
182-
else:
183-
return HTML('<b>? Enter model name: </b><span fg="#888888">(Type to search, run "emd list-supported-models" for full model list)</span>')
179+
return HTML('<b>? Enter model name: </b>')
184180

185-
selected_model = session.prompt(get_prompt_message, pre_run=lambda: get_app().current_buffer.start_completion())
181+
while True:
182+
selected_model = session.prompt(
183+
get_prompt_message,
184+
pre_run=lambda: get_app().current_buffer.start_completion()
185+
)
186186

187-
if not selected_model:
188-
console.print("[bold yellow]Model selection cancelled[/bold yellow]")
189-
raise typer.Exit(0)
187+
if not selected_model:
188+
console.print("[bold yellow]Model selection cancelled[/bold yellow]")
189+
raise typer.Exit(0)
190190

191-
if selected_model not in model_ids:
192-
console.print(f"[bold red]Invalid model name: {selected_model}[/bold red]")
193-
raise typer.Exit(1)
191+
if selected_model not in model_ids:
192+
console.print(f"[bold #FFA726]Invalid model name, please try again or press Ctrl+C to cancel[/bold #FFA726]")
193+
continue
194194

195-
return selected_model
195+
return selected_model
196196

197197
except Exception as e:
198198
if not isinstance(e, (ModelNotSupported, typer.Exit)):
199-
console.print(f"[bold red]Error during model selection: {str(e)}[/bold red]")
199+
console.print(f"[bold #FFA726]Error during model selection: {str(e)}[/bold #FFA726]")
200200
raise typer.Exit(1)
201201
raise
202202

@@ -304,7 +304,10 @@ def deploy(
304304
],
305305
style=custom_style
306306
).ask()
307-
service_type = Service.get_service_from_name(service_name).service_type
307+
try:
308+
service_type = Service.get_service_from_name(service_name).service_type
309+
except:
310+
raise typer.Exit(0)
308311
else:
309312
service_type = supported_services[0].service_type
310313
console.print(f"[bold blue]service type: {supported_services[0].name}[/bold blue]")
@@ -509,8 +512,7 @@ def deploy(
509512
if extra_params is None:
510513
while True:
511514
extra_params = questionary.text(
512-
"(Optional) Additional parameters, you can skip by pressing Enter:",
513-
instruction="Usage: https://aws-samples.github.io/easy-model-deployer/en/best_deployment_practices/#extra-parameters-usage",
515+
"(Optional) Additional parameters, usage (https://aws-samples.github.io/easy-model-deployer/en/best_deployment_practices/#extra-parameters-usage), you can skip by pressing Enter:",
514516
default="{}"
515517
).ask()
516518

@@ -543,10 +545,13 @@ def deploy(
543545
# else:
544546
# console.print(f"[bold blue] model tag: {model_tag}[/bold blue]")
545547
# break
546-
if not model_tag and not is_valid_model_tag(model_tag):
547-
console.print(f"[bold blue]invalid model tag: {model_tag}. Please ensure that the tag complies with the standard rules: {MODEL_TAG_PATTERN}.[/bold blue]")
548-
else:
549-
break
548+
try:
549+
if not model_tag and not is_valid_model_tag(model_tag):
550+
console.print(f"[bold blue]invalid model tag: {model_tag}. Please ensure that the tag complies with the standard rules: {MODEL_TAG_PATTERN}.[/bold blue]")
551+
else:
552+
break
553+
except:
554+
raise typer.Exit(0)
550555

551556
if not model_tag:
552557
raise ValueError("Model tag is required.")

0 commit comments

Comments
 (0)