Skip to content

Commit 8072892

Browse files
authored
Fix/GenericModel.from_model_deployment() fails to load model back (#322)
1 parent 3329287 commit 8072892

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

ads/model/deployment/model_deployment.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,8 @@ def from_id(cls, id: str) -> "ModelDeployment":
13041304
ModelDeployment
13051305
The ModelDeployment instance (self).
13061306
"""
1307-
return cls()._update_from_oci_model(OCIDataScienceModelDeployment.from_id(id))
1307+
oci_model = OCIDataScienceModelDeployment.from_id(id)
1308+
return cls(properties=oci_model)._update_from_oci_model(oci_model)
13081309

13091310
@classmethod
13101311
def from_dict(cls, obj_dict: Dict) -> "ModelDeployment":
@@ -1503,7 +1504,9 @@ def _build_model_deployment_details(self) -> CreateModelDeploymentDetails:
15031504
**create_model_deployment_details
15041505
).to_oci_model(CreateModelDeploymentDetails)
15051506

1506-
def _update_model_deployment_details(self, **kwargs) -> UpdateModelDeploymentDetails:
1507+
def _update_model_deployment_details(
1508+
self, **kwargs
1509+
) -> UpdateModelDeploymentDetails:
15071510
"""Builds UpdateModelDeploymentDetails from model deployment instance.
15081511
15091512
Returns
@@ -1527,7 +1530,7 @@ def _update_model_deployment_details(self, **kwargs) -> UpdateModelDeploymentDet
15271530
return OCIDataScienceModelDeployment(
15281531
**update_model_deployment_details
15291532
).to_oci_model(UpdateModelDeploymentDetails)
1530-
1533+
15311534
def _update_spec(self, **kwargs) -> "ModelDeployment":
15321535
"""Updates model deployment specs from kwargs.
15331536
@@ -1542,7 +1545,7 @@ def _update_spec(self, **kwargs) -> "ModelDeployment":
15421545
Model deployment freeform tags
15431546
defined_tags: (dict)
15441547
Model deployment defined tags
1545-
1548+
15461549
Additional kwargs arguments.
15471550
Can be any attribute that `ads.model.deployment.ModelDeploymentCondaRuntime`, `ads.model.deployment.ModelDeploymentContainerRuntime`
15481551
and `ads.model.deployment.ModelDeploymentInfrastructure` accepts.
@@ -1559,20 +1562,22 @@ def _update_spec(self, **kwargs) -> "ModelDeployment":
15591562
specs = {
15601563
"self": self._spec,
15611564
"runtime": self.runtime._spec,
1562-
"infrastructure": self.infrastructure._spec
1565+
"infrastructure": self.infrastructure._spec,
15631566
}
15641567
sub_set = {
15651568
self.infrastructure.CONST_ACCESS_LOG,
15661569
self.infrastructure.CONST_PREDICT_LOG,
1567-
self.infrastructure.CONST_SHAPE_CONFIG_DETAILS
1570+
self.infrastructure.CONST_SHAPE_CONFIG_DETAILS,
15681571
}
15691572
for spec_value in specs.values():
15701573
for key in spec_value:
15711574
if key in converted_specs:
15721575
if key in sub_set:
15731576
for sub_key in converted_specs[key]:
15741577
converted_sub_key = ads_utils.snake_to_camel(sub_key)
1575-
spec_value[key][converted_sub_key] = converted_specs[key][sub_key]
1578+
spec_value[key][converted_sub_key] = converted_specs[key][
1579+
sub_key
1580+
]
15761581
else:
15771582
spec_value[key] = copy.deepcopy(converted_specs[key])
15781583
self = (
@@ -1616,14 +1621,14 @@ def _build_model_deployment_configuration_details(self) -> Dict:
16161621
infrastructure.CONST_MEMORY_IN_GBS: infrastructure.shape_config_details.get(
16171622
"memory_in_gbs", None
16181623
)
1619-
or infrastructure.shape_config_details.get(
1620-
"memoryInGBs", None
1621-
)
1624+
or infrastructure.shape_config_details.get("memoryInGBs", None)
16221625
or DEFAULT_MEMORY_IN_GBS,
16231626
}
16241627

16251628
if infrastructure.subnet_id:
1626-
instance_configuration[infrastructure.CONST_SUBNET_ID] = infrastructure.subnet_id
1629+
instance_configuration[
1630+
infrastructure.CONST_SUBNET_ID
1631+
] = infrastructure.subnet_id
16271632

16281633
scaling_policy = {
16291634
infrastructure.CONST_POLICY_TYPE: "FIXED_SIZE",
@@ -1638,13 +1643,11 @@ def _build_model_deployment_configuration_details(self) -> Dict:
16381643

16391644
model_id = runtime.model_uri
16401645
if not model_id.startswith("ocid"):
1641-
16421646
from ads.model.datascience_model import DataScienceModel
1643-
1647+
16441648
dsc_model = DataScienceModel(
16451649
name=self.display_name,
1646-
compartment_id=self.infrastructure.compartment_id
1647-
or COMPARTMENT_OCID,
1650+
compartment_id=self.infrastructure.compartment_id or COMPARTMENT_OCID,
16481651
project_id=self.infrastructure.project_id or PROJECT_OCID,
16491652
artifact=runtime.model_uri,
16501653
).create(
@@ -1653,7 +1656,7 @@ def _build_model_deployment_configuration_details(self) -> Dict:
16531656
region=runtime.region,
16541657
overwrite_existing_artifact=runtime.overwrite_existing_artifact,
16551658
remove_existing_artifact=runtime.remove_existing_artifact,
1656-
timeout=runtime.timeout
1659+
timeout=runtime.timeout,
16571660
)
16581661
model_id = dsc_model.id
16591662

0 commit comments

Comments
 (0)