Skip to content

Commit 9fb835e

Browse files
stats generation done using mlm and transformation query support for SQL (#233)
Verified with Integration Test Run
2 parents 8e8fa68 + dd1f33e commit 9fb835e

16 files changed

+461
-112
lines changed

THIRD_PARTY_LICENSES.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,12 @@ pyspark
247247
* Source code: https://github.com/apache/spark/tree/master/python
248248
* Project home: https://spark.apache.org/
249249

250+
pyarrow
251+
* Copyright 2004 and onwards The Apache Software Foundation.
252+
* License: Apache-2.0 LICENSE
253+
* Source code: https://github.com/apache/arrow/tree/main/python
254+
* Project home: https://arrow.apache.org/
255+
250256
python_jsonschema_objects
251257
* Copyright (c) 2014 Chris Wacek
252258
* License: MIT License

ads/common/decorator/runtime_dependency.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ class OptionalDependency:
6565
SPARK = "oracle-ads[spark]"
6666
HUGGINGFACE = "oracle-ads[huggingface]"
6767
GREAT_EXPECTATIONS = "oracle-ads[great-expectations]"
68-
PYDEEQU = "oracle-ads[pydeequ]"
6968
GRAPHVIZ = "oracle-ads[graphviz]"
69+
MLM_INSIGHTS = "oracle-ads[mlm_insights]"
70+
PYARROW = "oracle-ads[pyarrow]"
7071

7172

7273
def runtime_dependency(

ads/feature_store/common/spark_session_singleton.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,14 @@ def __init__(self, metastore_id: str = None):
8080
if developer_enabled():
8181
# Configure spark session with delta jars only in developer mode. In other cases,
8282
# jars should be part of the conda pack
83-
spark_builder.config(
84-
"spark.jars",
85-
"https://repo1.maven.org/maven2/com/amazon/deequ/deequ/2.0.1-spark-3.2/deequ-2.0.1-spark-3.2.jar",
86-
)
8783
self.spark_session = configure_spark_with_delta_pip(
8884
spark_builder
8985
).getOrCreate()
9086
else:
9187
self.spark_session = spark_builder.getOrCreate()
9288

89+
self.spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
90+
9391
def get_spark_session(self):
9492
"""Access method to get the spark session."""
9593
return self.spark_session

ads/feature_store/common/utils/feature_schema_mapper.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import logging
78
from typing import List
89

910
import numpy as np
1011
import pandas as pd
1112

1213
from ads.common.decorator.runtime_dependency import OptionalDependency
14+
from mlm_insights.constants import types
15+
1316
from ads.feature_store.common.enums import FeatureType
1417

1518
try:
@@ -22,6 +25,8 @@
2225
except Exception as e:
2326
raise
2427

28+
logger = logging.getLogger(__name__)
29+
2530

2631
def map_spark_type_to_feature_type(spark_type):
2732
"""Returns the feature type corresponding to SparkType
@@ -82,8 +87,8 @@ def map_pandas_type_to_feature_type(feature_name, values):
8287
inferred_dtype = current_dtype
8388
else:
8489
if (
85-
current_dtype != inferred_dtype
86-
and current_dtype is not FeatureType.UNKNOWN
90+
current_dtype != inferred_dtype
91+
and current_dtype is not FeatureType.UNKNOWN
8792
):
8893
raise TypeError(
8994
f"Input feature '{feature_name}' has mixed types, {current_dtype} and {inferred_dtype}. "
@@ -228,7 +233,7 @@ def map_feature_type_to_pandas(feature_type):
228233

229234

230235
def convert_pandas_datatype_with_schema(
231-
raw_feature_details: List[dict], input_df: pd.DataFrame
236+
raw_feature_details: List[dict], input_df: pd.DataFrame
232237
):
233238
feature_detail_map = {}
234239
for feature_details in raw_feature_details:
@@ -240,6 +245,48 @@ def convert_pandas_datatype_with_schema(
240245
pandas_type = map_feature_type_to_pandas(feature_type)
241246
input_df[column] = (
242247
input_df[column]
243-
.astype(pandas_type)
244-
.where(pd.notnull(input_df[column]), None)
248+
.astype(pandas_type)
249+
.where(pd.notnull(input_df[column]), None)
245250
)
251+
252+
253+
def map_spark_type_to_stats_data_type(spark_type):
254+
""" Maps the spark data types to MLM library data types
255+
args:
256+
param spark_type: Spark data type input from the feature dataframe on which we need stats
257+
:return:
258+
Returns the MLM data type corresponding to SparkType
259+
"""
260+
spark_type_to_mlm_data_type = {
261+
StringType(): types.DataType.STRING,
262+
IntegerType(): types.DataType.INTEGER,
263+
FloatType(): types.DataType.FLOAT,
264+
DoubleType(): types.DataType.FLOAT,
265+
BooleanType(): types.DataType.BOOLEAN,
266+
DecimalType(): types.DataType.FLOAT,
267+
ShortType(): types.DataType.INTEGER,
268+
LongType(): types.DataType.INTEGER,
269+
}
270+
271+
return spark_type_to_mlm_data_type.get(spark_type)
272+
273+
274+
def map_spark_type_to_stats_variable_type(spark_type):
275+
""" Maps the spark data types to MLM library variable types
276+
args:
277+
param spark_type: Spark data type input from the feature dataframe on which we need stats
278+
:return:
279+
Returns the MLM variable type corresponding to SparkType
280+
"""
281+
spark_type_to_feature_type = {
282+
StringType(): types.VariableType.NOMINAL,
283+
IntegerType(): types.VariableType.CONTINUOUS,
284+
FloatType(): types.VariableType.CONTINUOUS,
285+
DoubleType(): types.VariableType.CONTINUOUS,
286+
BooleanType(): types.VariableType.BINARY,
287+
DecimalType(): types.VariableType.CONTINUOUS,
288+
ShortType(): types.VariableType.CONTINUOUS,
289+
LongType(): types.VariableType.CONTINUOUS,
290+
}
291+
292+
return spark_type_to_feature_type.get(spark_type)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import json
2+
import re
3+
4+
from pyparsing import ParseException
5+
6+
from ads.feature_store.common.spark_session_singleton import SparkSessionSingleton
7+
8+
"""
9+
USER_TRANSFORMATION_FUNCTION template: It is used to transform the user provided SQL query to the
10+
transformation function
11+
12+
Args:
13+
function_name: Transformation function name
14+
input : The input placeholder for the FROM clause
15+
"""
16+
USER_TRANSFORMATION_FUNCTION = \
17+
"""def {function_name}(input):
18+
sql_query = f\"\"\"{query}\"\"\"
19+
return sql_query"""
20+
21+
22+
class TransformationQueryValidator:
23+
24+
@staticmethod
25+
def __verify_sql_query_plan(parser_plan, input_symbol: str):
26+
"""
27+
Once the sql parser has parsed the query,
28+
This function takes the parser plan as an input, It checks for the table names
29+
and verifies to ensure that there should only be single table and that too should have the placeholder name
30+
A regex has been added to cater to common table expressions
31+
Args:
32+
parser_plan: A Spark sqlParser ParsePlan object.
33+
parser_plan contain the project and unresolved relation items
34+
project: list of unresolved attributes - table field names
35+
UnresolvedRelation: list of unresolved relation attributes - table names
36+
e.g. : Project ['user_id, 'credit_score], 'UnresolvedRelation [DATA_SOURCE_INPUT], [], false
37+
input_symbol (Transformation): The table name to be matched.
38+
"""
39+
plan_items = json.loads(parser_plan.toJSON())
40+
plan_string = parser_plan.toString()
41+
cte = re.findall(r"CTE \[(.*?)\]", plan_string)
42+
table_names = []
43+
for plan_item in plan_items:
44+
if plan_item['class'] == 'org.apache.spark.sql.catalyst.analysis.UnresolvedRelation':
45+
table = plan_item['multipartIdentifier']
46+
res = table.strip('][').split(', ')
47+
if len(res) >= 2:
48+
raise ValueError(
49+
"FROM Clause has invalid input {0}".format(table))
50+
else:
51+
if res[0].lower() != input_symbol.lower():
52+
raise ValueError(
53+
f"Incorrect table template name, It should be {input_symbol}")
54+
if table not in cte:
55+
table_names.append(f"{table}")
56+
if len(table_names) > 1:
57+
raise ValueError(
58+
"Multiple tables are not supported ")
59+
60+
@staticmethod
61+
def verify_sql_input(query_input: str, input_symbol: str):
62+
"""
63+
verifies the query provided by user to ensure that its a valid sql query.
64+
65+
Args:
66+
query_input: A Spark Sql query
67+
input_symbol (Transformation): The table name to be matched.
68+
"""
69+
spark = SparkSessionSingleton().get_spark_session()
70+
parser = spark._jsparkSession.sessionState().sqlParser()
71+
try:
72+
parser_plan = parser.parsePlan(query_input)
73+
except ParseException as pe:
74+
raise ParseException(f"Unable to parse the sql expression, exception occurred: {pe}")
75+
76+
# verify if the parser plan has only FROM DATA_SOURCE_INPUT template
77+
TransformationQueryValidator.__verify_sql_query_plan(parser_plan, input_symbol)
78+
79+
@staticmethod
80+
def create_transformation_template(query: str, input_symbol: str, function_name: str):
81+
"""
82+
Creates the query transformation function to ensure backend integrity
83+
Args:
84+
query: A Spark Sql query
85+
input_symbol (Transformation): The table name to be used.
86+
function_name : The name of the transformation function
87+
"""
88+
transformation_query = query.replace(input_symbol, "{input}")
89+
output = USER_TRANSFORMATION_FUNCTION.format(query=transformation_query, function_name=function_name)
90+
return output

ads/feature_store/execution_strategy/spark/spark_execution.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from ads.feature_store.feature_group_job import FeatureGroupJob
4141
from ads.feature_store.transformation import Transformation
4242

43-
from ads.feature_store.feature_statistics.pydeequ_service import StatisticsService
43+
from ads.feature_store.feature_statistics.statistics_service import StatisticsService
4444

4545
logger = logging.getLogger(__name__)
4646

@@ -227,11 +227,10 @@ def _save_offline_dataframe(
227227

228228
logger.info(f"output features for the FeatureGroup: {output_features}")
229229
# Compute Feature Statistics
230-
feature_statistics = StatisticsService.compute_statistics(
231-
spark=self._spark_session,
232-
statistics_config=feature_group.statistics_config,
233-
input_df=featured_data,
234-
)
230+
231+
feature_statistics = StatisticsService.compute_stats_with_mlm(
232+
statistics_config=feature_group.oci_feature_group.statistics_config,
233+
input_df=featured_data)
235234

236235
except Exception as ex:
237236
error_details = str(ex)
@@ -347,9 +346,8 @@ def _save_dataset_input(self, dataset, dataset_job: DatasetJob):
347346
logger.info(f"output features for the dataset: {output_features}")
348347

349348
# Compute Feature Statistics
350-
feature_statistics = StatisticsService.compute_statistics(
351-
spark=self._spark_session,
352-
statistics_config=dataset.statistics_config,
349+
feature_statistics = StatisticsService.compute_stats_with_mlm(
350+
statistics_config=dataset.oci_dataset.statistics_config,
353351
input_df=dataset_dataframe,
354352
)
355353

ads/feature_store/feature_statistics/pydeequ_service.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

0 commit comments

Comments
 (0)