28
28
from emd .utils .decorators import catch_aws_credential_errors ,check_emd_env_exist ,load_aws_profile
29
29
from emd .utils .logger_utils import make_layout
30
30
from emd .utils .exceptions import ModelNotSupported ,ServiceNotSupported ,InstanceNotSupported
31
+ from prompt_toolkit import prompt
32
+ from prompt_toolkit .completion import FuzzyWordCompleter
31
33
32
34
app = typer .Typer (pretty_exceptions_enable = False )
33
35
console = Console ()
@@ -52,13 +54,26 @@ def show_help(choice):
52
54
return f"{ choice } (shortcut)"
53
55
54
56
55
- def supported_models_filter (region :str ,support_models :list [Model ]):
57
+ def supported_models_filter (
58
+ region :str ,
59
+ allow_local_deploy ,
60
+ only_allow_local_deploy ,
61
+ support_models :list [Model ]
62
+ ):
56
63
ret = []
57
64
is_cn_region = check_cn_region (region )
58
65
59
66
for model in support_models :
60
67
if is_cn_region and not model .allow_china_region :
61
68
continue
69
+
70
+ # Skip models that only support local services when local deployment is not allowed
71
+ if not allow_local_deploy :
72
+ # Check if all supported services are local services
73
+ all_local_services = all (service .service_type == ServiceType .LOCAL for service in model .supported_services )
74
+ if all_local_services :
75
+ continue
76
+
62
77
ret .append (model )
63
78
return ret
64
79
@@ -130,72 +145,60 @@ def is_valid_model_tag(name,pattern=MODEL_TAG_PATTERN):
130
145
return bool (re .match (pattern , name ))
131
146
132
147
133
- def ask_model_id (region ,model_id = None ):
148
+ # Define a natural sort key function to handle numeric values in model names
149
+ def natural_sort_key (s ):
150
+ # Split the string into text and numeric parts
151
+ return [int (c ) if c .isdigit () else float (c ) if c .replace ('.' , '' , 1 ).isdigit () else c .lower ()
152
+ for c in re .split (r'(\d+\.\d+|\d+)' , s )]
153
+
154
+
155
+ def ask_model_id (region , allow_local_deploy , only_allow_local_deploy , model_id = None ):
134
156
if model_id is not None :
135
157
return model_id
136
158
137
- # step 1: select model series name
138
- support_models :list [Model ] = sorted (
139
- [Model .get_model (m ) for m in Model .get_supported_models ()
140
- if hasattr (Model .get_model (m ), 'model_series' ) and hasattr (Model .get_model (m ).model_series , 'model_series_name' )],
141
- key = lambda x :x .model_series .model_series_name
142
- )
143
- # filter models
144
- support_models = supported_models_filter (region ,support_models )
159
+ try :
160
+ supported_models = [Model .get_model (m ) for m in Model .get_supported_models ()]
161
+ filtered_models = supported_models_filter (region , allow_local_deploy , only_allow_local_deploy , supported_models )
145
162
146
- if not support_models :
147
- raise ModelNotSupported (region )
163
+ if not filtered_models :
164
+ raise ModelNotSupported (region )
148
165
149
- model_series_map = defaultdict (list )
150
- for model in support_models :
151
- model_series_map [model .model_series .model_series_name ].append (model )
152
-
153
- def _get_series_description (models :list [Model ]):
154
- model = models [0 ]
155
- description = "\n "
156
- description += model .model_series .description
157
- description += f"\n reference link: { model .model_series .reference_link } "
158
- description += "\n Supported models: " + "\n - " + "\n - " .join (model .model_id for model in models )
159
- return description
160
-
161
- series_name = select_with_help (
162
- "Select the model series:" ,
163
- choices = [
164
- Choice (
165
- title = series_name ,
166
- description = _get_series_description (models ),
167
- )
168
- for series_name ,models in model_series_map .items ()
169
- ],
170
- show_description = True ,
171
- style = custom_style
172
- ).ask ()
173
- if series_name is None :
174
- raise typer .Exit (0 )
166
+ model_ids = sorted ([model .model_id for model in filtered_models ], key = natural_sort_key )
167
+ completer = FuzzyWordCompleter (model_ids , WORD = True )
175
168
176
- def _get_model_description (model :Model ):
177
- description = f"\n \n ModelType: { model .model_type } \n Application Scenario: { model .application_scenario } "
178
- if model .description :
179
- description += f"\n Description: { model .description } "
180
- return description
181
-
182
- # step 2 select model_id
183
- model_id = select_with_help (
184
- "Select the model name:" ,
185
- choices = [
186
- Choice (
187
- title = model .model_id ,
188
- description = _get_model_description (model )
189
- )
190
- for model in model_series_map [series_name ]
191
- ],
192
- show_description = True ,
193
- style = custom_style
194
- ).ask ()
195
-
196
- if model_id is None :
197
- raise typer .Exit (0 )
198
- return model_id
169
+ from prompt_toolkit .formatted_text import HTML
170
+ from prompt_toolkit import PromptSession
171
+ from prompt_toolkit .application .current import get_app
172
+
173
+ session = PromptSession (
174
+ completer = completer ,
175
+ complete_while_typing = True ,
176
+ )
177
+
178
+ def get_prompt_message ():
179
+ buffer = get_app ().current_buffer
180
+ if buffer .text :
181
+ return HTML ('<b>? Enter model name: </b>' )
182
+ else :
183
+ return HTML ('<b>? Enter model name: </b><span fg="#888888">(Type to search, run "emd list-supported-models" for full model list)</span>' )
184
+
185
+ selected_model = session .prompt (get_prompt_message , pre_run = lambda : get_app ().current_buffer .start_completion ())
186
+
187
+ if not selected_model :
188
+ console .print ("[bold yellow]Model selection cancelled[/bold yellow]" )
189
+ raise typer .Exit (0 )
190
+
191
+ if selected_model not in model_ids :
192
+ console .print (f"[bold red]Invalid model name: { selected_model } [/bold red]" )
193
+ raise typer .Exit (1 )
194
+
195
+ return selected_model
196
+
197
+ except Exception as e :
198
+ if not isinstance (e , (ModelNotSupported , typer .Exit )):
199
+ console .print (f"[bold red]Error during model selection: { str (e )} [/bold red]" )
200
+ raise typer .Exit (1 )
201
+ raise
199
202
200
203
201
204
#@app.callback(invoke_without_command=True)(invoke_without_command=True)
@@ -268,7 +271,12 @@ def deploy(
268
271
269
272
vpc_id = None
270
273
# ask model id
271
- model_id = ask_model_id (region ,model_id = model_id )
274
+ model_id = ask_model_id (
275
+ region ,
276
+ allow_local_deploy ,
277
+ only_allow_local_deploy ,
278
+ model_id = model_id
279
+ )
272
280
273
281
if not check_model_support_on_cn_region (model_id ,region ):
274
282
raise ModelNotSupported (region ,model_id = model_id )
@@ -286,7 +294,7 @@ def deploy(
286
294
if service_type is None :
287
295
if len (supported_services ) > 1 :
288
296
service_name = select_with_help (
289
- "Select the service for deployment :" ,
297
+ "Select model hosting service :" ,
290
298
choices = [
291
299
Choice (
292
300
title = service .name ,
@@ -328,7 +336,7 @@ def deploy(
328
336
vpc_name = next ((tag ['Value' ] for tag in vpc .get ('Tags' , []) if tag .get ('Key' ) == 'Name' ), None )
329
337
vpc ['Name' ] = vpc_name if vpc_name else '-'
330
338
emd_vpc = select_with_help (
331
- "Select the VPC (Virtual Private Cloud) you want to deploy the ESC service :" ,
339
+ "Select VPC (Virtual Private Cloud):" ,
332
340
choices = [
333
341
Choice (
334
342
title = f"{ emd_default_vpc ['VpcId' ]} ({ emd_default_vpc ['CidrBlock' ]} ) (EMD-vpc)" if emd_default_vpc else 'Create a new VPC' ,
@@ -412,7 +420,7 @@ def deploy(
412
420
if instance_type is None :
413
421
if len (supported_instances )> 1 :
414
422
instance_type = select_with_help (
415
- "Select the instance type:" ,
423
+ "Select instance type:" ,
416
424
choices = [
417
425
Choice (
418
426
title = instance .instance_type ,
@@ -449,7 +457,7 @@ def deploy(
449
457
if engine_type is None :
450
458
if len (supported_engines )> 1 :
451
459
engine_type = select_with_help (
452
- "Select the inference engine to use :" ,
460
+ "Select inference engine:" ,
453
461
choices = [
454
462
Choice (
455
463
title = engine .engine_type ,
@@ -476,7 +484,7 @@ def deploy(
476
484
if framework_type is None :
477
485
if len (supported_frameworks )> 1 :
478
486
framework_type = select_with_help (
479
- "Select the inference engine to use :" ,
487
+ "Select inference engine:" ,
480
488
choices = [
481
489
Choice (
482
490
title = framework .framework_type ,
@@ -488,7 +496,6 @@ def deploy(
488
496
).ask ()
489
497
else :
490
498
framework_type = supported_frameworks [0 ].framework_type
491
- console .print (f"[bold blue]framework type: { framework_type } [/bold blue]" )
492
499
else :
493
500
supported_framework_types = model .supported_framework_types
494
501
console .print (f"[bold blue]framework type: { framework_type } [/bold blue]" )
@@ -502,8 +509,8 @@ def deploy(
502
509
if extra_params is None :
503
510
while True :
504
511
extra_params = questionary .text (
505
- "(Optional) Additional deployment parameters (JSON string or local file path) , you can skip by pressing Enter:" ,
506
- instruction = "Parameters format : https://aws-samples.github.io/easy-model-deployer/en/installation/#- extra-params. " ,
512
+ "(Optional) Additional parameters, you can skip by pressing Enter:" ,
513
+ instruction = "Usage : https://aws-samples.github.io/easy-model-deployer/en/best_deployment_practices/# extra-parameters-usage " ,
507
514
default = "{}"
508
515
).ask ()
509
516
@@ -528,7 +535,7 @@ def deploy(
528
535
if not skip_confirm and not service_type == ServiceType .LOCAL :
529
536
while True :
530
537
model_tag = questionary .text (
531
- "(Optional) Add a model deployment tag (custom label), you can skip by pressing Enter:" ,
538
+ "(Optional) Custom tag (label), you can skip by pressing Enter:" ,
532
539
default = MODEL_DEFAULT_TAG
533
540
).ask ()
534
541
# if model_tag == MODEL_DEFAULT_TAG:
@@ -547,7 +554,7 @@ def deploy(
547
554
548
555
if not skip_confirm :
549
556
if not typer .confirm (
550
- "Would you like to proceed with the deployment ? Please verify your selections above." ,
557
+ "Ready to deploy ? Please confirm your selections above." ,
551
558
abort = True ,
552
559
):
553
560
raise typer .Exit (0 )
0 commit comments