Skip to content

Commit cd2c8f3

Browse files
authored
ODSC 39392/triton (#128)
2 parents a9df825 + b66c02f commit cd2c8f3

File tree

4 files changed

+161
-13
lines changed

4 files changed

+161
-13
lines changed

ads/model/deployment/common/utils.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,19 +119,14 @@ def send_request(
119119
Returns:
120120
A JSON representive of a requests.Response object.
121121
"""
122-
headers = dict()
123122
if is_json_payload:
124-
headers["Content-Type"] = (
125-
header.get("content_type") or DEFAULT_CONTENT_TYPE_JSON
126-
)
123+
header["Content-Type"] = header.pop("content_type", DEFAULT_CONTENT_TYPE_JSON) or DEFAULT_CONTENT_TYPE_JSON
127124
request_kwargs = {"json": data}
128125
else:
129-
headers["Content-Type"] = (
130-
header.get("content_type") or DEFAULT_CONTENT_TYPE_BYTES
131-
)
126+
header["Content-Type"] = header.pop("content_type", DEFAULT_CONTENT_TYPE_BYTES) or DEFAULT_CONTENT_TYPE_BYTES
132127
request_kwargs = {"data": data} # should pass bytes when using data
133-
134-
request_kwargs["headers"] = headers
128+
129+
request_kwargs["headers"] = header
135130

136131
if dry_run:
137132
request_kwargs["headers"]["Accept"] = "*/*"
@@ -140,7 +135,7 @@ def send_request(
140135
return json.loads(req.body)
141136
return req.body
142137
else:
143-
request_kwargs["auth"] = header.get("signer")
138+
request_kwargs["auth"] = header.pop("signer")
144139
return requests.post(endpoint, **request_kwargs).json()
145140

146141

ads/model/deployment/model_deployment.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from ads.common import auth as authutil
1818
import pandas as pd
1919
from ads.model.serde.model_input import JsonModelInputSERDE
20-
from ads.common import auth, oci_client
2120
from ads.common.oci_logging import (
2221
LOG_INTERVAL,
2322
LOG_RECORDS_LIMIT,
@@ -63,6 +62,7 @@
6362

6463
MODEL_DEPLOYMENT_KIND = "deployment"
6564
MODEL_DEPLOYMENT_TYPE = "modelDeployment"
65+
MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON = "TRITON"
6666

6767
MODEL_DEPLOYMENT_INSTANCE_SHAPE = "VM.Standard.E4.Flex"
6868
MODEL_DEPLOYMENT_INSTANCE_OCPUS = 1
@@ -828,6 +828,8 @@ def predict(
828828
data: Any = None,
829829
serializer: "ads.model.ModelInputSerializer" = model_input_serializer,
830830
auto_serialize_data: bool = False,
831+
model_name: str = None,
832+
model_version: str = None,
831833
**kwargs,
832834
) -> dict:
833835
"""Returns prediction of input data run against the model deployment endpoint.
@@ -860,6 +862,10 @@ def predict(
860862
If `auto_serialize_data=False`, `data` required to be bytes or json serializable
861863
and `json_input` required to be json serializable. If `auto_serialize_data` set
862864
to True, data will be serialized before sending to model deployment endpoint.
865+
model_name: str
866+
Defaults to None. When the `Inference_server="triton"`, the name of the model to invoke.
867+
model_version: str
868+
Defaults to None. When the `Inference_server="triton"`, the version of the model to invoke.
863869
kwargs:
864870
content_type: str
865871
Used to indicate the media type of the resource.
@@ -878,6 +884,7 @@ def predict(
878884
"signer": signer,
879885
"content_type": kwargs.get("content_type", None),
880886
}
887+
header.update(kwargs.pop("headers", {}))
881888

882889
if data is None and json_input is None:
883890
raise AttributeError(
@@ -916,9 +923,13 @@ def predict(
916923
raise TypeError(
917924
"`data` is not bytes or json serializable. Set `auto_serialize_data` to `True` to serialize the input data."
918925
)
919-
926+
if model_name and model_version:
927+
header['model-name'] = model_name
928+
header['model-version'] = model_version
929+
elif bool(model_version) ^ bool(model_name):
930+
raise ValueError("`model_name` and `model_version` have to be provided together.")
920931
prediction = send_request(
921-
data=data, endpoint=endpoint, is_json_payload=is_json_payload, header=header
932+
data=data, endpoint=endpoint, is_json_payload=is_json_payload, header=header,
922933
)
923934
return prediction
924935

@@ -1390,6 +1401,10 @@ def _update_from_oci_model(self, oci_model_instance) -> "ModelDeployment":
13901401
infrastructure.CONST_WEB_CONCURRENCY,
13911402
runtime.env.get("WEB_CONCURRENCY", None),
13921403
)
1404+
if runtime.env.get("CONTAINER_TYPE", None) == MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON:
1405+
runtime.set_spec(
1406+
runtime.CONST_INFERENCE_SERVER, MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON.lower()
1407+
)
13931408

13941409
self.set_spec(self.CONST_INFRASTRUCTURE, infrastructure)
13951410
self.set_spec(self.CONST_RUNTIME, runtime)
@@ -1566,6 +1581,9 @@ def _build_model_deployment_configuration_details(self) -> Dict:
15661581
infrastructure.web_concurrency
15671582
)
15681583
runtime.set_spec(runtime.CONST_ENV, environment_variables)
1584+
if hasattr(runtime, "inference_server") and runtime.inference_server and runtime.inference_server.upper() == MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON:
1585+
environment_variables["CONTAINER_TYPE"] = MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON
1586+
runtime.set_spec(runtime.CONST_ENV, environment_variables)
15691587
environment_configuration_details = {
15701588
runtime.CONST_ENVIRONMENT_CONFIG_TYPE: runtime.environment_config_type,
15711589
runtime.CONST_ENVIRONMENT_VARIABLES: runtime.env,

ads/model/deployment/model_deployment_runtime.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ class ModelDeploymentContainerRuntime(ModelDeploymentRuntime):
330330
CONST_ENTRYPOINT = "entrypoint"
331331
CONST_SERVER_PORT = "serverPort"
332332
CONST_HEALTH_CHECK_PORT = "healthCheckPort"
333+
CONST_INFERENCE_SERVER = "inferenceServer"
333334

334335
attribute_map = {
335336
**ModelDeploymentRuntime.attribute_map,
@@ -339,6 +340,7 @@ class ModelDeploymentContainerRuntime(ModelDeploymentRuntime):
339340
CONST_ENTRYPOINT: "entrypoint",
340341
CONST_SERVER_PORT: "server_port",
341342
CONST_HEALTH_CHECK_PORT: "health_check_port",
343+
CONST_INFERENCE_SERVER: "inference_server"
342344
}
343345

344346
payload_attribute_map = {
@@ -532,3 +534,57 @@ def with_health_check_port(
532534
The ModelDeploymentContainerRuntime instance (self).
533535
"""
534536
return self.set_spec(self.CONST_HEALTH_CHECK_PORT, health_check_port)
537+
538+
@property
539+
def inference_server(self) -> str:
540+
"""Returns the inference server.
541+
542+
Returns
543+
-------
544+
str
545+
The inference server.
546+
"""
547+
return self.get_spec(self.CONST_INFERENCE_SERVER, None)
548+
549+
def with_inference_server(self, inference_server: str = "triton") -> "ModelDeploymentRuntime":
550+
"""Sets the inference server. Current supported inference server is "triton".
551+
Note if you are using byoc, you do not need to set the inference server.
552+
553+
Parameters
554+
----------
555+
inference_server: str
556+
Set the inference server.
557+
558+
Returns
559+
-------
560+
ModelDeploymentRuntime
561+
The ModelDeploymentRuntime instance (self).
562+
563+
Example
564+
-------
565+
>>> from ads.model.deployment import ModelDeployment, ModelDeploymentContainerRuntime, ModelDeploymentInfrastructure
566+
>>> import ads
567+
>>> ads.set_auth("resource_principal")
568+
>>> infrastructure = ModelDeploymentInfrastructure()\
569+
... .with_project_id(<project_id>)\
570+
... .with_compartment_id(<comparment_id>)\
571+
... .with_shape_name("VM.Standard.E4.Flex")\
572+
... .with_replica(2)\
573+
... .with_bandwidth_mbps(10)\
574+
... .with_access_log(log_group_id=<deployment_log_group_id>, log_id=<deployment_access_log_id>)\
575+
... .with_predict_log(log_group_id=<deployment_log_group_id>, log_id=<deployment_predict_log_id>)
576+
577+
>>> runtime = ModelDeploymentContainerRuntime()\
578+
... .with_image(<container_image>)\
579+
... .with_server_port(<server_port>)\
580+
... .with_health_check_port(<health_check_port>)\
581+
... .with_model_uri(<model_id>)\
582+
... .with_env({"key":"value", "key2":"value2"})\
583+
... .with_inference_server("triton")
584+
>>> deployment = ModelDeployment()\
585+
... .with_display_name("Triton Example")\
586+
... .with_infrastructure(infrastructure)\
587+
... .with_runtime(runtime)
588+
>>> deployment.deploy()
589+
"""
590+
return self.set_spec(self.CONST_INFERENCE_SERVER, inference_server.lower())

tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,62 @@ def initialize_model_deployment_from_spec(self):
308308
"runtime": runtime,
309309
}
310310
)
311+
312+
def initialize_model_deployment_triton_builder(self):
313+
infrastructure = ModelDeploymentInfrastructure()\
314+
.with_compartment_id("fakeid.compartment.oc1..xxx")\
315+
.with_project_id("fakeid.datascienceproject.oc1.iad.xxx")\
316+
.with_shape_name("VM.Standard.E4.Flex")\
317+
.with_replica(2)\
318+
.with_bandwidth_mbps(10)\
319+
320+
runtime = ModelDeploymentContainerRuntime()\
321+
.with_image("fake_image")\
322+
.with_server_port(5000)\
323+
.with_health_check_port(5000)\
324+
.with_model_uri("fake_model_id")\
325+
.with_env({"key":"value", "key2":"value2"})\
326+
.with_inference_server("triton")
327+
328+
deployment = ModelDeployment()\
329+
.with_display_name("triton case")\
330+
.with_infrastructure(infrastructure)\
331+
.with_runtime(runtime)
332+
return deployment
333+
334+
def initialize_model_deployment_triton_yaml(self):
335+
yaml_string = """
336+
kind: deployment
337+
spec:
338+
displayName: triton
339+
infrastructure:
340+
kind: infrastructure
341+
spec:
342+
bandwidthMbps: 10
343+
compartmentId: fake_compartment_id
344+
deploymentType: SINGLE_MODEL
345+
policyType: FIXED_SIZE
346+
replica: 2
347+
shapeConfigDetails:
348+
memoryInGBs: 16.0
349+
ocpus: 1.0
350+
shapeName: VM.Standard.E4.Flex
351+
type: datascienceModelDeployment
352+
runtime:
353+
kind: runtime
354+
spec:
355+
env:
356+
key: value
357+
key2: value2
358+
inference_server: triton
359+
healthCheckPort: 8000
360+
image: fake_image
361+
modelUri: fake_model_id
362+
serverPort: 8000
363+
type: container
364+
"""
365+
deployment_from_yaml = ModelDeployment.from_yaml(yaml_string)
366+
return deployment_from_yaml
311367

312368
def initialize_model_deployment_from_kwargs(self):
313369
infrastructure = (
@@ -435,11 +491,34 @@ def test_initialize_model_deployment_with_error(self):
435491
},
436492
)
437493

494+
438495
def test_initialize_model_deployment_with_spec_kwargs(self):
439496
model_deployment_kwargs = self.initialize_model_deployment_from_kwargs()
440497
model_deployment_builder = self.initialize_model_deployment()
441498

442499
assert model_deployment_kwargs.to_dict() == model_deployment_builder.to_dict()
500+
501+
502+
def test_initialize_model_deployment_triton_builder(self):
503+
temp_model_deployment = self.initialize_model_deployment_triton_builder()
504+
assert isinstance(
505+
temp_model_deployment.runtime, ModelDeploymentContainerRuntime
506+
)
507+
assert isinstance(
508+
temp_model_deployment.infrastructure, ModelDeploymentInfrastructure
509+
)
510+
assert temp_model_deployment.runtime.inference_server == "triton"
511+
512+
def test_initialize_model_deployment_triton_yaml(self):
513+
temp_model_deployment = self.initialize_model_deployment_triton_yaml()
514+
assert isinstance(
515+
temp_model_deployment.runtime, ModelDeploymentContainerRuntime
516+
)
517+
assert isinstance(
518+
temp_model_deployment.infrastructure, ModelDeploymentInfrastructure
519+
)
520+
assert temp_model_deployment.runtime.inference_server == "triton"
521+
443522

444523
def test_model_deployment_to_dict(self):
445524
model_deployment = self.initialize_model_deployment()

0 commit comments

Comments
 (0)