Skip to content

Commit 65e9c1c

Browse files
authored
fix: update deployment prompts and messages (#129)
* fix: update deployment prompts and messages * fix: fix prompt list and expand by default * chore: clean up
1 parent 3b5908c commit 65e9c1c

File tree

2 files changed

+82
-75
lines changed

2 files changed

+82
-75
lines changed

src/emd/commands/deploy.py

Lines changed: 79 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from emd.utils.decorators import catch_aws_credential_errors,check_emd_env_exist,load_aws_profile
2929
from emd.utils.logger_utils import make_layout
3030
from emd.utils.exceptions import ModelNotSupported,ServiceNotSupported,InstanceNotSupported
31+
from prompt_toolkit import prompt
32+
from prompt_toolkit.completion import FuzzyWordCompleter
3133

3234
app = typer.Typer(pretty_exceptions_enable=False)
3335
console = Console()
@@ -52,13 +54,26 @@ def show_help(choice):
5254
return f"{choice} (shortcut)"
5355

5456

55-
def supported_models_filter(region:str,support_models:list[Model]):
57+
def supported_models_filter(
58+
region:str,
59+
allow_local_deploy,
60+
only_allow_local_deploy,
61+
support_models:list[Model]
62+
):
5663
ret = []
5764
is_cn_region = check_cn_region(region)
5865

5966
for model in support_models:
6067
if is_cn_region and not model.allow_china_region:
6168
continue
69+
70+
# Skip models that only support local services when local deployment is not allowed
71+
if not allow_local_deploy:
72+
# Check if all supported services are local services
73+
all_local_services = all(service.service_type == ServiceType.LOCAL for service in model.supported_services)
74+
if all_local_services:
75+
continue
76+
6277
ret.append(model)
6378
return ret
6479

@@ -130,72 +145,60 @@ def is_valid_model_tag(name,pattern=MODEL_TAG_PATTERN):
130145
return bool(re.match(pattern, name))
131146

132147

133-
def ask_model_id(region,model_id=None):
148+
# Define a natural sort key function to handle numeric values in model names
149+
def natural_sort_key(s):
150+
# Split the string into text and numeric parts
151+
return [int(c) if c.isdigit() else float(c) if c.replace('.', '', 1).isdigit() else c.lower()
152+
for c in re.split(r'(\d+\.\d+|\d+)', s)]
153+
154+
155+
def ask_model_id(region, allow_local_deploy, only_allow_local_deploy, model_id=None):
134156
if model_id is not None:
135157
return model_id
136158

137-
# step 1: select model series name
138-
support_models:list[Model] = sorted(
139-
[Model.get_model(m) for m in Model.get_supported_models()
140-
if hasattr(Model.get_model(m), 'model_series') and hasattr(Model.get_model(m).model_series, 'model_series_name')],
141-
key=lambda x:x.model_series.model_series_name
142-
)
143-
# filter models
144-
support_models = supported_models_filter(region,support_models)
159+
try:
160+
supported_models = [Model.get_model(m) for m in Model.get_supported_models()]
161+
filtered_models = supported_models_filter(region, allow_local_deploy, only_allow_local_deploy, supported_models)
145162

146-
if not support_models:
147-
raise ModelNotSupported(region)
163+
if not filtered_models:
164+
raise ModelNotSupported(region)
148165

149-
model_series_map = defaultdict(list)
150-
for model in support_models:
151-
model_series_map[model.model_series.model_series_name].append(model)
152-
153-
def _get_series_description(models:list[Model]):
154-
model = models[0]
155-
description = "\n"
156-
description += model.model_series.description
157-
description += f"\nreference link: {model.model_series.reference_link}"
158-
description += "\nSupported models: "+ "\n - " + "\n - ".join(model.model_id for model in models)
159-
return description
160-
161-
series_name = select_with_help(
162-
"Select the model series:",
163-
choices=[
164-
Choice(
165-
title=series_name,
166-
description=_get_series_description(models),
167-
)
168-
for series_name,models in model_series_map.items()
169-
],
170-
show_description=True,
171-
style=custom_style
172-
).ask()
173-
if series_name is None:
174-
raise typer.Exit(0)
166+
model_ids = sorted([model.model_id for model in filtered_models], key=natural_sort_key)
167+
completer = FuzzyWordCompleter(model_ids, WORD=True)
175168

176-
def _get_model_description(model:Model):
177-
description=f"\n\nModelType: {model.model_type}\nApplication Scenario: {model.application_scenario}"
178-
if model.description:
179-
description += f"\nDescription: {model.description}"
180-
return description
181-
182-
# step 2 select model_id
183-
model_id = select_with_help(
184-
"Select the model name:",
185-
choices=[
186-
Choice(
187-
title=model.model_id,
188-
description=_get_model_description(model)
189-
)
190-
for model in model_series_map[series_name]
191-
],
192-
show_description=True,
193-
style=custom_style
194-
).ask()
195-
196-
if model_id is None:
197-
raise typer.Exit(0)
198-
return model_id
169+
from prompt_toolkit.formatted_text import HTML
170+
from prompt_toolkit import PromptSession
171+
from prompt_toolkit.application.current import get_app
172+
173+
session = PromptSession(
174+
completer=completer,
175+
complete_while_typing=True,
176+
)
177+
178+
def get_prompt_message():
179+
buffer = get_app().current_buffer
180+
if buffer.text:
181+
return HTML('<b>? Enter model name: </b>')
182+
else:
183+
return HTML('<b>? Enter model name: </b><span fg="#888888">(Type to search, run "emd list-supported-models" for full model list)</span>')
184+
185+
selected_model = session.prompt(get_prompt_message, pre_run=lambda: get_app().current_buffer.start_completion())
186+
187+
if not selected_model:
188+
console.print("[bold yellow]Model selection cancelled[/bold yellow]")
189+
raise typer.Exit(0)
190+
191+
if selected_model not in model_ids:
192+
console.print(f"[bold red]Invalid model name: {selected_model}[/bold red]")
193+
raise typer.Exit(1)
194+
195+
return selected_model
196+
197+
except Exception as e:
198+
if not isinstance(e, (ModelNotSupported, typer.Exit)):
199+
console.print(f"[bold red]Error during model selection: {str(e)}[/bold red]")
200+
raise typer.Exit(1)
201+
raise
199202

200203

201204
#@app.callback(invoke_without_command=True)(invoke_without_command=True)
@@ -268,7 +271,12 @@ def deploy(
268271

269272
vpc_id = None
270273
# ask model id
271-
model_id = ask_model_id(region,model_id=model_id)
274+
model_id = ask_model_id(
275+
region,
276+
allow_local_deploy,
277+
only_allow_local_deploy,
278+
model_id=model_id
279+
)
272280

273281
if not check_model_support_on_cn_region(model_id,region):
274282
raise ModelNotSupported(region,model_id=model_id)
@@ -286,7 +294,7 @@ def deploy(
286294
if service_type is None:
287295
if len(supported_services) > 1:
288296
service_name = select_with_help(
289-
"Select the service for deployment:",
297+
"Select model hosting service:",
290298
choices=[
291299
Choice(
292300
title=service.name,
@@ -328,7 +336,7 @@ def deploy(
328336
vpc_name = next((tag['Value'] for tag in vpc.get('Tags', []) if tag.get('Key') == 'Name'), None)
329337
vpc['Name'] = vpc_name if vpc_name else '-'
330338
emd_vpc = select_with_help(
331-
"Select the VPC (Virtual Private Cloud) you want to deploy the ESC service:",
339+
"Select VPC (Virtual Private Cloud):",
332340
choices=[
333341
Choice(
334342
title=f"{emd_default_vpc['VpcId']} ({emd_default_vpc['CidrBlock']}) (EMD-vpc)" if emd_default_vpc else 'Create a new VPC',
@@ -412,7 +420,7 @@ def deploy(
412420
if instance_type is None:
413421
if len(supported_instances)>1:
414422
instance_type = select_with_help(
415-
"Select the instance type:",
423+
"Select instance type:",
416424
choices=[
417425
Choice(
418426
title=instance.instance_type,
@@ -449,7 +457,7 @@ def deploy(
449457
if engine_type is None:
450458
if len(supported_engines)>1:
451459
engine_type = select_with_help(
452-
"Select the inference engine to use:",
460+
"Select inference engine:",
453461
choices=[
454462
Choice(
455463
title=engine.engine_type,
@@ -476,7 +484,7 @@ def deploy(
476484
if framework_type is None:
477485
if len(supported_frameworks)>1:
478486
framework_type = select_with_help(
479-
"Select the inference engine to use:",
487+
"Select inference engine:",
480488
choices=[
481489
Choice(
482490
title=framework.framework_type,
@@ -488,7 +496,6 @@ def deploy(
488496
).ask()
489497
else:
490498
framework_type = supported_frameworks[0].framework_type
491-
console.print(f"[bold blue]framework type: {framework_type}[/bold blue]")
492499
else:
493500
supported_framework_types = model.supported_framework_types
494501
console.print(f"[bold blue]framework type: {framework_type}[/bold blue]")
@@ -502,8 +509,8 @@ def deploy(
502509
if extra_params is None:
503510
while True:
504511
extra_params = questionary.text(
505-
"(Optional) Additional deployment parameters (JSON string or local file path), you can skip by pressing Enter:",
506-
instruction="Parameters format: https://aws-samples.github.io/easy-model-deployer/en/installation/#-extra-params.",
512+
"(Optional) Additional parameters, you can skip by pressing Enter:",
513+
instruction="Usage: https://aws-samples.github.io/easy-model-deployer/en/best_deployment_practices/#extra-parameters-usage",
507514
default="{}"
508515
).ask()
509516

@@ -528,7 +535,7 @@ def deploy(
528535
if not skip_confirm and not service_type == ServiceType.LOCAL:
529536
while True:
530537
model_tag = questionary.text(
531-
"(Optional) Add a model deployment tag (custom label), you can skip by pressing Enter:",
538+
"(Optional) Custom tag (label), you can skip by pressing Enter:",
532539
default=MODEL_DEFAULT_TAG
533540
).ask()
534541
# if model_tag == MODEL_DEFAULT_TAG:
@@ -547,7 +554,7 @@ def deploy(
547554

548555
if not skip_confirm:
549556
if not typer.confirm(
550-
"Would you like to proceed with the deployment? Please verify your selections above.",
557+
"Ready to deploy? Please confirm your selections above.",
551558
abort=True,
552559
):
553560
raise typer.Exit(0)

src/emd/models/services.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"SageMakerEndpointName": ValueWithDefault(name="sagemaker_endpoint_name",default="Auto-generate"),
2222
"APIKey": ValueWithDefault(name="api_key",default="")
2323
},
24-
name = "Amazon SageMaker AI Real-time inference with OpenAI Compatible API",
24+
name = "Amazon SageMaker AI Real-time inference with OpenAI-Compatible API",
2525
service_type=ServiceType.SAGEMAKER,
2626
description="Amazon SageMaker Real-time inference provides low-latency, interactive inference through fully managed endpoints that support autoscaling. It provides an OpenAI-compatible REST API (e.g., /v1/completions) via an Application Load Balancer (ALB).\n(https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html)",
2727
support_cn_region = True
@@ -64,7 +64,7 @@
6464
"AutoScalingTargetValue": ValueWithDefault(name="auto_scaling_target_value",default=10),
6565
"APIKey": ValueWithDefault(name="api_key",default="")
6666
},
67-
name = "Amazon SageMaker AI Asynchronous inference with OpenAI Compatible API",
67+
name = "Amazon SageMaker AI Asynchronous inference with OpenAI-Compatible API",
6868
service_type=ServiceType.SAGEMAKER_ASYNC,
6969
description="Amazon SageMaker Asynchronous Inference queues requests for processing asynchronously, making it suitable for large payloads (up to 1GB) and long processing times (up to one hour), while also enabling cost savings by autoscaling to zero when idle. It provides an OpenAI-compatible REST API (e.g., /v1/completions) via an Application Load Balancer (ALB).\n(https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference.html)",
7070
support_cn_region = True
@@ -107,7 +107,7 @@
107107
"ContainerGpu":"instance_gpu_num",
108108
"APIKey": ValueWithDefault(name="api_key",default="")
109109
},
110-
name = "Amazon ECS with OpenAI Compatible API",
110+
name = "Amazon ECS with OpenAI-Compatible API",
111111
service_type=ServiceType.ECS,
112112
description="Amazon Elastic Container Service is a fully managed service that runs containerized applications in clusters with auto scaling. It provides an OpenAI-compatible REST API (e.g., /v1/completions) via an Application Load Balancer (ALB), enabling integration with AI models for tasks like chatbots or document analysis. (https://docs.aws.amazon.com/AmazonECS/latest/developerguide)",
113113
support_cn_region = True,

0 commit comments

Comments
 (0)