Skip to content

Commit b46958e

Browse files
authored
refactor: modify model list (#95)
* merge * merge * add Mistral-Small-3.1-24B-Instruct-2503 * modify qwq-32b deploy * add txgemma model; * modify model list command * fix typo
1 parent ab1e566 commit b46958e

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

src/emd/cli.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,17 @@
8383

8484
@app.command(help="List supported models")
8585
@catch_aws_credential_errors
86-
def list_supported_models(model_id: Annotated[
86+
def list_supported_models(
87+
model_id: Annotated[
8788
str, typer.Argument(help="Model ID")
88-
] = None):
89+
] = None,
90+
detail: Annotated[
91+
Optional[bool],
92+
typer.Option("-a", "--detail", help="output model information in details.")
93+
] = False
94+
):
8995
# console.print("[bold blue]Retrieving models...[/bold blue]")
90-
support_models = Model.get_supported_models()
96+
support_models = Model.get_supported_models(detail=detail)
9197
if model_id:
9298
support_models = [model for _model_id,model in support_models.items() if _model_id == model_id]
9399
r = json.dumps(support_models,indent=2,ensure_ascii=False)

src/emd/models/llms/txgemma.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
supported_frameworks=[
4848
fastapi_framework
4949
],
50+
allow_china_region=True,
5051
huggingface_model_id="google/txgemma-9b-chat",
5152
modelscope_model_id="AI-ModelScope/txgemma-9b-chat",
5253
model_files_download_source=ModelFilesDownloadSource.MODELSCOPE,
@@ -79,6 +80,7 @@
7980
supported_frameworks=[
8081
fastapi_framework
8182
],
83+
allow_china_region=True,
8284
huggingface_model_id="google/txgemma-27b-chat",
8385
modelscope_model_id="AI-ModelScope/txgemma-27b-chat",
8486
model_files_download_source=ModelFilesDownloadSource.MODELSCOPE,

src/emd/models/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,10 @@ def get_model(cls ,model_id:str,update:dict = None) -> T:
210210
return model
211211

212212
@classmethod
213-
def get_supported_models(cls) -> dict:
214-
return {model_id: model.model_type for model_id,model in cls.model_map.items()}
213+
def get_supported_models(cls,detail=False) -> dict:
214+
if not detail:
215+
return {model_id: model.model_type for model_id,model in cls.model_map.items()}
216+
return {model_id: model.model_dump() for model_id,model in cls.model_map.items()}
215217

216218
def find_current_engine(self,engine_type:str) -> dict:
217219
supported_engines:List[Engine] = self.supported_engines

0 commit comments

Comments
 (0)