Skip to content

Commit c267386

Browse files
[ODSC-59697] Llama cpp deployment support (#902)
2 parents 961aa05 + dc39f58 commit c267386

File tree

11 files changed

+330
-54
lines changed

11 files changed

+330
-54
lines changed
Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
16
import json
7+
from importlib import metadata
28
from typing import List, Union
39

410
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, fetch_service_compartment
511
from ads.aqua.common.decorator import handle_exceptions
612
from ads.aqua.common.errors import AquaResourceAccessError
713
from ads.aqua.common.utils import known_realm
814
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
9-
from ads.aqua.extension.models.ws_models import RequestResponseType, AdsVersionResponse, AdsVersionRequest, \
10-
CompatibilityCheckResponse
11-
from importlib import metadata
15+
from ads.aqua.extension.models.ws_models import (
16+
AdsVersionResponse,
17+
CompatibilityCheckResponse,
18+
RequestResponseType,
19+
)
1220

1321

1422
class AquaCommonWsMsgHandler(AquaWSMsgHandler):
15-
1623
@staticmethod
1724
def get_message_types() -> List[RequestResponseType]:
1825
return [RequestResponseType.AdsVersion, RequestResponseType.CompatibilityCheck]
@@ -21,25 +28,30 @@ def __init__(self, message: Union[str, bytes]):
2128
super().__init__(message)
2229

2330
@handle_exceptions
24-
def process(self) -> AdsVersionResponse | CompatibilityCheckResponse:
31+
def process(self) -> Union[AdsVersionResponse, CompatibilityCheckResponse]:
2532
request = json.loads(self.message)
26-
if request.get('kind') == 'AdsVersion':
33+
if request.get("kind") == "AdsVersion":
2734
version = metadata.version("oracle_ads")
2835
response = AdsVersionResponse(
2936
message_id=request.get("message_id"),
3037
kind=RequestResponseType.AdsVersion,
31-
data=version)
38+
data=version,
39+
)
3240
return response
33-
if request.get('kind') == 'CompatibilityCheck':
41+
if request.get("kind") == "CompatibilityCheck":
3442
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
35-
return CompatibilityCheckResponse(message_id=request.get("message_id"),
36-
kind=RequestResponseType.CompatibilityCheck,
37-
data={'status': 'ok'})
43+
return CompatibilityCheckResponse(
44+
message_id=request.get("message_id"),
45+
kind=RequestResponseType.CompatibilityCheck,
46+
data={"status": "ok"},
47+
)
3848
elif known_realm():
39-
return CompatibilityCheckResponse(message_id=request.get("message_id"),
40-
kind=RequestResponseType.CompatibilityCheck,
41-
data={'status': 'compatible'})
49+
return CompatibilityCheckResponse(
50+
message_id=request.get("message_id"),
51+
kind=RequestResponseType.CompatibilityCheck,
52+
data={"status": "compatible"},
53+
)
4254
else:
4355
raise AquaResourceAccessError(
44-
f"The AI Quick actions extension is not compatible in the given region."
56+
"The AI Quick actions extension is not compatible in the given region."
4557
)

ads/aqua/extension/deployment_handler.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -8,8 +7,8 @@
87
from tornado.web import HTTPError
98

109
from ads.aqua.common.decorator import handle_exceptions
11-
from ads.aqua.extension.errors import Errors
1210
from ads.aqua.extension.base_handler import AquaAPIhandler
11+
from ads.aqua.extension.errors import Errors
1312
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
1413
from ads.aqua.modeldeployment.entities import ModelParams
1514
from ads.config import COMPARTMENT_OCID, PROJECT_OCID
@@ -66,8 +65,8 @@ def post(self, *args, **kwargs):
6665
"""
6766
try:
6867
input_data = self.get_json_body()
69-
except Exception:
70-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
68+
except Exception as ex:
69+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
7170

7271
if not input_data:
7372
raise HTTPError(400, Errors.NO_INPUT_DATA)
@@ -100,6 +99,8 @@ def post(self, *args, **kwargs):
10099
health_check_port = input_data.get("health_check_port")
101100
env_var = input_data.get("env_var")
102101
container_family = input_data.get("container_family")
102+
ocpus = input_data.get("ocpus")
103+
memory_in_gbs = input_data.get("memory_in_gbs")
103104

104105
self.finish(
105106
AquaDeploymentApp().create(
@@ -119,6 +120,8 @@ def post(self, *args, **kwargs):
119120
health_check_port=health_check_port,
120121
env_var=env_var,
121122
container_family=container_family,
123+
ocpus=ocpus,
124+
memory_in_gbs=memory_in_gbs,
122125
)
123126
)
124127

@@ -153,9 +156,7 @@ def validate_predict_url(endpoint):
153156
return False
154157
if not url.netloc:
155158
return False
156-
if not url.path.endswith("/predict"):
157-
return False
158-
return True
159+
return url.path.endswith("/predict")
159160
except Exception:
160161
return False
161162

@@ -170,8 +171,8 @@ def post(self, *args, **kwargs):
170171
"""
171172
try:
172173
input_data = self.get_json_body()
173-
except Exception:
174-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
174+
except Exception as ex:
175+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
175176

176177
if not input_data:
177178
raise HTTPError(400, Errors.NO_INPUT_DATA)
@@ -192,10 +193,10 @@ def post(self, *args, **kwargs):
192193
)
193194
try:
194195
model_params_obj = ModelParams(**model_params)
195-
except:
196+
except Exception as ex:
196197
raise HTTPError(
197198
400, Errors.INVALID_INPUT_DATA_FORMAT.format("model_params")
198-
)
199+
) from ex
199200

200201
return self.finish(
201202
MDInferenceResponse(prompt, model_params_obj).get_model_deployment_response(
@@ -236,8 +237,8 @@ def post(self, *args, **kwargs):
236237
"""
237238
try:
238239
input_data = self.get_json_body()
239-
except Exception:
240-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
240+
except Exception as ex:
241+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
241242

242243
if not input_data:
243244
raise HTTPError(400, Errors.NO_INPUT_DATA)
Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,35 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
16
import json
27
from typing import List, Union
38

49
from ads.aqua.common.decorator import handle_exceptions
510
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
6-
from ads.aqua.extension.models.ws_models import RequestResponseType, ListDeploymentResponse, ListDeploymentRequest, \
7-
ModelDeploymentDetailsResponse
11+
from ads.aqua.extension.models.ws_models import (
12+
ListDeploymentResponse,
13+
ModelDeploymentDetailsResponse,
14+
RequestResponseType,
15+
)
816
from ads.aqua.modeldeployment import AquaDeploymentApp
917
from ads.config import COMPARTMENT_OCID
1018

1119

1220
class AquaDeploymentWSMsgHandler(AquaWSMsgHandler):
13-
1421
def __init__(self, message: Union[str, bytes]):
1522
super().__init__(message)
1623

1724
@staticmethod
1825
def get_message_types() -> List[RequestResponseType]:
19-
return [RequestResponseType.ListDeployments, RequestResponseType.DeploymentDetails]
26+
return [
27+
RequestResponseType.ListDeployments,
28+
RequestResponseType.DeploymentDetails,
29+
]
2030

2131
@handle_exceptions
22-
def process(self) -> ListDeploymentResponse | ModelDeploymentDetailsResponse:
32+
def process(self) -> Union[ListDeploymentResponse, ModelDeploymentDetailsResponse]:
2333
request = json.loads(self.message)
2434
if request.get("kind") == "ListDeployments":
2535
deployment_list = AquaDeploymentApp().list(
@@ -33,8 +43,12 @@ def process(self) -> ListDeploymentResponse | ModelDeploymentDetailsResponse:
3343
)
3444
return response
3545
elif request.get("kind") == "DeploymentDetails":
36-
deployment_details = AquaDeploymentApp().get(request.get("model_deployment_id"))
37-
response = ModelDeploymentDetailsResponse(message_id=request.get("message_id"),
38-
kind=RequestResponseType.DeploymentDetails,
39-
data=deployment_details)
46+
deployment_details = AquaDeploymentApp().get(
47+
request.get("model_deployment_id")
48+
)
49+
response = ModelDeploymentDetailsResponse(
50+
message_id=request.get("message_id"),
51+
kind=RequestResponseType.DeploymentDetails,
52+
data=deployment_details,
53+
)
4054
return response

ads/aqua/extension/evaluation_ws_msg_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self, message: Union[str, bytes]):
2929
super().__init__(message)
3030

3131
@handle_exceptions
32-
def process(self) -> ListEvaluationsResponse | EvaluationDetailsResponse:
32+
def process(self) -> Union[ListEvaluationsResponse, EvaluationDetailsResponse]:
3333
request = json.loads(self.message)
3434
if request["kind"] == "ListEvaluations":
3535
return self.list_evaluations(request)
Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,49 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
16
import json
27
from typing import List, Union
38

49
from ads.aqua.common.decorator import handle_exceptions
510
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
6-
from ads.aqua.extension.models.ws_models import RequestResponseType, ListModelsResponse, ListModelsRequest, \
7-
ModelDetailsResponse
11+
from ads.aqua.extension.models.ws_models import (
12+
ListModelsResponse,
13+
ModelDetailsResponse,
14+
RequestResponseType,
15+
)
816
from ads.aqua.model import AquaModelApp
917

1018

1119
class AquaModelWSMsgHandler(AquaWSMsgHandler):
12-
1320
def __init__(self, message: Union[str, bytes]):
1421
super().__init__(message)
1522

1623
@staticmethod
1724
def get_message_types() -> List[RequestResponseType]:
18-
return [RequestResponseType.ListModels,RequestResponseType.ModelDetails]
25+
return [RequestResponseType.ListModels, RequestResponseType.ModelDetails]
1926

2027
@handle_exceptions
21-
def process(self) -> ListModelsResponse | ModelDetailsResponse:
28+
def process(self) -> Union[ListModelsResponse, ModelDetailsResponse]:
2229
request = json.loads(self.message)
23-
if request.get('kind') == 'ListModels':
30+
if request.get("kind") == "ListModels":
2431
models_list = AquaModelApp().list(
2532
compartment_id=request.get("compartment_id"),
2633
project_id=request.get("project_id"),
27-
model_type=request.get("model_type")
34+
model_type=request.get("model_type"),
2835
)
2936
response = ListModelsResponse(
3037
message_id=request.get("message_id"),
3138
kind=RequestResponseType.ListModels,
3239
data=models_list,
3340
)
3441
return response
35-
elif request.get('kind') == 'ModelDetails':
36-
model_id=request.get("model_id")
37-
response=AquaModelApp().get(model_id)
38-
return ModelDetailsResponse(message_id=request.get("message_id"),
39-
kind=RequestResponseType.ModelDetails,
40-
data=response)
41-
42+
elif request.get("kind") == "ModelDetails":
43+
model_id = request.get("model_id")
44+
response = AquaModelApp().get(model_id)
45+
return ModelDetailsResponse(
46+
message_id=request.get("message_id"),
47+
kind=RequestResponseType.ModelDetails,
48+
data=response,
49+
)

ads/aqua/modeldeployment/deployment.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
import logging
6-
from typing import Dict, List, Union
6+
from typing import Dict, List, Optional, Union
77

88
from ads.aqua.app import AquaApp, logger
99
from ads.aqua.common.enums import (
@@ -102,6 +102,8 @@ def create(
102102
health_check_port: int = None,
103103
env_var: Dict = None,
104104
container_family: str = None,
105+
memory_in_gbs: Optional[float] = None,
106+
ocpus: Optional[float] = None,
105107
) -> "AquaDeployment":
106108
"""
107109
Creates a new Aqua deployment
@@ -142,6 +144,10 @@ def create(
142144
Environment variable for the deployment, by default None.
143145
container_family: str
144146
The image family of model deployment container runtime. Required for unverified Aqua models.
147+
memory_in_gbs: float
148+
The memory in gbs for the shape selected.
149+
ocpus: float
150+
The ocpu count for the shape selected.
145151
Returns
146152
-------
147153
AquaDeployment
@@ -325,6 +331,11 @@ def create(
325331
log_id=predict_log_id,
326332
)
327333
)
334+
if memory_in_gbs and ocpus and infrastructure.shape_name.endswith("Flex"):
335+
infrastructure.with_shape_config_details(
336+
ocpus=ocpus,
337+
memory_in_gbs=memory_in_gbs,
338+
)
328339
# configure model deployment runtime
329340
container_runtime = (
330341
ModelDeploymentContainerRuntime()
@@ -338,6 +349,7 @@ def create(
338349
.with_overwrite_existing_artifact(True)
339350
.with_remove_existing_artifact(True)
340351
)
352+
341353
# configure model deployment and deploy model on container runtime
342354
deployment = (
343355
ModelDeployment()
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
kind: deployment
2+
spec:
3+
createdBy: ocid1.user.oc1..<OCID>
4+
displayName: model-deployment-name
5+
freeformTags:
6+
OCI_AQUA: active
7+
aqua_model_name: model-name
8+
id: "ocid1.datasciencemodeldeployment.oc1.<region>.<MD_OCID>"
9+
infrastructure:
10+
kind: infrastructure
11+
spec:
12+
bandwidthMbps: 10
13+
compartmentId: ocid1.compartment.oc1..<OCID>
14+
deploymentType: SINGLE_MODEL
15+
policyType: FIXED_SIZE
16+
projectId: ocid1.datascienceproject.oc1.iad.<OCID>
17+
replica: 1
18+
shapeName: "VM.Standard.A1.Flex"
19+
shapeConfigDetails:
20+
memoryInGBs: 60.0
21+
ocpus: 10
22+
type: datascienceModelDeployment
23+
lifecycleState: CREATING
24+
modelDeploymentUrl: "https://modeldeployment.customer-oci.com/ocid1.datasciencemodeldeployment.oc1.<region>.<MD_OCID>"
25+
runtime:
26+
kind: runtime
27+
spec:
28+
env:
29+
BASE_MODEL: service_models/model-name/artifact
30+
BASE_MODEL_FILE: model-name.gguf
31+
MODEL_DEPLOY_ENABLE_STREAMING: 'true'
32+
MODEL_DEPLOY_PREDICT_ENDPOINT: /v1/completions
33+
MODEL_DEPLOY_HEALTH_ENDPOINT: /v1/models
34+
healthCheckPort: 8080
35+
image: "dsmc://image-name:1.0.0.0"
36+
modelUri: "ocid1.datasciencemodeldeployment.oc1.<region>.<MODEL_OCID>"
37+
serverPort: 8080
38+
type: container
39+
timeCreated: 2024-01-01T00:00:00.000000+00:00
40+
type: modelDeployment

0 commit comments

Comments
 (0)