Skip to content

Commit 81494c2

Browse files
authored
Fix override container family (#938)
1 parent 2821e30 commit 81494c2

File tree

1 file changed

+31
-51
lines changed

1 file changed

+31
-51
lines changed

ads/aqua/modeldeployment/deployment.py

Lines changed: 31 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def create(
146146
env_var : dict, optional
147147
Environment variable for the deployment, by default None.
148148
container_family: str
149-
The image family of model deployment container runtime. Required for unverified Aqua models.
149+
The image family of model deployment container runtime.
150150
memory_in_gbs: float
151151
The memory in gbs for the shape selected.
152152
ocpus: float
@@ -230,41 +230,14 @@ def create(
230230

231231
env_var.update({"FT_MODEL": f"{fine_tune_output_path}"})
232232

233-
is_custom_container = False
234-
try:
235-
container_type_key = aqua_model.custom_metadata_list.get(
236-
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
237-
).value
238-
except ValueError as err:
239-
message = (
240-
f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field "
241-
f"for model {aqua_model.id}."
242-
)
243-
logger.debug(message)
244-
if not container_family:
245-
raise AquaValueError(
246-
f"{message}. For unverified Aqua models, container_family parameter should be "
247-
f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
248-
) from err
249-
container_type_key = container_family
250-
try:
251-
# Check if the container override flag is set. If set, then the user has chosen custom image
252-
if aqua_model.custom_metadata_list.get(
253-
AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME
254-
).value:
255-
is_custom_container = True
256-
except Exception:
257-
pass
233+
container_type_key = self._get_container_type_key(
234+
model=aqua_model,
235+
container_family=container_family
236+
)
258237

259238
# fetch image name from config
260-
# If the image is of type custom, then `container_type_key` is the inference image
261-
container_image = (
262-
get_container_image(
263-
container_type=container_type_key,
264-
)
265-
if not is_custom_container
266-
else container_type_key
267-
)
239+
container_image = get_container_image(container_type=container_type_key)
240+
268241
logging.info(
269242
f"Aqua Image used for deploying {aqua_model.id} : {container_image}"
270243
)
@@ -433,6 +406,26 @@ def create(
433406
deployment.dsc_model_deployment, self.region
434407
)
435408

409+
@staticmethod
410+
def _get_container_type_key(model: DataScienceModel, container_family: str) -> str:
411+
container_type_key = UNKNOWN
412+
if container_family:
413+
container_type_key = container_family
414+
else:
415+
try:
416+
container_type_key = model.custom_metadata_list.get(
417+
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
418+
).value
419+
except ValueError as err:
420+
raise AquaValueError(
421+
f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field "
422+
f"for model {model.id}. For unverified Aqua models, {AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} should be"
423+
f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
424+
) from err
425+
426+
return container_type_key
427+
428+
436429
@telemetry(entry_point="plugin=deployment&action=list", name="aqua")
437430
def list(self, **kwargs) -> List["AquaDeployment"]:
438431
"""List Aqua model deployments in a given compartment and under certain project.
@@ -672,23 +665,10 @@ def validate_deployment_params(
672665
restricted_params = []
673666
if params:
674667
model = DataScienceModel.from_id(model_id)
675-
try:
676-
container_type_key = model.custom_metadata_list.get(
677-
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
678-
).value
679-
except ValueError as err:
680-
message = (
681-
f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field "
682-
f"for model {model_id}."
683-
)
684-
logger.debug(message)
685-
686-
if not container_family:
687-
raise AquaValueError(
688-
f"{message}. For unverified Aqua models, container_family parameter should be "
689-
f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
690-
) from err
691-
container_type_key = container_family
668+
container_type_key = self._get_container_type_key(
669+
model=model,
670+
container_family=container_family
671+
)
692672

693673
container_config = get_container_config()
694674
container_spec = container_config.get(ContainerSpec.CONTAINER_SPEC, {}).get(

0 commit comments

Comments
 (0)