|
17 | 17 | from ads.common import auth as authutil
|
18 | 18 | import pandas as pd
|
19 | 19 | from ads.model.serde.model_input import JsonModelInputSERDE
|
20 |
| -from ads.common import auth, oci_client |
21 | 20 | from ads.common.oci_logging import (
|
22 | 21 | LOG_INTERVAL,
|
23 | 22 | LOG_RECORDS_LIMIT,
|
|
63 | 62 |
|
64 | 63 | MODEL_DEPLOYMENT_KIND = "deployment"
|
65 | 64 | MODEL_DEPLOYMENT_TYPE = "modelDeployment"
|
| 65 | +MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON = "TRITON" |
66 | 66 |
|
67 | 67 | MODEL_DEPLOYMENT_INSTANCE_SHAPE = "VM.Standard.E4.Flex"
|
68 | 68 | MODEL_DEPLOYMENT_INSTANCE_OCPUS = 1
|
@@ -926,10 +926,7 @@ def predict(
|
926 | 926 | if model_name and model_version:
|
927 | 927 | header['model-name'] = model_name
|
928 | 928 | header['model-version'] = model_version
|
929 |
| - elif not model_version and not model_name: |
930 |
| - |
931 |
| - pass |
932 |
| - else: |
| 929 | + elif bool(model_version) ^ bool(model_name): |
933 | 930 | raise ValueError("`model_name` and `model_version` have to be provided together.")
|
934 | 931 | prediction = send_request(
|
935 | 932 | data=data, endpoint=endpoint, is_json_payload=is_json_payload, header=header,
|
@@ -1404,9 +1401,9 @@ def _update_from_oci_model(self, oci_model_instance) -> "ModelDeployment":
|
1404 | 1401 | infrastructure.CONST_WEB_CONCURRENCY,
|
1405 | 1402 | runtime.env.get("WEB_CONCURRENCY", None),
|
1406 | 1403 | )
|
1407 |
| - if runtime.env.get("CONTAINER_TYPE", None) == "TRITON": |
| 1404 | + if runtime.env.get("CONTAINER_TYPE", None) == MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON: |
1408 | 1405 | runtime.set_spec(
|
1409 |
| - runtime.CONST_INFERENCE_SERVER, "triton" |
| 1406 | + runtime.CONST_INFERENCE_SERVER, MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON.lower() |
1410 | 1407 | )
|
1411 | 1408 |
|
1412 | 1409 | self.set_spec(self.CONST_INFRASTRUCTURE, infrastructure)
|
|
0 commit comments