Skip to content

Commit f8de160

Browse files
authored
Prioritize generic model deploy parameters. (#239)
2 parents 79ccd06 + 2052c8e commit f8de160

File tree

4 files changed

+216
-97
lines changed

4 files changed

+216
-97
lines changed

ads/jobs/builders/infrastructure/dsc_job.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,9 @@ def load_properties_from_env(self) -> None:
266266
# This will skip loading the default configure.
267267
nb_session = None
268268
if nb_session:
269-
nb_config = getattr(
270-
nb_session,
271-
"notebook_session_config_details",
272-
getattr(nb_session, "notebook_session_configuration_details", None),
269+
nb_config = (
270+
getattr(nb_session, "notebook_session_config_details", None)
271+
or getattr(nb_session, "notebook_session_configuration_details", None)
273272
)
274273

275274
if nb_config:

ads/model/deployment/model_deployment_infrastructure.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,18 @@ def _load_default_properties(self) -> Dict:
225225
if NB_SESSION_OCID:
226226
try:
227227
nb_session = DSCNotebookSession.from_ocid(NB_SESSION_OCID)
228-
nb_config = nb_session.notebook_session_configuration_details
228+
except Exception as e:
229+
logger.warning(
230+
f"Error fetching details about Notebook "
231+
f"session: {NB_SESSION_OCID}. {e}"
232+
)
233+
logger.debug(traceback.format_exc())
234+
235+
nb_config = (
236+
getattr(nb_session, "notebook_session_config_details", None)
237+
or getattr(nb_session, "notebook_session_configuration_details", None)
238+
)
239+
if nb_config:
229240
defaults[self.CONST_SHAPE_NAME] = nb_config.shape
230241

231242
if nb_config.notebook_session_shape_config_details:
@@ -236,13 +247,6 @@ def _load_default_properties(self) -> Dict:
236247
notebook_shape_config_details
237248
)
238249

239-
except Exception as e:
240-
logger.warning(
241-
f"Error fetching details about Notebook "
242-
f"session: {NB_SESSION_OCID}. {e}"
243-
)
244-
logger.debug(traceback.format_exc())
245-
246250
return defaults
247251

248252
@property

ads/model/generic_model.py

Lines changed: 75 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -2101,7 +2101,7 @@ def deploy(
21012101
if not display_name:
21022102
display_name = utils.get_random_name_for_resource()
21032103
# populates properties from args and kwargs. Empty values will be ignored.
2104-
self.properties.with_dict(_extract_locals(locals()))
2104+
override_properties = _extract_locals(locals())
21052105
# clears out project_id and compartment_id from kwargs, to prevent passing
21062106
# these params to the deployment via kwargs.
21072107
kwargs.pop("project_id", None)
@@ -2110,19 +2110,55 @@ def deploy(
21102110
max_wait_time = kwargs.pop("max_wait_time", DEFAULT_WAIT_TIME)
21112111
poll_interval = kwargs.pop("poll_interval", DEFAULT_POLL_INTERVAL)
21122112

2113-
self.properties.compartment_id = (
2114-
self.properties.compartment_id or _COMPARTMENT_OCID
2115-
)
2116-
self.properties.project_id = self.properties.project_id or PROJECT_OCID
2117-
self.properties.deployment_instance_shape = (
2118-
self.properties.deployment_instance_shape or MODEL_DEPLOYMENT_INSTANCE_SHAPE
2119-
)
2120-
self.properties.deployment_instance_count = (
2121-
self.properties.deployment_instance_count or MODEL_DEPLOYMENT_INSTANCE_COUNT
2122-
)
2123-
self.properties.deployment_bandwidth_mbps = (
2124-
self.properties.deployment_bandwidth_mbps or MODEL_DEPLOYMENT_BANDWIDTH_MBPS
2125-
)
2113+
# GenericModel itself has a ModelDeployment instance. When calling deploy(),
2114+
# if there are parameters passed in they will override this ModelDeployment instance,
2115+
# otherwise the properties of the ModelDeployment instance will be applied for deployment.
2116+
existing_infrastructure = self.model_deployment.infrastructure
2117+
existing_runtime = self.model_deployment.runtime
2118+
property_dict = ModelProperties(
2119+
compartment_id = existing_infrastructure.compartment_id
2120+
or self.properties.compartment_id
2121+
or _COMPARTMENT_OCID,
2122+
project_id = existing_infrastructure.project_id
2123+
or self.properties.project_id
2124+
or PROJECT_OCID,
2125+
deployment_instance_shape = existing_infrastructure.shape_name
2126+
or self.properties.deployment_instance_shape
2127+
or MODEL_DEPLOYMENT_INSTANCE_SHAPE,
2128+
deployment_instance_count = existing_infrastructure.replica
2129+
or self.properties.deployment_instance_count
2130+
or MODEL_DEPLOYMENT_INSTANCE_COUNT,
2131+
deployment_bandwidth_mbps = existing_infrastructure.bandwidth_mbps
2132+
or self.properties.deployment_bandwidth_mbps
2133+
or MODEL_DEPLOYMENT_BANDWIDTH_MBPS,
2134+
deployment_ocpus = existing_infrastructure.shape_config_details.get(
2135+
"ocpus", None
2136+
)
2137+
or self.properties.deployment_ocpus
2138+
or MODEL_DEPLOYMENT_INSTANCE_OCPUS,
2139+
deployment_memory_in_gbs = existing_infrastructure.shape_config_details.get(
2140+
"memoryInGBs", None
2141+
)
2142+
or self.properties.deployment_memory_in_gbs
2143+
or MODEL_DEPLOYMENT_INSTANCE_MEMORY_IN_GBS,
2144+
deployment_log_group_id = existing_infrastructure.log_group_id
2145+
or self.properties.deployment_log_group_id,
2146+
deployment_access_log_id = existing_infrastructure.access_log.get(
2147+
"log_id", None
2148+
)
2149+
or self.properties.deployment_access_log_id,
2150+
deployment_predict_log_id = existing_infrastructure.predict_log.get(
2151+
"log_id", None
2152+
)
2153+
or self.properties.deployment_predict_log_id,
2154+
deployment_image = existing_runtime.image
2155+
or self.properties.deployment_image,
2156+
deployment_instance_subnet_id = existing_infrastructure.subnet_id
2157+
or self.properties.deployment_instance_subnet_id
2158+
).to_dict()
2159+
2160+
property_dict.update(override_properties)
2161+
self.properties.with_dict(property_dict)
21262162

21272163
if not self.model_id:
21282164
raise ValueError(
@@ -2140,104 +2176,58 @@ def deploy(
21402176
"cannot be used without `deployment_log_group_id`."
21412177
)
21422178

2143-
existing_infrastructure = self.model_deployment.infrastructure
2144-
existing_runtime = self.model_deployment.runtime
2145-
2146-
web_concurrency = (
2147-
kwargs.pop("web_concurrency", None)
2148-
or existing_infrastructure.web_concurrency
2149-
)
2150-
if not (
2151-
self.properties.compartment_id or existing_infrastructure.compartment_id
2152-
):
2179+
if not self.properties.compartment_id:
21532180
raise ValueError("`compartment_id` has to be provided.")
2154-
if not (self.properties.project_id or existing_infrastructure.project_id):
2181+
if not self.properties.project_id:
21552182
raise ValueError("`project_id` has to be provided.")
21562183
infrastructure = (
21572184
ModelDeploymentInfrastructure()
2158-
.with_compartment_id(
2159-
self.properties.compartment_id or existing_infrastructure.compartment_id
2160-
)
2161-
.with_project_id(
2162-
self.properties.project_id or existing_infrastructure.project_id
2163-
)
2164-
.with_bandwidth_mbps(
2165-
self.properties.deployment_bandwidth_mbps
2166-
or existing_infrastructure.bandwidth_mbps
2167-
)
2168-
.with_shape_name(
2169-
self.properties.deployment_instance_shape
2170-
or existing_infrastructure.shape_name
2171-
)
2172-
.with_subnet_id(
2173-
self.properties.deployment_instance_subnet_id
2174-
or existing_infrastructure.subnet_id
2175-
)
2176-
.with_replica(
2177-
self.properties.deployment_instance_count
2178-
or existing_infrastructure.replica
2179-
)
2180-
.with_web_concurrency(web_concurrency)
2185+
.with_compartment_id(self.properties.compartment_id)
2186+
.with_project_id(self.properties.project_id)
2187+
.with_bandwidth_mbps(self.properties.deployment_bandwidth_mbps)
2188+
.with_shape_name(self.properties.deployment_instance_shape)
2189+
.with_replica(self.properties.deployment_instance_count)
2190+
.with_subnet_id(self.properties.deployment_instance_subnet_id)
21812191
)
21822192

2183-
ocpus = (
2184-
self.properties.deployment_ocpus
2185-
or existing_infrastructure.shape_config_details.get("ocpus")
2186-
)
2187-
memory_in_gbs = (
2188-
self.properties.deployment_memory_in_gbs
2189-
or existing_infrastructure.shape_config_details.get("memory_in_gbs")
2193+
web_concurrency = (
2194+
kwargs.pop("web_concurrency", None)
2195+
or existing_infrastructure.web_concurrency
21902196
)
2197+
if web_concurrency:
2198+
infrastructure.with_web_concurrency(web_concurrency)
21912199

21922200
if infrastructure.shape_name.endswith("Flex"):
21932201
infrastructure.with_shape_config_details(
2194-
ocpus=ocpus or MODEL_DEPLOYMENT_INSTANCE_OCPUS,
2195-
memory_in_gbs=memory_in_gbs or MODEL_DEPLOYMENT_INSTANCE_MEMORY_IN_GBS,
2202+
ocpus=self.properties.deployment_ocpus,
2203+
memory_in_gbs=self.properties.deployment_memory_in_gbs,
21962204
)
21972205

2198-
access_log_id = (
2199-
self.properties.deployment_access_log_id
2200-
or existing_infrastructure.access_log.get("log_id")
2201-
)
2202-
access_log_group_id = (
2203-
self.properties.deployment_log_group_id
2204-
or existing_infrastructure.access_log.get("log_group_id")
2205-
)
2206-
22072206
# specifies the access log id
2208-
if access_log_id:
2207+
if self.properties.deployment_access_log_id:
22092208
infrastructure.with_access_log(
2210-
log_group_id=access_log_group_id,
2211-
log_id=access_log_id,
2209+
log_group_id=self.properties.deployment_log_group_id,
2210+
log_id=self.properties.deployment_access_log_id,
22122211
)
22132212

2214-
predict_log_id = (
2215-
self.properties.deployment_predict_log_id
2216-
or existing_infrastructure.predict_log.get("log_id")
2217-
)
2218-
predict_log_group_id = (
2219-
self.properties.deployment_log_group_id
2220-
or existing_infrastructure.predict_log.get("log_group_id")
2221-
)
2222-
22232213
# specifies the predict log id
2224-
if predict_log_id:
2214+
if self.properties.deployment_predict_log_id:
22252215
infrastructure.with_predict_log(
2226-
log_group_id=predict_log_group_id,
2227-
log_id=predict_log_id,
2216+
log_group_id=self.properties.deployment_log_group_id,
2217+
log_id=self.properties.deployment_predict_log_id,
22282218
)
22292219

22302220
environment_variables = (
22312221
kwargs.pop("environment_variables", {}) or existing_runtime.env
22322222
)
22332223
deployment_mode = (
2234-
kwargs.pop("deployment_mode", ModelDeploymentMode.HTTPS)
2224+
kwargs.pop("deployment_mode", None)
22352225
or existing_runtime.deployment_mode
2226+
or ModelDeploymentMode.HTTPS
22362227
)
22372228

22382229
runtime = None
2239-
image = self.properties.deployment_image or existing_runtime.image
2240-
if image:
2230+
if self.properties.deployment_image:
22412231
image_digest = (
22422232
kwargs.pop("image_digest", None) or existing_runtime.image_digest
22432233
)
@@ -2252,7 +2242,7 @@ def deploy(
22522242
)
22532243
runtime = (
22542244
ModelDeploymentContainerRuntime()
2255-
.with_image(image)
2245+
.with_image(self.properties.deployment_image)
22562246
.with_image_digest(image_digest)
22572247
.with_cmd(cmd)
22582248
.with_entrypoint(entrypoint)

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,6 +1917,132 @@ def test_update(
19171917
self.generic_model.dsc_model.schema_output == self.mock_dsc_model.output_schema
19181918
self.generic_model.dsc_model.metadata_provenance == self.mock_dsc_model.provenance_metadata
19191919

1920+
@patch.object(ModelDeployment, "deploy")
1921+
def test_deploy_combined_initialization(self, mock_deploy):
1922+
self.generic_model.properties = ModelProperties(
1923+
deployment_image="default_test_docker_image",
1924+
compartment_id="default_test_compartment_id",
1925+
project_id="default_test_project_id",
1926+
)
1927+
test_model_id = "ocid.test_model_id"
1928+
self.generic_model.dsc_model = MagicMock(id=test_model_id)
1929+
self.generic_model.ignore_conda_error = True
1930+
infrastructure = ModelDeploymentInfrastructure(
1931+
**{
1932+
"shape_name": "test_deployment_instance_shape",
1933+
"replica": 10,
1934+
"bandwidth_mbps": 100,
1935+
"shape_config_details": {"memory_in_gbs": 10, "ocpus": 1},
1936+
"access_log": {
1937+
"log_group_id": "test_deployment_log_group_id",
1938+
"log_id": "test_deployment_access_log_id",
1939+
},
1940+
"predict_log": {
1941+
"log_group_id": "test_deployment_log_group_id",
1942+
"log_id": "test_deployment_predict_log_id",
1943+
},
1944+
"project_id": "project_id_passed_using_with",
1945+
"compartment_id": "compartment_id_passed_using_with",
1946+
}
1947+
)
1948+
runtime = ModelDeploymentContainerRuntime(
1949+
**{
1950+
"image": "image_passed_using_with",
1951+
"image_digest": "test_image_digest",
1952+
"cmd": ["test_cmd"],
1953+
"entrypoint": ["test_entrypoint"],
1954+
"server_port": 8080,
1955+
"health_check_port": 8080,
1956+
"env": {"test_key": "test_value"},
1957+
"deployment_mode": "HTTPS_ONLY",
1958+
}
1959+
)
1960+
mock_deploy.return_value = ModelDeployment(
1961+
display_name="test_display_name",
1962+
description="test_description",
1963+
infrastructure=infrastructure,
1964+
runtime=runtime,
1965+
model_deployment_url="test_model_deployment_url",
1966+
model_deployment_id="test_model_deployment_id",
1967+
)
1968+
input_dict = {
1969+
"wait_for_completion": True,
1970+
"display_name": "test_display_name",
1971+
"description": "test_description",
1972+
"deployment_instance_shape": "test_deployment_instance_shape",
1973+
"deployment_instance_count": 10,
1974+
"deployment_bandwidth_mbps": 100,
1975+
"deployment_memory_in_gbs": 10,
1976+
"deployment_ocpus": 1,
1977+
"deployment_log_group_id": "test_deployment_log_group_id",
1978+
"deployment_access_log_id": "test_deployment_access_log_id",
1979+
"deployment_predict_log_id": "test_deployment_predict_log_id",
1980+
"cmd": ["test_cmd"],
1981+
"entrypoint": ["test_entrypoint"],
1982+
"server_port": 8080,
1983+
"health_check_port": 8080,
1984+
"environment_variables": {"test_key": "test_value"},
1985+
"max_wait_time": 100,
1986+
"poll_interval": 200,
1987+
}
1988+
1989+
self.generic_model.model_deployment.infrastructure.with_compartment_id(
1990+
"compartment_id_passed_using_with"
1991+
).with_project_id("project_id_passed_using_with")
1992+
self.generic_model.model_deployment.runtime.with_image(
1993+
"image_passed_using_with"
1994+
)
1995+
1996+
result = self.generic_model.deploy(
1997+
**input_dict,
1998+
)
1999+
assert result == mock_deploy.return_value
2000+
assert result.infrastructure.access_log == {
2001+
"log_id": input_dict["deployment_access_log_id"],
2002+
"log_group_id": input_dict["deployment_log_group_id"],
2003+
}
2004+
assert result.infrastructure.predict_log == {
2005+
"log_id": input_dict["deployment_predict_log_id"],
2006+
"log_group_id": input_dict["deployment_log_group_id"],
2007+
}
2008+
assert (
2009+
result.infrastructure.bandwidth_mbps
2010+
== input_dict["deployment_bandwidth_mbps"]
2011+
)
2012+
assert (
2013+
result.infrastructure.compartment_id == "compartment_id_passed_using_with"
2014+
)
2015+
assert result.infrastructure.project_id == "project_id_passed_using_with"
2016+
assert (
2017+
result.infrastructure.shape_name == input_dict["deployment_instance_shape"]
2018+
)
2019+
assert result.infrastructure.shape_config_details == {
2020+
"ocpus": input_dict["deployment_ocpus"],
2021+
"memory_in_gbs": input_dict["deployment_memory_in_gbs"],
2022+
}
2023+
assert result.runtime.image == "image_passed_using_with"
2024+
assert result.runtime.entrypoint == input_dict["entrypoint"]
2025+
assert result.runtime.server_port == input_dict["server_port"]
2026+
assert result.runtime.health_check_port == input_dict["health_check_port"]
2027+
assert result.runtime.env == {"test_key": "test_value"}
2028+
assert result.runtime.deployment_mode == "HTTPS_ONLY"
2029+
mock_deploy.assert_called_with(
2030+
wait_for_completion=input_dict["wait_for_completion"],
2031+
max_wait_time=input_dict["max_wait_time"],
2032+
poll_interval=input_dict["poll_interval"],
2033+
)
2034+
2035+
assert (
2036+
self.generic_model.properties.compartment_id
2037+
== "compartment_id_passed_using_with"
2038+
)
2039+
assert (
2040+
self.generic_model.properties.project_id == "project_id_passed_using_with"
2041+
)
2042+
assert (
2043+
self.generic_model.properties.deployment_image == "image_passed_using_with"
2044+
)
2045+
19202046

19212047
class TestCommonMethods:
19222048
"""Tests common methods presented in the generic_model module."""

0 commit comments

Comments
 (0)