Skip to content

Commit 6d7b28f

Browse files
update deployment create API
1 parent 3d5b790 commit 6d7b28f

File tree

7 files changed

+257
-15
lines changed

7 files changed

+257
-15
lines changed

ads/aqua/extension/deployment_handler.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -8,8 +7,8 @@
87
from tornado.web import HTTPError
98

109
from ads.aqua.common.decorator import handle_exceptions
11-
from ads.aqua.extension.errors import Errors
1210
from ads.aqua.extension.base_handler import AquaAPIhandler
11+
from ads.aqua.extension.errors import Errors
1312
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
1413
from ads.aqua.modeldeployment.entities import ModelParams
1514
from ads.config import COMPARTMENT_OCID, PROJECT_OCID
@@ -66,8 +65,8 @@ def post(self, *args, **kwargs):
6665
"""
6766
try:
6867
input_data = self.get_json_body()
69-
except Exception:
70-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
68+
except Exception as ex:
69+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
7170

7271
if not input_data:
7372
raise HTTPError(400, Errors.NO_INPUT_DATA)
@@ -100,6 +99,8 @@ def post(self, *args, **kwargs):
10099
health_check_port = input_data.get("health_check_port")
101100
env_var = input_data.get("env_var")
102101
container_family = input_data.get("container_family")
102+
ocpus = input_data.get("ocpus")
103+
memory_in_gbs = input_data.get("memory_in_gbs")
103104

104105
self.finish(
105106
AquaDeploymentApp().create(
@@ -119,6 +120,8 @@ def post(self, *args, **kwargs):
119120
health_check_port=health_check_port,
120121
env_var=env_var,
121122
container_family=container_family,
123+
ocpus=ocpus,
124+
memory_in_gbs=memory_in_gbs,
122125
)
123126
)
124127

@@ -153,9 +156,7 @@ def validate_predict_url(endpoint):
153156
return False
154157
if not url.netloc:
155158
return False
156-
if not url.path.endswith("/predict"):
157-
return False
158-
return True
159+
return url.path.endswith("/predict")
159160
except Exception:
160161
return False
161162

@@ -170,8 +171,8 @@ def post(self, *args, **kwargs):
170171
"""
171172
try:
172173
input_data = self.get_json_body()
173-
except Exception:
174-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
174+
except Exception as ex:
175+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
175176

176177
if not input_data:
177178
raise HTTPError(400, Errors.NO_INPUT_DATA)
@@ -192,10 +193,10 @@ def post(self, *args, **kwargs):
192193
)
193194
try:
194195
model_params_obj = ModelParams(**model_params)
195-
except:
196+
except Exception as ex:
196197
raise HTTPError(
197198
400, Errors.INVALID_INPUT_DATA_FORMAT.format("model_params")
198-
)
199+
) from ex
199200

200201
return self.finish(
201202
MDInferenceResponse(prompt, model_params_obj).get_model_deployment_response(
@@ -236,8 +237,8 @@ def post(self, *args, **kwargs):
236237
"""
237238
try:
238239
input_data = self.get_json_body()
239-
except Exception:
240-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
240+
except Exception as ex:
241+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
241242

242243
if not input_data:
243244
raise HTTPError(400, Errors.NO_INPUT_DATA)

ads/aqua/modeldeployment/deployment.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
import logging
6-
from typing import Dict, List, Union
6+
from typing import Dict, List, Optional, Union
77

88
from ads.aqua.app import AquaApp, logger
99
from ads.aqua.common.enums import (
@@ -102,6 +102,8 @@ def create(
102102
health_check_port: int = None,
103103
env_var: Dict = None,
104104
container_family: str = None,
105+
memory_in_gbs: Optional[float] = None,
106+
ocpus: Optional[float] = None,
105107
) -> "AquaDeployment":
106108
"""
107109
Creates a new Aqua deployment
@@ -142,6 +144,10 @@ def create(
142144
Environment variable for the deployment, by default None.
143145
container_family: str
144146
The image family of model deployment container runtime. Required for unverified Aqua models.
147+
memory_in_gbs: float
148+
The memory in gbs for the shape selected.
149+
ocpus: float
150+
The ocpu count for the shape selected.
145151
Returns
146152
-------
147153
AquaDeployment
@@ -325,6 +331,11 @@ def create(
325331
log_id=predict_log_id,
326332
)
327333
)
334+
if memory_in_gbs and ocpus and infrastructure.shape_name.endswith("Flex"):
335+
infrastructure.with_shape_config_details(
336+
ocpus=ocpus,
337+
memory_in_gbs=memory_in_gbs,
338+
)
328339
# configure model deployment runtime
329340
container_runtime = (
330341
ModelDeploymentContainerRuntime()
@@ -338,6 +349,7 @@ def create(
338349
.with_overwrite_existing_artifact(True)
339350
.with_remove_existing_artifact(True)
340351
)
352+
341353
# configure model deployment and deploy model on container runtime
342354
deployment = (
343355
ModelDeployment()
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
kind: deployment
2+
spec:
3+
createdBy: ocid1.user.oc1..<OCID>
4+
displayName: model-deployment-name
5+
freeformTags:
6+
OCI_AQUA: active
7+
aqua_model_name: model-name
8+
id: "ocid1.datasciencemodeldeployment.oc1.<region>.<MD_OCID>"
9+
infrastructure:
10+
kind: infrastructure
11+
spec:
12+
bandwidthMbps: 10
13+
compartmentId: ocid1.compartment.oc1..<OCID>
14+
deploymentType: SINGLE_MODEL
15+
policyType: FIXED_SIZE
16+
projectId: ocid1.datascienceproject.oc1.iad.<OCID>
17+
replica: 1
18+
shapeName: "VM.Standard.A1.Flex"
19+
shapeConfigDetails:
20+
memoryInGBs: 60.0
21+
ocpus: 10
22+
type: datascienceModelDeployment
23+
lifecycleState: CREATING
24+
modelDeploymentUrl: "https://modeldeployment.customer-oci.com/ocid1.datasciencemodeldeployment.oc1.<region>.<MD_OCID>"
25+
runtime:
26+
kind: runtime
27+
spec:
28+
env:
29+
BASE_MODEL: service_models/model-name/artifact
30+
BASE_MODEL_FILE: model-name.gguf
31+
MODEL_DEPLOY_ENABLE_STREAMING: 'true'
32+
MODEL_DEPLOY_PREDICT_ENDPOINT: /v1/completions
33+
MODEL_DEPLOY_HEALTH_ENDPOINT: /v1/models
34+
healthCheckPort: 8080
35+
image: "dsmc://image-name:1.0.0.0"
36+
modelUri: "ocid1.datasciencemodeldeployment.oc1.<region>.<MODEL_OCID>"
37+
serverPort: 8080
38+
type: container
39+
timeCreated: 2024-01-01T00:00:00.000000+00:00
40+
type: modelDeployment

tests/unitary/with_extras/aqua/test_data/deployment/deployment_config.json

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,26 @@
55
"TGI_PARAMS": "--max-stop-sequences 6",
66
"VLLM_PARAMS": "--max-model-len 4096"
77
}
8+
},
9+
"VM.Standard.A1.Flex": {
10+
"parameters": {},
11+
"shape_info": {
12+
"configs": [
13+
{
14+
"memory_in_gbs": 128,
15+
"ocpu": 32
16+
},
17+
{
18+
"memory_in_gbs": 256,
19+
"ocpu": 64
20+
}
21+
],
22+
"type": "CPU"
23+
}
824
}
925
},
1026
"shape": [
11-
"VM.GPU.A10.1"
27+
"VM.GPU.A10.1",
28+
"VM.Standard.A1.Flex"
1229
]
1330
}

tests/unitary/with_extras/aqua/test_data/ui/container_index.json

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,27 @@
11
{
22
"containerSpec": {
3+
"odsc-llama-cpp-serving": {
4+
"cliParam": "",
5+
"envVars": [
6+
{
7+
"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions"
8+
},
9+
{
10+
"MODEL_DEPLOY_HEALTH_ENDPOINT": "/v1/models"
11+
},
12+
{
13+
"MODEL_DEPLOY_ENABLE_STREAMING": "true"
14+
},
15+
{
16+
"PORT": "8080"
17+
},
18+
{
19+
"HEALTH_CHECK_PORT": "8080"
20+
}
21+
],
22+
"healthCheckPort": "8080",
23+
"serverPort": "8080"
24+
},
325
"odsc-tgi-serving": {
426
"cliParam": "--sharded true --trust-remote-code",
527
"envVars": [
@@ -39,6 +61,14 @@
3961
"serverPort": "8080"
4062
}
4163
},
64+
"odsc-llama-cpp-serving": [
65+
{
66+
"displayName": "llama.cpp",
67+
"name": "dsmc://odsc-llama-cpp-serving",
68+
"type": "inference",
69+
"version": "0.2.75.2"
70+
}
71+
],
4272
"odsc-llm-evaluate": [
4373
{
4474
"name": "dsmc://odsc-llm-evaluate",

0 commit comments

Comments
 (0)