Skip to content

Commit 512d2d0

Browse files
committed
Prioritize generic model deploy parameters.
1 parent 7f49d7e commit 512d2d0

File tree

2 files changed

+197
-85
lines changed

2 files changed

+197
-85
lines changed

ads/model/generic_model.py

Lines changed: 71 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -2101,7 +2101,7 @@ def deploy(
21012101
if not display_name:
21022102
display_name = utils.get_random_name_for_resource()
21032103
# populates properties from args and kwargs. Empty values will be ignored.
2104-
self.properties.with_dict(_extract_locals(locals()))
2104+
override_properties = _extract_locals(locals())
21052105
# clears out project_id and compartment_id from kwargs, to prevent passing
21062106
# these params to the deployment via kwargs.
21072107
kwargs.pop("project_id", None)
@@ -2110,19 +2110,51 @@ def deploy(
21102110
max_wait_time = kwargs.pop("max_wait_time", DEFAULT_WAIT_TIME)
21112111
poll_interval = kwargs.pop("poll_interval", DEFAULT_POLL_INTERVAL)
21122112

2113-
self.properties.compartment_id = (
2114-
self.properties.compartment_id or _COMPARTMENT_OCID
2115-
)
2116-
self.properties.project_id = self.properties.project_id or PROJECT_OCID
2117-
self.properties.deployment_instance_shape = (
2118-
self.properties.deployment_instance_shape or MODEL_DEPLOYMENT_INSTANCE_SHAPE
2119-
)
2120-
self.properties.deployment_instance_count = (
2121-
self.properties.deployment_instance_count or MODEL_DEPLOYMENT_INSTANCE_COUNT
2122-
)
2123-
self.properties.deployment_bandwidth_mbps = (
2124-
self.properties.deployment_bandwidth_mbps or MODEL_DEPLOYMENT_BANDWIDTH_MBPS
2125-
)
2113+
existing_infrastructure = self.model_deployment.infrastructure
2114+
existing_runtime = self.model_deployment.runtime
2115+
properties = {
2116+
"compartment_id": existing_infrastructure.compartment_id
2117+
or self.properties.compartment_id
2118+
or _COMPARTMENT_OCID,
2119+
"project_id": existing_infrastructure.project_id
2120+
or self.properties.project_id
2121+
or PROJECT_OCID,
2122+
"deployment_instance_shape": existing_infrastructure.shape_name
2123+
or self.properties.deployment_instance_shape
2124+
or MODEL_DEPLOYMENT_INSTANCE_SHAPE,
2125+
"deployment_instance_count": existing_infrastructure.replica
2126+
or self.properties.deployment_instance_count
2127+
or MODEL_DEPLOYMENT_INSTANCE_COUNT,
2128+
"deployment_bandwidth_mbps": existing_infrastructure.bandwidth_mbps
2129+
or self.properties.deployment_bandwidth_mbps
2130+
or MODEL_DEPLOYMENT_BANDWIDTH_MBPS,
2131+
"deployment_ocpus": existing_infrastructure.shape_config_details.get(
2132+
"ocpus", None
2133+
)
2134+
or self.properties.deployment_ocpus
2135+
or MODEL_DEPLOYMENT_INSTANCE_OCPUS,
2136+
"deployment_memory_in_gbs": existing_infrastructure.shape_config_details.get(
2137+
"memory_in_gbs", None
2138+
)
2139+
or self.properties.deployment_memory_in_gbs
2140+
or MODEL_DEPLOYMENT_INSTANCE_MEMORY_IN_GBS,
2141+
"deployment_log_group_id": existing_infrastructure.log_group_id
2142+
or self.properties.deployment_log_group_id,
2143+
"deployment_access_log_id": existing_infrastructure.access_log.get(
2144+
"log_id", None
2145+
)
2146+
or self.properties.deployment_access_log_id,
2147+
"deployment_predict_log_id": existing_infrastructure.predict_log.get(
2148+
"log_id", None
2149+
)
2150+
or self.properties.deployment_predict_log_id,
2151+
"deployment_image": existing_runtime.image
2152+
or self.properties.deployment_image,
2153+
"deployment_instance_subnet_id": existing_infrastructure.subnet_id
2154+
or self.properties.deployment_instance_subnet_id
2155+
}
2156+
properties.update(override_properties)
2157+
self.properties.with_dict(properties)
21262158

21272159
if not self.model_id:
21282160
raise ValueError(
@@ -2140,104 +2172,58 @@ def deploy(
21402172
"cannot be used without `deployment_log_group_id`."
21412173
)
21422174

2143-
existing_infrastructure = self.model_deployment.infrastructure
2144-
existing_runtime = self.model_deployment.runtime
2145-
2146-
web_concurrency = (
2147-
kwargs.pop("web_concurrency", None)
2148-
or existing_infrastructure.web_concurrency
2149-
)
2150-
if not (
2151-
self.properties.compartment_id or existing_infrastructure.compartment_id
2152-
):
2175+
if not self.properties.compartment_id:
21532176
raise ValueError("`compartment_id` has to be provided.")
2154-
if not (self.properties.project_id or existing_infrastructure.project_id):
2177+
if not self.properties.project_id:
21552178
raise ValueError("`project_id` has to be provided.")
21562179
infrastructure = (
21572180
ModelDeploymentInfrastructure()
2158-
.with_compartment_id(
2159-
self.properties.compartment_id or existing_infrastructure.compartment_id
2160-
)
2161-
.with_project_id(
2162-
self.properties.project_id or existing_infrastructure.project_id
2163-
)
2164-
.with_bandwidth_mbps(
2165-
self.properties.deployment_bandwidth_mbps
2166-
or existing_infrastructure.bandwidth_mbps
2167-
)
2168-
.with_shape_name(
2169-
self.properties.deployment_instance_shape
2170-
or existing_infrastructure.shape_name
2171-
)
2172-
.with_subnet_id(
2173-
self.properties.deployment_instance_subnet_id
2174-
or existing_infrastructure.subnet_id
2175-
)
2176-
.with_replica(
2177-
self.properties.deployment_instance_count
2178-
or existing_infrastructure.replica
2179-
)
2180-
.with_web_concurrency(web_concurrency)
2181+
.with_compartment_id(self.properties.compartment_id)
2182+
.with_project_id(self.properties.project_id)
2183+
.with_bandwidth_mbps(self.properties.deployment_bandwidth_mbps)
2184+
.with_shape_name(self.properties.deployment_instance_shape)
2185+
.with_replica(self.properties.deployment_instance_count)
2186+
.with_subnet_id(self.properties.deployment_instance_subnet_id)
21812187
)
21822188

2183-
ocpus = (
2184-
self.properties.deployment_ocpus
2185-
or existing_infrastructure.shape_config_details.get("ocpus")
2186-
)
2187-
memory_in_gbs = (
2188-
self.properties.deployment_memory_in_gbs
2189-
or existing_infrastructure.shape_config_details.get("memory_in_gbs")
2189+
web_concurrency = (
2190+
kwargs.pop("web_concurrency", None)
2191+
or existing_infrastructure.web_concurrency
21902192
)
2193+
if web_concurrency:
2194+
infrastructure.with_web_concurrency(web_concurrency)
21912195

21922196
if infrastructure.shape_name.endswith("Flex"):
21932197
infrastructure.with_shape_config_details(
2194-
ocpus=ocpus or MODEL_DEPLOYMENT_INSTANCE_OCPUS,
2195-
memory_in_gbs=memory_in_gbs or MODEL_DEPLOYMENT_INSTANCE_MEMORY_IN_GBS,
2198+
ocpus=self.properties.deployment_ocpus,
2199+
memory_in_gbs=self.properties.deployment_memory_in_gbs,
21962200
)
21972201

2198-
access_log_id = (
2199-
self.properties.deployment_access_log_id
2200-
or existing_infrastructure.access_log.get("log_id")
2201-
)
2202-
access_log_group_id = (
2203-
self.properties.deployment_log_group_id
2204-
or existing_infrastructure.access_log.get("log_group_id")
2205-
)
2206-
22072202
# specifies the access log id
2208-
if access_log_id:
2203+
if self.properties.deployment_access_log_id:
22092204
infrastructure.with_access_log(
2210-
log_group_id=access_log_group_id,
2211-
log_id=access_log_id,
2205+
log_group_id=self.properties.deployment_log_group_id,
2206+
log_id=self.properties.deployment_access_log_id,
22122207
)
22132208

2214-
predict_log_id = (
2215-
self.properties.deployment_predict_log_id
2216-
or existing_infrastructure.predict_log.get("log_id")
2217-
)
2218-
predict_log_group_id = (
2219-
self.properties.deployment_log_group_id
2220-
or existing_infrastructure.predict_log.get("log_group_id")
2221-
)
2222-
22232209
# specifies the predict log id
2224-
if predict_log_id:
2210+
if self.properties.deployment_predict_log_id:
22252211
infrastructure.with_predict_log(
2226-
log_group_id=predict_log_group_id,
2227-
log_id=predict_log_id,
2212+
log_group_id=self.properties.deployment_log_group_id,
2213+
log_id=self.properties.deployment_predict_log_id,
22282214
)
22292215

22302216
environment_variables = (
22312217
kwargs.pop("environment_variables", {}) or existing_runtime.env
22322218
)
22332219
deployment_mode = (
2234-
kwargs.pop("deployment_mode", ModelDeploymentMode.HTTPS)
2220+
kwargs.pop("deployment_mode", None)
22352221
or existing_runtime.deployment_mode
2222+
or ModelDeploymentMode.HTTPS
22362223
)
22372224

22382225
runtime = None
2239-
image = self.properties.deployment_image or existing_runtime.image
2240-
if image:
2226+
if self.properties.deployment_image:
22412227
image_digest = (
22422228
kwargs.pop("image_digest", None) or existing_runtime.image_digest
22432229
)
@@ -2252,7 +2238,7 @@ def deploy(
22522238
)
22532239
runtime = (
22542240
ModelDeploymentContainerRuntime()
2255-
.with_image(image)
2241+
.with_image(self.properties.deployment_image)
22562242
.with_image_digest(image_digest)
22572243
.with_cmd(cmd)
22582244
.with_entrypoint(entrypoint)

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,6 +1917,132 @@ def test_update(
19171917
self.generic_model.dsc_model.schema_output == self.mock_dsc_model.output_schema
19181918
self.generic_model.dsc_model.metadata_provenance == self.mock_dsc_model.provenance_metadata
19191919

1920+
@patch.object(ModelDeployment, "deploy")
1921+
def test_deploy_combined_initialization(self, mock_deploy):
1922+
self.generic_model.properties = ModelProperties(
1923+
deployment_image="default_test_docker_image",
1924+
compartment_id="default_test_compartment_id",
1925+
project_id="default_test_project_id",
1926+
)
1927+
test_model_id = "ocid.test_model_id"
1928+
self.generic_model.dsc_model = MagicMock(id=test_model_id)
1929+
self.generic_model.ignore_conda_error = True
1930+
infrastructure = ModelDeploymentInfrastructure(
1931+
**{
1932+
"shape_name": "test_deployment_instance_shape",
1933+
"replica": 10,
1934+
"bandwidth_mbps": 100,
1935+
"shape_config_details": {"memory_in_gbs": 10, "ocpus": 1},
1936+
"access_log": {
1937+
"log_group_id": "test_deployment_log_group_id",
1938+
"log_id": "test_deployment_access_log_id",
1939+
},
1940+
"predict_log": {
1941+
"log_group_id": "test_deployment_log_group_id",
1942+
"log_id": "test_deployment_predict_log_id",
1943+
},
1944+
"project_id": "project_id_passed_using_with",
1945+
"compartment_id": "compartment_id_passed_using_with",
1946+
}
1947+
)
1948+
runtime = ModelDeploymentContainerRuntime(
1949+
**{
1950+
"image": "image_passed_using_with",
1951+
"image_digest": "test_image_digest",
1952+
"cmd": ["test_cmd"],
1953+
"entrypoint": ["test_entrypoint"],
1954+
"server_port": 8080,
1955+
"health_check_port": 8080,
1956+
"env": {"test_key": "test_value"},
1957+
"deployment_mode": "HTTPS_ONLY",
1958+
}
1959+
)
1960+
mock_deploy.return_value = ModelDeployment(
1961+
display_name="test_display_name",
1962+
description="test_description",
1963+
infrastructure=infrastructure,
1964+
runtime=runtime,
1965+
model_deployment_url="test_model_deployment_url",
1966+
model_deployment_id="test_model_deployment_id",
1967+
)
1968+
input_dict = {
1969+
"wait_for_completion": True,
1970+
"display_name": "test_display_name",
1971+
"description": "test_description",
1972+
"deployment_instance_shape": "test_deployment_instance_shape",
1973+
"deployment_instance_count": 10,
1974+
"deployment_bandwidth_mbps": 100,
1975+
"deployment_memory_in_gbs": 10,
1976+
"deployment_ocpus": 1,
1977+
"deployment_log_group_id": "test_deployment_log_group_id",
1978+
"deployment_access_log_id": "test_deployment_access_log_id",
1979+
"deployment_predict_log_id": "test_deployment_predict_log_id",
1980+
"cmd": ["test_cmd"],
1981+
"entrypoint": ["test_entrypoint"],
1982+
"server_port": 8080,
1983+
"health_check_port": 8080,
1984+
"environment_variables": {"test_key": "test_value"},
1985+
"max_wait_time": 100,
1986+
"poll_interval": 200,
1987+
}
1988+
1989+
self.generic_model.model_deployment.infrastructure.with_compartment_id(
1990+
"compartment_id_passed_using_with"
1991+
).with_project_id("project_id_passed_using_with")
1992+
self.generic_model.model_deployment.runtime.with_image(
1993+
"image_passed_using_with"
1994+
)
1995+
1996+
result = self.generic_model.deploy(
1997+
**input_dict,
1998+
)
1999+
assert result == mock_deploy.return_value
2000+
assert result.infrastructure.access_log == {
2001+
"log_id": input_dict["deployment_access_log_id"],
2002+
"log_group_id": input_dict["deployment_log_group_id"],
2003+
}
2004+
assert result.infrastructure.predict_log == {
2005+
"log_id": input_dict["deployment_predict_log_id"],
2006+
"log_group_id": input_dict["deployment_log_group_id"],
2007+
}
2008+
assert (
2009+
result.infrastructure.bandwidth_mbps
2010+
== input_dict["deployment_bandwidth_mbps"]
2011+
)
2012+
assert (
2013+
result.infrastructure.compartment_id == "compartment_id_passed_using_with"
2014+
)
2015+
assert result.infrastructure.project_id == "project_id_passed_using_with"
2016+
assert (
2017+
result.infrastructure.shape_name == input_dict["deployment_instance_shape"]
2018+
)
2019+
assert result.infrastructure.shape_config_details == {
2020+
"ocpus": input_dict["deployment_ocpus"],
2021+
"memory_in_gbs": input_dict["deployment_memory_in_gbs"],
2022+
}
2023+
assert result.runtime.image == "image_passed_using_with"
2024+
assert result.runtime.entrypoint == input_dict["entrypoint"]
2025+
assert result.runtime.server_port == input_dict["server_port"]
2026+
assert result.runtime.health_check_port == input_dict["health_check_port"]
2027+
assert result.runtime.env == {"test_key": "test_value"}
2028+
assert result.runtime.deployment_mode == "HTTPS_ONLY"
2029+
mock_deploy.assert_called_with(
2030+
wait_for_completion=input_dict["wait_for_completion"],
2031+
max_wait_time=input_dict["max_wait_time"],
2032+
poll_interval=input_dict["poll_interval"],
2033+
)
2034+
2035+
assert (
2036+
self.generic_model.properties.compartment_id
2037+
== "compartment_id_passed_using_with"
2038+
)
2039+
assert (
2040+
self.generic_model.properties.project_id == "project_id_passed_using_with"
2041+
)
2042+
assert (
2043+
self.generic_model.properties.deployment_image == "image_passed_using_with"
2044+
)
2045+
19202046

19212047
class TestCommonMethods:
19222048
"""Tests common methods presented in the generic_model module."""

0 commit comments

Comments
 (0)