|
40 | 40 | from questionary import Style
|
41 | 41 |
|
42 | 42 | custom_style = Style([
|
43 |
| - ('qmark', 'fg:#66BB6A bold'), # 问题前的标记 |
44 |
| - ('question', 'fg:default'), # 将问题文本颜色设置为默认颜色 |
45 |
| - ('answer', 'fg:#4CAF50 bold'), # 提交的答案文本 |
46 |
| - ('pointer', 'fg:#66BB6A bold'), # 选择提示符 |
47 |
| - ('highlighted', 'fg:#4CAF50 bold'), # 高亮的选项 |
48 |
| - ('selected', 'fg:#A5D6A7 bold'), # 选中的选项 |
49 |
| - ('disabled', 'fg:#CED4DA italic'), # 禁用的选项 |
50 |
| - ('error', 'fg:#F44336 bold'), # 错误信息 |
| 43 | + ('qmark', 'fg:#66BB6A bold'), |
| 44 | + ('question', 'fg:default'), |
| 45 | + ('answer', 'fg:#4CAF50 bold'), |
| 46 | + ('pointer', 'fg:#66BB6A bold'), |
| 47 | + ('highlighted', 'fg:#4CAF50 bold'), |
| 48 | + ('selected', 'fg:#A5D6A7 bold'), |
| 49 | + ('disabled', 'fg:#CED4DA italic'), |
| 50 | + ('error', 'fg:#F44336 bold'), |
51 | 51 | ])
|
52 | 52 |
|
53 | 53 | def show_help(choice):
|
@@ -145,7 +145,6 @@ def is_valid_model_tag(name,pattern=MODEL_TAG_PATTERN):
|
145 | 145 | return bool(re.match(pattern, name))
|
146 | 146 |
|
147 | 147 |
|
148 |
| -# Define a natural sort key function to handle numeric values in model names |
149 | 148 | def natural_sort_key(s):
|
150 | 149 | # Split the string into text and numeric parts
|
151 | 150 | return [int(c) if c.isdigit() else float(c) if c.replace('.', '', 1).isdigit() else c.lower()
|
@@ -173,30 +172,31 @@ def ask_model_id(region, allow_local_deploy, only_allow_local_deploy, model_id=N
|
173 | 172 | session = PromptSession(
|
174 | 173 | completer=completer,
|
175 | 174 | complete_while_typing=True,
|
| 175 | + rprompt=HTML('<span fg="#888888">(Run "emd list-supported-models" for full model list)</span>') |
176 | 176 | )
|
177 | 177 |
|
178 | 178 | def get_prompt_message():
|
179 |
| - buffer = get_app().current_buffer |
180 |
| - if buffer.text: |
181 |
| - return HTML('<b>? Enter model name: </b>') |
182 |
| - else: |
183 |
| - return HTML('<b>? Enter model name: </b><span fg="#888888">(Type to search, run "emd list-supported-models" for full model list)</span>') |
| 179 | + return HTML('<b>? Enter model name: </b>') |
184 | 180 |
|
185 |
| - selected_model = session.prompt(get_prompt_message, pre_run=lambda: get_app().current_buffer.start_completion()) |
| 181 | + while True: |
| 182 | + selected_model = session.prompt( |
| 183 | + get_prompt_message, |
| 184 | + pre_run=lambda: get_app().current_buffer.start_completion() |
| 185 | + ) |
186 | 186 |
|
187 |
| - if not selected_model: |
188 |
| - console.print("[bold yellow]Model selection cancelled[/bold yellow]") |
189 |
| - raise typer.Exit(0) |
| 187 | + if not selected_model: |
| 188 | + console.print("[bold yellow]Model selection cancelled[/bold yellow]") |
| 189 | + raise typer.Exit(0) |
190 | 190 |
|
191 |
| - if selected_model not in model_ids: |
192 |
| - console.print(f"[bold red]Invalid model name: {selected_model}[/bold red]") |
193 |
| - raise typer.Exit(1) |
| 191 | + if selected_model not in model_ids: |
| 192 | + console.print(f"[bold #FFA726]Invalid model name, please try again or press Ctrl+C to cancel[/bold #FFA726]") |
| 193 | + continue |
194 | 194 |
|
195 |
| - return selected_model |
| 195 | + return selected_model |
196 | 196 |
|
197 | 197 | except Exception as e:
|
198 | 198 | if not isinstance(e, (ModelNotSupported, typer.Exit)):
|
199 |
| - console.print(f"[bold red]Error during model selection: {str(e)}[/bold red]") |
| 199 | + console.print(f"[bold #FFA726]Error during model selection: {str(e)}[/bold #FFA726]") |
200 | 200 | raise typer.Exit(1)
|
201 | 201 | raise
|
202 | 202 |
|
@@ -304,7 +304,10 @@ def deploy(
|
304 | 304 | ],
|
305 | 305 | style=custom_style
|
306 | 306 | ).ask()
|
307 |
| - service_type = Service.get_service_from_name(service_name).service_type |
| 307 | + try: |
| 308 | + service_type = Service.get_service_from_name(service_name).service_type |
| 309 | + except: |
| 310 | + raise typer.Exit(0) |
308 | 311 | else:
|
309 | 312 | service_type = supported_services[0].service_type
|
310 | 313 | console.print(f"[bold blue]service type: {supported_services[0].name}[/bold blue]")
|
@@ -509,8 +512,7 @@ def deploy(
|
509 | 512 | if extra_params is None:
|
510 | 513 | while True:
|
511 | 514 | extra_params = questionary.text(
|
512 |
| - "(Optional) Additional parameters, you can skip by pressing Enter:", |
513 |
| - instruction="Usage: https://aws-samples.github.io/easy-model-deployer/en/best_deployment_practices/#extra-parameters-usage", |
| 515 | + "(Optional) Additional parameters, usage (https://aws-samples.github.io/easy-model-deployer/en/best_deployment_practices/#extra-parameters-usage), you can skip by pressing Enter:", |
514 | 516 | default="{}"
|
515 | 517 | ).ask()
|
516 | 518 |
|
@@ -543,10 +545,13 @@ def deploy(
|
543 | 545 | # else:
|
544 | 546 | # console.print(f"[bold blue] model tag: {model_tag}[/bold blue]")
|
545 | 547 | # break
|
546 |
| - if not model_tag and not is_valid_model_tag(model_tag): |
547 |
| - console.print(f"[bold blue]invalid model tag: {model_tag}. Please ensure that the tag complies with the standard rules: {MODEL_TAG_PATTERN}.[/bold blue]") |
548 |
| - else: |
549 |
| - break |
| 548 | + try: |
| 549 | + if not model_tag and not is_valid_model_tag(model_tag): |
| 550 | + console.print(f"[bold blue]invalid model tag: {model_tag}. Please ensure that the tag complies with the standard rules: {MODEL_TAG_PATTERN}.[/bold blue]") |
| 551 | + else: |
| 552 | + break |
| 553 | + except: |
| 554 | + raise typer.Exit(0) |
550 | 555 |
|
551 | 556 | if not model_tag:
|
552 | 557 | raise ValueError("Model tag is required.")
|
|
0 commit comments