Skip to content

Commit 2e620e2

Browse files
committed
search models and update
1 parent 4d30232 commit 2e620e2

File tree

5 files changed

+211
-155
lines changed

5 files changed

+211
-155
lines changed

ads/feature_store/common/utils/utility.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88

99
from typing import Union, List
1010

11+
import oci.regions
1112
from great_expectations.core import ExpectationSuite
1213

1314
from ads.common.decorator.runtime_dependency import OptionalDependency
15+
from ads.common.oci_resource import OCIResource, SEARCH_TYPE
1416
from ads.feature_store.common.utils.feature_schema_mapper import (
1517
map_spark_type_to_feature_type,
1618
map_feature_type_to_pandas,
@@ -19,7 +21,7 @@
1921
from ads.feature_store.feature_group_expectation import Rule, Expectation
2022
from ads.feature_store.input_feature_detail import FeatureDetail
2123
from ads.feature_store.common.spark_session_singleton import SparkSessionSingleton
22-
import re
24+
2325
try:
2426
from pyspark.pandas import DataFrame
2527
except ModuleNotFoundError:
@@ -47,7 +49,7 @@
4749

4850

4951
def get_execution_engine_type(
50-
data_frame: Union[DataFrame, pd.DataFrame]
52+
data_frame: Union[DataFrame, pd.DataFrame]
5153
) -> ExecutionEngine:
5254
"""
5355
Determines the execution engine type for a given DataFrame.
@@ -87,7 +89,7 @@ def get_metastore_id(feature_store_id: str):
8789

8890

8991
def validate_delta_format_parameters(
90-
timestamp: datetime = None, version_number: int = None, is_restore: bool = False
92+
timestamp: datetime = None, version_number: int = None, is_restore: bool = False
9193
):
9294
"""
9395
Validate the user input provided as part of preview, restore APIs for ingested data, Ingested data is
@@ -121,9 +123,9 @@ def validate_delta_format_parameters(
121123

122124

123125
def show_ingestion_summary(
124-
entity_id: str,
125-
entity_type: EntityType = EntityType.FEATURE_GROUP,
126-
error_details: str = None,
126+
entity_id: str,
127+
entity_type: EntityType = EntityType.FEATURE_GROUP,
128+
error_details: str = None,
127129
):
128130
"""
129131
Displays a ingestion summary table with the given entity type and error details.
@@ -163,7 +165,7 @@ def show_validation_summary(ingestion_status: str, validation_output, expectatio
163165
statistics = validation_output["statistics"]
164166

165167
table_headers = (
166-
["expectation_type"] + list(statistics.keys()) + ["ingestion_status"]
168+
["expectation_type"] + list(statistics.keys()) + ["ingestion_status"]
167169
)
168170

169171
table_values = [expectation_type] + list(statistics.values()) + [ingestion_status]
@@ -207,9 +209,9 @@ def show_validation_summary(ingestion_status: str, validation_output, expectatio
207209

208210

209211
def get_features(
210-
output_columns: List[dict],
211-
parent_id: str,
212-
entity_type: EntityType = EntityType.FEATURE_GROUP,
212+
output_columns: List[dict],
213+
parent_id: str,
214+
entity_type: EntityType = EntityType.FEATURE_GROUP,
213215
) -> List[Feature]:
214216
"""
215217
Returns a list of features, given a list of output_columns and a feature_group_id.
@@ -266,7 +268,7 @@ def get_schema_from_spark_df(df: DataFrame):
266268

267269

268270
def get_schema_from_df(
269-
data_frame: Union[DataFrame, pd.DataFrame], feature_store_id: str
271+
data_frame: Union[DataFrame, pd.DataFrame], feature_store_id: str
270272
) -> List[dict]:
271273
"""
272274
Given a DataFrame, returns a list of dictionaries that describe its schema.
@@ -280,7 +282,7 @@ def get_schema_from_df(
280282

281283

282284
def get_input_features_from_df(
283-
data_frame: Union[DataFrame, pd.DataFrame], feature_store_id: str
285+
data_frame: Union[DataFrame, pd.DataFrame], feature_store_id: str
284286
) -> List[FeatureDetail]:
285287
"""
286288
Given a DataFrame, returns a list of FeatureDetail objects that represent its input features.
@@ -297,7 +299,7 @@ def get_input_features_from_df(
297299

298300

299301
def convert_expectation_suite_to_expectation(
300-
expectation_suite: ExpectationSuite, expectation_type: ExpectationType
302+
expectation_suite: ExpectationSuite, expectation_type: ExpectationType
301303
):
302304
"""
303305
Convert an ExpectationSuite object to an Expectation object with detailed rule information.
@@ -356,7 +358,7 @@ def largest_matching_subset_of_primary_keys(left_feature_group, right_feature_gr
356358

357359

358360
def convert_pandas_datatype_with_schema(
359-
raw_feature_details: List[dict], input_df: pd.DataFrame
361+
raw_feature_details: List[dict], input_df: pd.DataFrame
360362
) -> pd.DataFrame:
361363
feature_detail_map = {}
362364
columns_to_remove = []
@@ -381,7 +383,7 @@ def convert_pandas_datatype_with_schema(
381383

382384

383385
def convert_spark_dataframe_with_schema(
384-
raw_feature_details: List[dict], input_df: DataFrame
386+
raw_feature_details: List[dict], input_df: DataFrame
385387
) -> DataFrame:
386388
feature_detail_map = {}
387389
columns_to_remove = []
@@ -403,10 +405,35 @@ def validate_input_feature_details(input_feature_details, data_frame):
403405
return convert_spark_dataframe_with_schema(input_feature_details, data_frame)
404406

405407

406-
def validate_model_ocid(model_ocid):
407-
pattern = r'^ocid1\.datasciencemodel\.oc(?P<realm>[0-17]+)\.(?P<region>[A-Za-z0-9]+)?\.?(?P<future_use>[A-Za-z0-9]+)?\.(?P<unique_id>[A-Za-z0-9]+)$'
408-
match = re.match(pattern, model_ocid)
409-
if match:
410-
# groups = match.groupdict()
411-
return True
412-
return False
408+
def validate_model_ocid_format(model_ocid):
409+
split_words = model_ocid.split('.')
410+
region = split_words[3]
411+
realm = split_words[2]
412+
print(split_words)
413+
# region = auth.get("signer").region will not work for config
414+
# TODO: try to get current region if possible??
415+
if region in oci.regions.REGIONS_SHORT_NAMES:
416+
region = oci.regions.REGIONS_SHORT_NAMES[region]
417+
elif region not in oci.regions.REGIONS:
418+
return False
419+
if realm not in oci.regions.REGION_REALMS[region]:
420+
return False
421+
return True
422+
423+
424+
def search_model_ocids(model_ids: list) -> list:
425+
query = "query datasciencemodel resources where "
426+
items = model_ids
427+
for item in items:
428+
query = query + f"identifier='{item}'||"
429+
list_models = OCIResource.search(
430+
query[:-2]
431+
, type=SEARCH_TYPE.STRUCTURED,
432+
)
433+
list_models_ids = []
434+
for model in list_models:
435+
list_models_ids.append(model.identifier)
436+
for model_id in model_ids:
437+
if model_id not in list_models_ids:
438+
logger.warning(model_id + " doesnt exist")
439+
return list_models_ids

ads/feature_store/dataset.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
get_metastore_id,
2020
validate_delta_format_parameters,
2121
convert_expectation_suite_to_expectation,
22+
validate_model_ocid_format, search_model_ocids
2223
)
2324
from ads.feature_store.dataset_job import DatasetJob, IngestionMode
2425
from ads.feature_store.execution_strategy.engine.spark_engine import SparkEngine
@@ -481,12 +482,13 @@ def model_details(self) -> "ModelDetails":
481482
def model_details(self, model_details: ModelDetails):
482483
self.with_model_details(model_details)
483484

484-
def with_model_details(self, model_details: ModelDetails) -> "Dataset":
485+
def with_model_details(self, model_details: ModelDetails,validate: bool = True) -> "Dataset":
485486
"""Sets the model details for the dataset.
486487
487488
Parameters
488489
----------
489490
model_details: ModelDetails
491+
validate: if to validate the model ids
490492
491493
Returns
492494
-------
@@ -498,6 +500,18 @@ def with_model_details(self, model_details: ModelDetails) -> "Dataset":
498500
"The argument `model_details` has to be of type `ModelDetails`"
499501
"but is of type: `{}`".format(type(model_details))
500502
)
503+
#TODO: we could eiher use validate_model_ocid_format if thats enough or search_model_ocids
504+
#search_model_ocids should be used after set_auth without url
505+
# items = model_details["items"]
506+
# for item in items:
507+
# if not validate_model_ocid(item):
508+
# raise ValueError(
509+
# f"the model ocid {item} is not valid"
510+
# )
511+
if validate:
512+
model_ids_list = search_model_ocids(model_details.items)
513+
model_details = ModelDetails(model_ids_list)
514+
501515
return self.set_spec(self.CONST_MODEL_DETAILS, model_details.to_dict())
502516

503517
def add_models(self, model_details: ModelDetails) -> "Dataset":

ads/feature_store/mixin/oci_feature_store.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ def init_client(
1515
cls, **kwargs
1616
) -> oci.feature_store.feature_store_client.FeatureStoreClient:
1717

18+
# TODO: Getting the endpoint from authorizer
1819
fs_service_endpoint = os.environ.get("OCI_FS_SERVICE_ENDPOINT")
19-
kwargs = {"service_endpoint": fs_service_endpoint}
20-
20+
if fs_service_endpoint:
21+
kwargs = {"service_endpoint": fs_service_endpoint}
2122

2223
client = cls._init_client(
2324
client=oci.feature_store.feature_store_client.FeatureStoreClient, **kwargs

0 commit comments

Comments
 (0)