Skip to content

Commit b3efa4c

Browse files
avoid circular imports
1 parent 9599dab commit b3efa4c

File tree

3 files changed

+36
-40
lines changed

3 files changed

+36
-40
lines changed

ads/aqua/common/utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,12 @@
3838
MODEL_BY_REFERENCE_OSS_PATH_KEY,
3939
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
4040
SUPPORTED_FILE_FORMATS,
41+
TGI_INFERENCE_RESTRICTED_PARAMS,
4142
UNKNOWN,
4243
UNKNOWN_JSON_STR,
44+
VLLM_INFERENCE_RESTRICTED_PARAMS,
4345
)
4446
from ads.aqua.data import AquaResourceIdentifier
45-
from ads.aqua.modeldeployment.constants import (
46-
TGIInferenceRestrictedParams,
47-
VLLMInferenceRestrictedParams,
48-
)
4947
from ads.common.auth import default_signer
5048
from ads.common.decorator.threaded import threaded
5149
from ads.common.extended_enum import ExtendedEnumMeta
@@ -897,13 +895,13 @@ def get_restricted_params_by_container(container_type_name: str) -> set:
897895
898896
Returns
899897
-------
900-
InferenceContainerParamType value
898+
Set of restricted params based on container type
901899
902900
"""
903901
# check substring instead of direct match in case container_type_name changes in the future
904902
if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_name.lower():
905-
return VLLMInferenceRestrictedParams
903+
return VLLM_INFERENCE_RESTRICTED_PARAMS
906904
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
907-
return TGIInferenceRestrictedParams
905+
return TGI_INFERENCE_RESTRICTED_PARAMS
908906
else:
909907
return set()

ads/aqua/constants.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54
"""This module defines constants used in ads.aqua module."""
@@ -45,19 +44,34 @@
4544
SUPPORTED_FILE_FORMATS = ["jsonl"]
4645
MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"
4746

48-
CONSOLE_LINK_RESOURCE_TYPE_MAPPING = dict(
49-
datasciencemodel="models",
50-
datasciencemodeldeployment="model-deployments",
51-
datasciencemodeldeploymentdev="model-deployments",
52-
datasciencemodeldeploymentint="model-deployments",
53-
datasciencemodeldeploymentpre="model-deployments",
54-
datasciencejob="jobs",
55-
datasciencejobrun="job-runs",
56-
datasciencejobrundev="job-runs",
57-
datasciencejobrunint="job-runs",
58-
datasciencejobrunpre="job-runs",
59-
datasciencemodelversionset="model-version-sets",
60-
datasciencemodelversionsetpre="model-version-sets",
61-
datasciencemodelversionsetint="model-version-sets",
62-
datasciencemodelversionsetdev="model-version-sets",
63-
)
47+
CONSOLE_LINK_RESOURCE_TYPE_MAPPING = {
48+
"datasciencemodel": "models",
49+
"datasciencemodeldeployment": "model-deployments",
50+
"datasciencemodeldeploymentdev": "model-deployments",
51+
"datasciencemodeldeploymentint": "model-deployments",
52+
"datasciencemodeldeploymentpre": "model-deployments",
53+
"datasciencejob": "jobs",
54+
"datasciencejobrun": "job-runs",
55+
"datasciencejobrundev": "job-runs",
56+
"datasciencejobrunint": "job-runs",
57+
"datasciencejobrunpre": "job-runs",
58+
"datasciencemodelversionset": "model-version-sets",
59+
"datasciencemodelversionsetpre": "model-version-sets",
60+
"datasciencemodelversionsetint": "model-version-sets",
61+
"datasciencemodelversionsetdev": "model-version-sets",
62+
}
63+
64+
VLLM_INFERENCE_RESTRICTED_PARAMS = {
65+
"--tensor-parallel-size",
66+
"--port",
67+
"--host",
68+
"--served-model-name",
69+
"--seed",
70+
}
71+
TGI_INFERENCE_RESTRICTED_PARAMS = {
72+
"--port",
73+
"--hostname",
74+
"--num-shard",
75+
"--sharded",
76+
"--trust-remote-code",
77+
}

ads/aqua/modeldeployment/constants.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -9,18 +8,3 @@
98
109
This module contains constants used in Aqua Model Deployment.
1110
"""
12-
13-
VLLMInferenceRestrictedParams = {
14-
"--tensor-parallel-size",
15-
"--port",
16-
"--host",
17-
"--served-model-name",
18-
"--seed",
19-
}
20-
TGIInferenceRestrictedParams = {
21-
"--port",
22-
"--hostname",
23-
"--num-shard",
24-
"--sharded",
25-
"--trust-remote-code",
26-
}

0 commit comments

Comments
 (0)