Skip to content

Commit 6ea37e5

Browse files
authored
Merge branch 'main' into ODSC-64829/enhanced-auto-select-report
2 parents 90b7fa8 + eabb40e commit 6ea37e5

File tree

25 files changed

+248
-111
lines changed

25 files changed

+248
-111
lines changed

ads/aqua/extension/ui_handler.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def get(self, id=""):
6868
return self.list_buckets()
6969
elif paths.startswith("aqua/job/shapes"):
7070
return self.list_job_shapes()
71+
elif paths.startswith("aqua/modeldeployment/shapes"):
72+
return self.list_model_deployment_shapes()
7173
elif paths.startswith("aqua/vcn"):
7274
return self.list_vcn()
7375
elif paths.startswith("aqua/subnets"):
@@ -160,6 +162,15 @@ def list_job_shapes(self, **kwargs):
160162
AquaUIApp().list_job_shapes(compartment_id=compartment_id, **kwargs)
161163
)
162164

165+
def list_model_deployment_shapes(self, **kwargs):
166+
"""Lists model deployment shapes available in the specified compartment."""
167+
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
168+
return self.finish(
169+
AquaUIApp().list_model_deployment_shapes(
170+
compartment_id=compartment_id, **kwargs
171+
)
172+
)
173+
163174
def list_vcn(self, **kwargs):
164175
"""Lists the virtual cloud networks (VCNs) in the specified compartment."""
165176
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
@@ -255,8 +266,9 @@ def post(self, *args, **kwargs):
255266
__handlers__ = [
256267
("logging/?([^/]*)", AquaUIHandler),
257268
("compartments/?([^/]*)", AquaUIHandler),
258-
# TODO: change url to evaluation/experiements/?([^/]*)
269+
# TODO: change url to evaluation/experiments/?([^/]*)
259270
("experiment/?([^/]*)", AquaUIHandler),
271+
("modeldeployment/?([^/]*)", AquaUIHandler),
260272
("versionsets/?([^/]*)", AquaUIHandler),
261273
("buckets/?([^/]*)", AquaUIHandler),
262274
("job/shapes/?([^/]*)", AquaUIHandler),

ads/aqua/modeldeployment/deployment.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def create(
185185
tags[tag] = aqua_model.freeform_tags[tag]
186186

187187
tags.update({Tags.AQUA_MODEL_NAME_TAG: aqua_model.display_name})
188-
tags.update({Tags.TASK: aqua_model.freeform_tags.get(Tags.TASK, None)})
188+
tags.update({Tags.TASK: aqua_model.freeform_tags.get(Tags.TASK, UNKNOWN)})
189189

190190
# Set up info to get deployment config
191191
config_source_id = model_id
@@ -533,16 +533,22 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
533533
return results
534534

535535
@telemetry(entry_point="plugin=deployment&action=delete", name="aqua")
536-
def delete(self,model_deployment_id:str):
537-
return self.ds_client.delete_model_deployment(model_deployment_id=model_deployment_id).data
536+
def delete(self, model_deployment_id: str):
537+
return self.ds_client.delete_model_deployment(
538+
model_deployment_id=model_deployment_id
539+
).data
538540

539-
@telemetry(entry_point="plugin=deployment&action=deactivate",name="aqua")
540-
def deactivate(self,model_deployment_id:str):
541-
return self.ds_client.deactivate_model_deployment(model_deployment_id=model_deployment_id).data
541+
@telemetry(entry_point="plugin=deployment&action=deactivate", name="aqua")
542+
def deactivate(self, model_deployment_id: str):
543+
return self.ds_client.deactivate_model_deployment(
544+
model_deployment_id=model_deployment_id
545+
).data
542546

543-
@telemetry(entry_point="plugin=deployment&action=activate",name="aqua")
544-
def activate(self,model_deployment_id:str):
545-
return self.ds_client.activate_model_deployment(model_deployment_id=model_deployment_id).data
547+
@telemetry(entry_point="plugin=deployment&action=activate", name="aqua")
548+
def activate(self, model_deployment_id: str):
549+
return self.ds_client.activate_model_deployment(
550+
model_deployment_id=model_deployment_id
551+
).data
546552

547553
@telemetry(entry_point="plugin=deployment&action=get", name="aqua")
548554
def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":

ads/aqua/ui.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,12 +481,12 @@ def _is_bucket_versioned(response):
481481

482482
@telemetry(entry_point="plugin=ui&action=list_job_shapes", name="aqua")
483483
def list_job_shapes(self, **kwargs) -> list:
484-
"""Lists all availiable job shapes for the specified compartment.
484+
"""Lists all available job shapes for the specified compartment.
485485
486486
Parameters
487487
----------
488488
**kwargs
489-
Addtional arguments, such as `compartment_id`,
489+
Additional arguments, such as `compartment_id`,
490490
for `list_job_shapes <https://docs.oracle.com/en-us/iaas/tools/python/2.122.0/api/data_science/client/oci.data_science.DataScienceClient.html#oci.data_science.DataScienceClient.list_job_shapes>`_
491491
492492
Returns
@@ -500,6 +500,28 @@ def list_job_shapes(self, **kwargs) -> list:
500500
).data
501501
return sanitize_response(oci_client=self.ds_client, response=res)
502502

503+
@telemetry(entry_point="plugin=ui&action=list_model_deployment_shapes", name="aqua")
504+
def list_model_deployment_shapes(self, **kwargs) -> list:
505+
"""Lists all available shapes for model deployment in the specified compartment.
506+
507+
Parameters
508+
----------
509+
**kwargs
510+
Additional arguments, such as `compartment_id`,
511+
for `list_model_deployment_shapes <https://docs.oracle.com/en-us/iaas/api/#/en/data-science/20190101/ModelDeploymentShapeSummary/ListModelDeploymentShapes>`_
512+
513+
Returns
514+
-------
515+
str has json representation of `oci.data_science.models.ModelDeploymentShapeSummary`."""
516+
compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
517+
logger.info(
518+
f"Loading model deployment shape summary from compartment: {compartment_id}"
519+
)
520+
res = self.ds_client.list_model_deployment_shapes(
521+
compartment_id=compartment_id, **kwargs
522+
).data
523+
return sanitize_response(oci_client=self.ds_client, response=res)
524+
503525
@telemetry(entry_point="plugin=ui&action=list_vcn", name="aqua")
504526
def list_vcn(self, **kwargs) -> list:
505527
"""Lists the virtual cloud networks (VCNs) in the specified compartment.

ads/opctl/operator/common/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

43
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -18,23 +17,26 @@
1817
from cerberus import Validator
1918

2019
from ads.opctl import logger, utils
21-
from ads.opctl.operator import __operators__
2220

2321
CONTAINER_NETWORK = "CONTAINER_NETWORK"
2422

2523

2624
class OperatorValidator(Validator):
2725
"""The custom validator class."""
2826

29-
pass
27+
def validate(self, obj_dict, **kwargs):
28+
# Model should be case insensitive
29+
if "model" in obj_dict["spec"]:
30+
obj_dict["spec"]["model"] = str(obj_dict["spec"]["model"]).lower()
31+
return super().validate(obj_dict, **kwargs)
3032

3133

3234
def create_output_folder(name):
3335
output_folder = name
3436
protocol = fsspec.utils.get_protocol(output_folder)
3537
storage_options = {}
3638
if protocol != "file":
37-
storage_options = auth or default_signer()
39+
storage_options = default_signer()
3840

3941
fs = fsspec.filesystem(protocol, **storage_options)
4042
name_suffix = 1

ads/opctl/operator/lowcode/anomaly/model/base_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,8 @@ def generate_report(self):
166166
yaml_appendix = rc.Yaml(self.config.to_dict())
167167
summary = rc.Block(
168168
rc.Group(
169-
rc.Text(
170-
f"You selected the **`{self.spec.model}`** model.\n{model_description.text}\n"
171-
),
169+
rc.Text(f"You selected the **`{self.spec.model}`** model.\n"),
170+
model_description,
172171
rc.Text(
173172
"Based on your dataset, you could have also selected "
174173
f"any of the models: `{'`, `'.join(SupportedModels.keys() if self.spec.datetime_column else NonTimeADSupportedModels.keys())}`."

ads/opctl/operator/lowcode/anomaly/model/factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ class UnSupportedModelError(Exception):
2626

2727
def __init__(self, operator_config: AnomalyOperatorConfig, model_type: str):
2828
supported_models = (
29-
SupportedModels.values
29+
SupportedModels.values()
3030
if operator_config.spec.datetime_column
31-
else NonTimeADSupportedModels.values
31+
else NonTimeADSupportedModels.values()
3232
)
3333
message = (
3434
f"Model: `{model_type}` is not supported. "

ads/opctl/operator/lowcode/common/transformations.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

6+
from abc import ABC
7+
8+
import pandas as pd
9+
710
from ads.opctl import logger
11+
from ads.opctl.operator.lowcode.common.const import DataColumns
812
from ads.opctl.operator.lowcode.common.errors import (
9-
InvalidParameterError,
1013
DataMismatchError,
14+
InvalidParameterError,
1115
)
12-
from ads.opctl.operator.lowcode.common.const import DataColumns
1316
from ads.opctl.operator.lowcode.common.utils import merge_category_columns
14-
import pandas as pd
15-
from abc import ABC
1617

1718

1819
class Transformations(ABC):
@@ -58,6 +59,7 @@ def run(self, data):
5859
5960
"""
6061
clean_df = self._remove_trailing_whitespace(data)
62+
# clean_df = self._normalize_column_names(clean_df)
6163
if self.name == "historical_data":
6264
self._check_historical_dataset(clean_df)
6365
clean_df = self._set_series_id_column(clean_df)
@@ -95,8 +97,11 @@ def run(self, data):
9597
def _remove_trailing_whitespace(self, df):
9698
return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)
9799

100+
# def _normalize_column_names(self, df):
101+
# return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x))
102+
98103
def _set_series_id_column(self, df):
99-
self._target_category_columns_map = dict()
104+
self._target_category_columns_map = {}
100105
if not self.target_category_columns:
101106
df[DataColumns.Series] = "Series 1"
102107
self.has_artificial_series = True
@@ -125,10 +130,10 @@ def _format_datetime_col(self, df):
125130
df[self.dt_column_name] = pd.to_datetime(
126131
df[self.dt_column_name], format=self.dt_column_format
127132
)
128-
except:
133+
except Exception as ee:
129134
raise InvalidParameterError(
130135
f"Unable to determine the datetime type for column: {self.dt_column_name} in dataset: {self.name}. Please specify the format explicitly. (For example adding 'format: %d/%m/%Y' underneath 'name: {self.dt_column_name}' in the datetime_column section of the yaml file if you haven't already. For reference, here is the first datetime given: {df[self.dt_column_name].values[0]}"
131-
)
136+
) from ee
132137
return df
133138

134139
def _set_multi_index(self, df):
@@ -242,7 +247,6 @@ def _check_historical_dataset(self, df):
242247
"Class": "A",
243248
"Num": 2
244249
},
245-
246250
}
247251
"""
248252

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,32 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

43
# Copyright (c) 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

7-
import argparse
86
import logging
97
import os
108
import shutil
119
import sys
1210
import tempfile
13-
import time
14-
from string import Template
15-
from typing import Any, Dict, List, Tuple
16-
import pandas as pd
17-
from ads.opctl import logger
18-
import oracledb
11+
from typing import List, Union
1912

2013
import fsspec
21-
import yaml
22-
from typing import Union
14+
import oracledb
15+
import pandas as pd
2316

17+
from ads.common.object_storage_details import ObjectStorageDetails
2418
from ads.opctl import logger
19+
from ads.opctl.operator.common.operator_config import OutputDirectory
2520
from ads.opctl.operator.lowcode.common.errors import (
26-
InputDataError,
2721
InvalidParameterError,
28-
PermissionsError,
29-
DataMismatchError,
3022
)
31-
from ads.opctl.operator.common.operator_config import OutputDirectory
32-
from ads.common.object_storage_details import ObjectStorageDetails
3323
from ads.secrets import ADBSecretKeeper
3424

3525

3626
def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):
37-
if fsspec.utils.get_protocol(filename) == "file":
38-
return pd_fn(filename, **kwargs)
39-
elif fsspec.utils.get_protocol(filename) in ["http", "https"]:
27+
if fsspec.utils.get_protocol(filename) == "file" or fsspec.utils.get_protocol(
28+
filename
29+
) in ["http", "https"]:
4030
return pd_fn(filename, **kwargs)
4131

4232
storage_options = storage_options or (
@@ -48,7 +38,7 @@ def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):
4838

4939
def load_data(data_spec, storage_options=None, **kwargs):
5040
if data_spec is None:
51-
raise InvalidParameterError(f"No details provided for this data source.")
41+
raise InvalidParameterError("No details provided for this data source.")
5242
filename = data_spec.url
5343
format = data_spec.format
5444
columns = data_spec.columns
@@ -67,7 +57,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
6757
if not format:
6858
_, format = os.path.splitext(filename)
6959
format = format[1:]
70-
if format in ["json", "clipboard", "excel", "csv", "feather", "hdf"]:
60+
if format in ["json", "clipboard", "excel", "csv", "feather", "hdf", "parquet"]:
7161
read_fn = getattr(pd, f"read_{format}")
7262
data = call_pandas_fsspec(
7363
read_fn, filename, storage_options=storage_options
@@ -84,19 +74,31 @@ def load_data(data_spec, storage_options=None, **kwargs):
8474
with tempfile.TemporaryDirectory() as temp_dir:
8575
if vault_secret_id is not None:
8676
try:
87-
with ADBSecretKeeper.load_secret(vault_secret_id, wallet_dir=temp_dir) as adwsecret:
88-
if 'wallet_location' in adwsecret and 'wallet_location' not in connect_args:
89-
shutil.unpack_archive(adwsecret["wallet_location"], temp_dir)
90-
connect_args['wallet_location'] = temp_dir
91-
if 'user_name' in adwsecret and 'user' not in connect_args:
92-
connect_args['user'] = adwsecret['user_name']
93-
if 'password' in adwsecret and 'password' not in connect_args:
94-
connect_args['password'] = adwsecret['password']
95-
if 'service_name' in adwsecret and 'service_name' not in connect_args:
96-
connect_args['service_name'] = adwsecret['service_name']
77+
with ADBSecretKeeper.load_secret(
78+
vault_secret_id, wallet_dir=temp_dir
79+
) as adwsecret:
80+
if (
81+
"wallet_location" in adwsecret
82+
and "wallet_location" not in connect_args
83+
):
84+
shutil.unpack_archive(
85+
adwsecret["wallet_location"], temp_dir
86+
)
87+
connect_args["wallet_location"] = temp_dir
88+
if "user_name" in adwsecret and "user" not in connect_args:
89+
connect_args["user"] = adwsecret["user_name"]
90+
if "password" in adwsecret and "password" not in connect_args:
91+
connect_args["password"] = adwsecret["password"]
92+
if (
93+
"service_name" in adwsecret
94+
and "service_name" not in connect_args
95+
):
96+
connect_args["service_name"] = adwsecret["service_name"]
9797

9898
except Exception as e:
99-
raise Exception(f"Could not retrieve database credentials from vault {vault_secret_id}: {e}")
99+
raise Exception(
100+
f"Could not retrieve database credentials from vault {vault_secret_id}: {e}"
101+
)
100102

101103
con = oracledb.connect(**connect_args)
102104
if table_name is not None:
@@ -105,11 +107,11 @@ def load_data(data_spec, storage_options=None, **kwargs):
105107
data = pd.read_sql(sql, con)
106108
else:
107109
raise InvalidParameterError(
108-
f"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`."
110+
"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`."
109111
)
110112
else:
111113
raise InvalidParameterError(
112-
f"No filename/url provided, and no connect_args provided. Please specify one of these if you want to read data from a file or a database respectively."
114+
"No filename/url provided, and no connect_args provided. Please specify one of these if you want to read data from a file or a database respectively."
113115
)
114116
if columns:
115117
# keep only these columns, done after load because only CSV supports stream filtering
@@ -232,7 +234,7 @@ def human_time_friendly(seconds):
232234
accumulator.append(
233235
"{} {}{}".format(int(amount), unit, "" if amount == 1 else "s")
234236
)
235-
accumulator.append("{} secs".format(round(seconds, 2)))
237+
accumulator.append(f"{round(seconds, 2)} secs")
236238
return ", ".join(accumulator)
237239

238240

@@ -248,9 +250,7 @@ def find_output_dirname(output_dir: OutputDirectory):
248250
unique_output_dir = f"{output_dir}_{counter}"
249251
counter += 1
250252
logger.warn(
251-
"Since the output directory was not specified, the output will be saved to {} directory.".format(
252-
unique_output_dir
253-
)
253+
f"Since the output directory was not specified, the output will be saved to {unique_output_dir} directory."
254254
)
255255
return unique_output_dir
256256

0 commit comments

Comments
 (0)