Skip to content

Commit 820b5bd

Browse files
authored
added as_of interface to the feature store (#279)
2 parents 0e1082f + fa56768 commit 820b5bd

File tree

8 files changed

+432
-8
lines changed

8 files changed

+432
-8
lines changed

ads/feature_store/common/spark_session_singleton.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(self, metastore_id: str = None):
8484
)
8585
.enableHiveSupport()
8686
)
87+
_managed_table_location = None
8788

8889
if not developer_enabled() and metastore_id:
8990
# Get the authentication credentials for the OCI data catalog service
@@ -94,12 +95,11 @@ def __init__(self, metastore_id: str = None):
9495

9596
data_catalog_client = OCIClientFactory(**auth).data_catalog
9697
metastore = data_catalog_client.get_metastore(metastore_id).data
98+
_managed_table_location = metastore.default_managed_table_location
9799
# Configure the Spark session builder object to use the specified metastore
98100
spark_builder.config(
99101
"spark.hadoop.oracle.dcat.metastore.id", metastore_id
100-
).config(
101-
"spark.sql.warehouse.dir", metastore.default_managed_table_location
102-
).config(
102+
).config("spark.sql.warehouse.dir", _managed_table_location).config(
103103
"spark.driver.memory", "16G"
104104
)
105105

@@ -114,7 +114,12 @@ def __init__(self, metastore_id: str = None):
114114

115115
self.spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
116116
self.spark_session.sparkContext.setLogLevel("OFF")
117+
self.managed_table_location = _managed_table_location
117118

118119
def get_spark_session(self):
119120
"""Access method to get the spark session."""
120121
return self.spark_session
122+
123+
def get_managed_table_location(self):
124+
"""Returns the managed table location for the spark"""
125+
return self.managed_table_location

ads/feature_store/common/utils/transformation_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33
import json
4+
45
# Copyright (c) 2023 Oracle and/or its affiliates.
56
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
67

@@ -64,9 +65,13 @@ def apply_transformation(
6465
dataframe.createOrReplaceTempView(temporary_table_view)
6566

6667
transformed_data = spark.sql(
67-
transformation_function_caller(temporary_table_view, **transformation_kwargs_dict)
68+
transformation_function_caller(
69+
temporary_table_view, **transformation_kwargs_dict
70+
)
6871
)
6972
elif transformation.transformation_mode == TransformationMode.PANDAS.value:
70-
transformed_data = transformation_function_caller(dataframe, **transformation_kwargs_dict)
73+
transformed_data = transformation_function_caller(
74+
dataframe, **transformation_kwargs_dict
75+
)
7176

7277
return transformed_data

ads/feature_store/dataset.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@
88

99
import pandas
1010
from great_expectations.core import ExpectationSuite
11+
12+
from ads import deprecated
1113
from ads.common import utils
1214
from ads.common.oci_mixin import OCIModelMixin
1315
from ads.feature_store.common.enums import (
1416
ExecutionEngine,
1517
ExpectationType,
1618
EntityType,
1719
)
20+
from ads.feature_store.common.exceptions import NotMaterializedError
1821
from ads.feature_store.common.utils.utility import (
1922
get_metastore_id,
2023
validate_delta_format_parameters,
@@ -475,6 +478,20 @@ def with_statistics_config(
475478
self.CONST_STATISTICS_CONFIG, statistics_config_in.to_dict()
476479
)
477480

481+
def target_delta_table(self):
482+
"""
483+
Returns the fully-qualified name of the target table for storing delta data.
484+
485+
The name of the target table is constructed by concatenating the entity ID
486+
and the name of the table, separated by a dot. The resulting string has the
487+
format 'entity_id.table_name'.
488+
489+
Returns:
490+
str: The fully-qualified name of the target delta table.
491+
"""
492+
target_table = f"{self.entity_id}.{self.name}"
493+
return target_table
494+
478495
@property
479496
def model_details(self) -> "ModelDetails":
480497
return self.get_spec(self.CONST_MODEL_DETAILS)
@@ -560,7 +577,9 @@ def add_models(self, model_details: ModelDetails) -> "Dataset":
560577
f"Dataset update Failed with : {type(ex)} with error message: {ex}"
561578
)
562579
if existing_model_details:
563-
self.with_model_details(ModelDetails().with_items(existing_model_details["items"]))
580+
self.with_model_details(
581+
ModelDetails().with_items(existing_model_details["items"])
582+
)
564583
else:
565584
self.with_model_details(ModelDetails().with_items([]))
566585
return self
@@ -773,6 +792,7 @@ def materialise(
773792

774793
dataset_execution_strategy.ingest_dataset(self, dataset_job)
775794

795+
@deprecated(details="preview functionality is deprecated. Please use as_of.")
776796
def preview(
777797
self,
778798
row_count: int = 10,
@@ -797,6 +817,8 @@ def preview(
797817
spark dataframe
798818
The preview result in spark dataframe
799819
"""
820+
self.check_resource_materialization()
821+
800822
validate_delta_format_parameters(timestamp, version_number)
801823
target_table = f"{self.entity_id}.{self.name}"
802824

@@ -806,6 +828,43 @@ def preview(
806828

807829
return self.spark_engine.sql(sql_query)
808830

831+
def check_resource_materialization(self):
832+
"""Checks whether the target Delta table for this resource has been materialized in Spark.
833+
If the target Delta table doesn't exist, raises a NotMaterializedError with the type and name of this resource.
834+
"""
835+
if not self.spark_engine.is_delta_table_exists(self.target_delta_table()):
836+
raise NotMaterializedError(self.type, self.name)
837+
838+
def as_of(
839+
self,
840+
version_number: int = None,
841+
commit_timestamp: datetime = None,
842+
):
843+
"""preview the feature definition and return the response in dataframe.
844+
845+
Parameters
846+
----------
847+
commit_timestamp: datetime
848+
commit date time to preview in format yyyy-MM-dd or yyyy-MM-dd HH:mm:ss
849+
commit date time is maintained for every ingestion commit using delta lake
850+
version_number: int
851+
commit version number for the preview. Version numbers are automatically versioned for every ingestion
852+
commit using delta lake
853+
854+
Returns
855+
-------
856+
spark dataframe
857+
The preview result in spark dataframe
858+
"""
859+
self.check_resource_materialization()
860+
861+
validate_delta_format_parameters(commit_timestamp, version_number)
862+
target_table = self.target_delta_table()
863+
864+
return self.spark_engine.get_time_version_data(
865+
target_table, version_number, commit_timestamp
866+
)
867+
809868
def profile(self):
810869
"""Get the dataset profile information and return the response in dataframe.
811870
@@ -814,6 +873,8 @@ def profile(self):
814873
spark dataframe
815874
The profile result in spark dataframe
816875
"""
876+
self.check_resource_materialization()
877+
817878
target_table = f"{self.entity_id}.{self.name}"
818879
sql_query = f"DESCRIBE DETAIL {target_table}"
819880

@@ -835,6 +896,8 @@ def restore(self, version_number: int = None, timestamp: datetime = None):
835896
spark dataframe
836897
The restore output as spark dataframe
837898
"""
899+
self.check_resource_materialization()
900+
838901
validate_delta_format_parameters(timestamp, version_number, True)
839902
target_table = f"{self.entity_id}.{self.name}"
840903
if version_number is not None:

ads/feature_store/docs/source/dataset.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ You can call the ``get_input_schema_dataframe()`` method of the Dataset instance
248248
Preview
249249
========
250250

251+
.. deprecated:: 1.0.3
252+
Use :func:`as_of` instead.
253+
251254
You can call the ``preview()`` method of the Dataset instance to preview the dataset.
252255

253256
The ``.preview()`` method takes the following optional parameter:
@@ -261,6 +264,21 @@ The ``.preview()`` method takes the following optional parameter:
261264
df = dataset.preview(row_count=50)
262265
df.show()
263266
267+
as_of
268+
=======
269+
270+
You can call the ``as_of()`` method of the Dataset instance to get specified point in time and time traveled data.
271+
272+
The ``.as_of()`` method takes the following optional parameter:
273+
274+
- ``commit_timestamp: date-time``. Commit timestamp for dataset
275+
- ``version_number: int``. Version number for dataset
276+
277+
.. code-block:: python3
278+
279+
# as_of feature group
280+
df = dataset.as_of(version_number=1)
281+
264282
265283
Restore
266284
=======

ads/feature_store/docs/source/feature_group.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,13 @@ Feature store provides an API similar to Pandas to join feature groups together
290290
# Filter feature group
291291
feature_group.filter(feature_group.col1 > 10).show()
292292
293+
293294
Preview
294295
=======
295296

297+
.. deprecated:: 1.0.3
298+
Use :func:`as_of` instead.
299+
296300
You can call the ``preview()`` method of the FeatureGroup instance to preview the feature group.
297301

298302
The ``.preview()`` method takes the following optional parameter:
@@ -306,6 +310,21 @@ The ``.preview()`` method takes the following optional parameter:
306310
# Preview feature group
307311
df = feature_group.preview(row_count=50)
308312
313+
as_of
314+
=======
315+
316+
You can call the ``as_of()`` method of the FeatureGroup instance to get specified point in time and time traveled data.
317+
318+
The ``.as_of()`` method takes the following optional parameter:
319+
320+
- ``commit_timestamp: date-time``. Commit timestamp for feature group
321+
- ``version_number: int``. Version number for feature group
322+
323+
.. code-block:: python3
324+
325+
# as_of feature group
326+
df = feature_group.as_of(version_number=1)
327+
309328
Restore
310329
=======
311330

ads/feature_store/execution_strategy/engine/spark_engine.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import logging
8+
from datetime import datetime
89

910
from ads.common.decorator.runtime_dependency import OptionalDependency
1011

@@ -17,7 +18,7 @@
1718
)
1819
except Exception as e:
1920
raise
20-
from typing import List
21+
from typing import List, Dict
2122

2223
from ads.feature_store.common.utils.feature_schema_mapper import (
2324
map_spark_type_to_feature_type,
@@ -36,6 +37,71 @@ def __init__(self, metastore_id: str = None, spark_session: SparkSession = None)
3637
else:
3738
self.spark = SparkSessionSingleton(metastore_id).get_spark_session()
3839

40+
self.managed_table_location = (
41+
SparkSessionSingleton().get_managed_table_location()
42+
)
43+
44+
def get_time_version_data(
45+
self,
46+
delta_table_name: str,
47+
version_number: int = None,
48+
timestamp: datetime = None,
49+
):
50+
split_db_name = delta_table_name.split(".")
51+
52+
# Get the Delta table path
53+
delta_table_path = (
54+
f"{self.managed_table_location}/{split_db_name[0].lower()}.db/{split_db_name[1]}"
55+
if self.managed_table_location
56+
else self._get_delta_table_path(delta_table_name)
57+
)
58+
59+
# Set read options based on version_number and timestamp
60+
read_options = {}
61+
if version_number is not None:
62+
read_options["versionAsOf"] = version_number
63+
if timestamp:
64+
read_options["timestampAsOf"] = timestamp
65+
66+
# Load the data from the Delta table using specified read options
67+
df = self._read_delta_table(delta_table_path, read_options)
68+
return df
69+
70+
def _get_delta_table_path(self, delta_table_name: str) -> str:
71+
"""
72+
Get the path of the Delta table using DESCRIBE EXTENDED SQL command.
73+
74+
Args:
75+
delta_table_name (str): The name of the Delta table.
76+
77+
Returns:
78+
str: The path of the Delta table.
79+
"""
80+
delta_table_path = (
81+
self.spark.sql(f"DESCRIBE EXTENDED {delta_table_name}")
82+
.filter("col_name = 'Location'")
83+
.collect()[0][1]
84+
)
85+
return delta_table_path
86+
87+
def _read_delta_table(self, delta_table_path: str, read_options: Dict):
88+
"""
89+
Read the Delta table using specified read options.
90+
91+
Args:
92+
delta_table_path (str): The path of the Delta table.
93+
read_options (dict): Dictionary of read options for Delta table.
94+
95+
Returns:
96+
DataFrame: The loaded DataFrame from the Delta table.
97+
"""
98+
df = (
99+
self.spark.read.format("delta")
100+
.options(**read_options)
101+
.load(delta_table_path)
102+
)
103+
return df
104+
39105
def sql(
40106
self,
41107
query: str,

0 commit comments

Comments
 (0)