Skip to content

Commit 2a57246

Browse files
stats generation done using mlm and transformation query support for SQL
1 parent 8e8fa68 commit 2a57246

File tree

12 files changed

+412
-102
lines changed

12 files changed

+412
-102
lines changed

ads/common/decorator/runtime_dependency.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ class OptionalDependency:
6565
SPARK = "oracle-ads[spark]"
6666
HUGGINGFACE = "oracle-ads[huggingface]"
6767
GREAT_EXPECTATIONS = "oracle-ads[great-expectations]"
68-
PYDEEQU = "oracle-ads[pydeequ]"
6968
GRAPHVIZ = "oracle-ads[graphviz]"
69+
MLM_INSIGHTS = "oracle-ads[mlm_insights]"
70+
PYARROW = "oracle-ads[pyarrow]"
7071

7172

7273
def runtime_dependency(

ads/feature_store/common/spark_session_singleton.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,14 @@ def __init__(self, metastore_id: str = None):
8080
if developer_enabled():
8181
# Configure spark session with delta jars only in developer mode. In other cases,
8282
# jars should be part of the conda pack
83-
spark_builder.config(
84-
"spark.jars",
85-
"https://repo1.maven.org/maven2/com/amazon/deequ/deequ/2.0.1-spark-3.2/deequ-2.0.1-spark-3.2.jar",
86-
)
8783
self.spark_session = configure_spark_with_delta_pip(
8884
spark_builder
8985
).getOrCreate()
9086
else:
9187
self.spark_session = spark_builder.getOrCreate()
9288

89+
self.spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
90+
9391
def get_spark_session(self):
9492
"""Access method to get the spark session."""
9593
return self.spark_session

ads/feature_store/common/utils/feature_schema_mapper.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import logging
78
from typing import List
89

910
import numpy as np
1011
import pandas as pd
1112

1213
from ads.common.decorator.runtime_dependency import OptionalDependency
14+
from mlm_insights.constants import types
15+
1316
from ads.feature_store.common.enums import FeatureType
1417

1518
try:
@@ -22,6 +25,8 @@
2225
except Exception as e:
2326
raise
2427

28+
logger = logging.getLogger(__name__)
29+
2530

2631
def map_spark_type_to_feature_type(spark_type):
2732
"""Returns the feature type corresponding to SparkType
@@ -243,3 +248,40 @@ def convert_pandas_datatype_with_schema(
243248
.astype(pandas_type)
244249
.where(pd.notnull(input_df[column]), None)
245250
)
251+
252+
def map_spark_type_to_stats_data_type(spark_type):
253+
"""Returns the MLM data type corresponding to SparkType
254+
:param spark_type:
255+
:return:
256+
"""
257+
spark_type_to_feature_type = {
258+
StringType(): types.DataType.STRING,
259+
IntegerType(): types.DataType.INTEGER,
260+
FloatType(): types.DataType.FLOAT,
261+
DoubleType(): types.DataType.FLOAT,
262+
BooleanType(): types.DataType.BOOLEAN,
263+
DecimalType(): types.DataType.FLOAT,
264+
ShortType(): types.DataType.INTEGER,
265+
LongType(): types.DataType.INTEGER,
266+
}
267+
268+
return spark_type_to_feature_type.get(spark_type)
269+
270+
271+
def map_spark_type_to_stats_variable_type(spark_type):
272+
"""Returns the MLM variable type corresponding to SparkType
273+
:param spark_type:
274+
:return:
275+
"""
276+
spark_type_to_feature_type = {
277+
StringType(): types.VariableType.NOMINAL,
278+
IntegerType(): types.VariableType.CONTINUOUS,
279+
FloatType(): types.VariableType.CONTINUOUS,
280+
DoubleType(): types.VariableType.CONTINUOUS,
281+
BooleanType(): types.VariableType.BINARY,
282+
DecimalType(): types.VariableType.CONTINUOUS,
283+
ShortType(): types.VariableType.CONTINUOUS,
284+
LongType(): types.VariableType.CONTINUOUS,
285+
}
286+
287+
return spark_type_to_feature_type.get(spark_type)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import json
2+
import re
3+
4+
from pyparsing import ParseException
5+
6+
from ads.feature_store.common.spark_session_singleton import SparkSessionSingleton
7+
8+
USER_TRANSFORMATION_FUNCTION = \
9+
"""def {function_name}(input):
10+
sql_query = f\"\"\"{query}\"\"\"
11+
return sql_query"""
12+
13+
14+
class TransformationQueryValidator:
15+
16+
@staticmethod
17+
def __verify_sql_query_plan(parser_plan, input_symbol: str):
18+
"""
19+
Once the sql parser has parsed the query,
20+
This function takes the parser plan as an input, It checks for the table names
21+
and verifies to ensure that there should only be single table and that too should have the placeholder name
22+
23+
Args:
24+
parser_plan: A Spark sqlParser ParsePlan object.
25+
input_symbol (Transformation): The table name to be matched.
26+
"""
27+
plan_items = json.loads(parser_plan.toJSON())
28+
plan_string = parser_plan.toString()
29+
cte = re.findall(r"CTE \[(.*?)\]", plan_string)
30+
table_names = []
31+
for plan_item in plan_items:
32+
if plan_item['class'] == 'org.apache.spark.sql.catalyst.analysis.UnresolvedRelation':
33+
table = plan_item['multipartIdentifier']
34+
res = table.strip('][').split(', ')
35+
if len(res) >= 2:
36+
raise ValueError(
37+
"FROM Clause has invalid input {0}".format(table))
38+
else:
39+
if res[0].lower() != input_symbol.lower():
40+
raise ValueError(
41+
f"Incorrect table template name, It should be {input_symbol}")
42+
if table not in cte:
43+
table_names.append(f"{table}")
44+
if len(table_names) > 1:
45+
raise ValueError(
46+
"Multiple tables are not supported ")
47+
48+
@staticmethod
49+
def verify_sql_input(query_input: str, input_symbol: str):
50+
"""
51+
verifies the query provided by user to ensure that its a valid sql query.
52+
53+
Args:
54+
query_input: A Spark Sql query
55+
input_symbol (Transformation): The table name to be matched.
56+
"""
57+
spark = SparkSessionSingleton().get_spark_session()
58+
parser = spark._jsparkSession.sessionState().sqlParser()
59+
try:
60+
parser_plan = parser.parsePlan(query_input)
61+
# verify if the parser plan has only FROM DATA_SOURCE_INPUT template
62+
TransformationQueryValidator.__verify_sql_query_plan(parser_plan, input_symbol)
63+
except ParseException as pe:
64+
raise ParseException(f"Unable to parse the sql expression, exception occurred: {pe}")
65+
66+
@staticmethod
67+
def create_transformation_template(query: str, input_symbol: str, function_name: str):
68+
"""
69+
Creates the query transformation function to ensure backend integrity
70+
Args:
71+
query: A Spark Sql query
72+
input_symbol (Transformation): The table name to be used.
73+
function_name : The name of the transformation function
74+
"""
75+
transformation_query = query.replace(input_symbol, "{input}")
76+
output = USER_TRANSFORMATION_FUNCTION.format(query=transformation_query, function_name=function_name)
77+
return output

ads/feature_store/execution_strategy/spark/spark_execution.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from ads.feature_store.feature_group_job import FeatureGroupJob
4141
from ads.feature_store.transformation import Transformation
4242

43-
from ads.feature_store.feature_statistics.pydeequ_service import StatisticsService
43+
from ads.feature_store.feature_statistics.statistics_service import StatisticsService
4444

4545
logger = logging.getLogger(__name__)
4646

@@ -227,11 +227,10 @@ def _save_offline_dataframe(
227227

228228
logger.info(f"output features for the FeatureGroup: {output_features}")
229229
# Compute Feature Statistics
230-
feature_statistics = StatisticsService.compute_statistics(
231-
spark=self._spark_session,
230+
231+
feature_statistics = StatisticsService.compute_stats_with_mlm(
232232
statistics_config=feature_group.statistics_config,
233-
input_df=featured_data,
234-
)
233+
input_df=featured_data)
235234

236235
except Exception as ex:
237236
error_details = str(ex)
@@ -347,8 +346,7 @@ def _save_dataset_input(self, dataset, dataset_job: DatasetJob):
347346
logger.info(f"output features for the dataset: {output_features}")
348347

349348
# Compute Feature Statistics
350-
feature_statistics = StatisticsService.compute_statistics(
351-
spark=self._spark_session,
349+
feature_statistics = StatisticsService.compute_stats_with_mlm(
352350
statistics_config=dataset.statistics_config,
353351
input_df=dataset_dataframe,
354352
)

ads/feature_store/feature_statistics/pydeequ_service.py

Lines changed: 0 additions & 66 deletions
This file was deleted.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8; -*-
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
import json
7+
import logging
8+
9+
from ads.feature_store.statistics_config import StatisticsConfig
10+
from ads.feature_store.common.utils.feature_schema_mapper import *
11+
12+
try:
13+
from pyspark.sql import DataFrame
14+
except ModuleNotFoundError:
15+
raise ModuleNotFoundError(
16+
f"The `pyspark` module was not found. Please run `pip install "
17+
f"{OptionalDependency.SPARK}`."
18+
)
19+
20+
try:
21+
from mlm_insights.builder.builder_component import EngineDetail
22+
from mlm_insights.builder.insights_builder import InsightsBuilder
23+
from mlm_insights.constants.types import FeatureType
24+
except ModuleNotFoundError:
25+
raise ModuleNotFoundError(
26+
f"The `mlm_insights` module was not found. Please run `pip install "
27+
f"{OptionalDependency.MLM_INSIGHTS}`."
28+
)
29+
30+
logger = logging.getLogger(__name__)
31+
CONST_FEATURE_METRICS = "feature_metrics"
32+
33+
34+
class StatisticsService:
35+
"""StatisticsService is used to compute the statistics using pydeequ column profiler"""
36+
37+
@staticmethod
38+
def compute_stats_with_mlm(
39+
statistics_config: StatisticsConfig, input_df: DataFrame
40+
):
41+
feature_metrics = None
42+
if bool(input_df.head(1)) and statistics_config and statistics_config.get("isEnabled"):
43+
feature_schema = {}
44+
if input_df.schema:
45+
StatisticsService.__get_mlm_supported_schema(feature_schema, input_df, statistics_config)
46+
feature_metrics = StatisticsService.__get_feature_metric(feature_schema, input_df)
47+
else:
48+
raise ValueError("Dataframe schema is missing")
49+
return feature_metrics
50+
51+
@staticmethod
52+
def __get_feature_metric(feature_schema: dict, data_frame: DataFrame):
53+
feature_metrics = None
54+
runner = InsightsBuilder(). \
55+
with_input_schema(input_schema=feature_schema). \
56+
with_data_frame(data_frame=data_frame). \
57+
with_engine(engine=EngineDetail(engine_name="spark")). \
58+
build()
59+
result = runner.run()
60+
if result and result.profile:
61+
profile = result.profile
62+
feature_metrics = json.dumps(profile.to_json()[CONST_FEATURE_METRICS])
63+
else:
64+
logger.warning(f"stats computation failed with MLM for schema {feature_schema}")
65+
return feature_metrics
66+
67+
@staticmethod
68+
def __get_mlm_supported_schema(feature_schema: dict, input_df: DataFrame, statistics_config: StatisticsConfig):
69+
relevant_columns = statistics_config.get("columns")
70+
for field in input_df.schema.fields:
71+
data_type = map_spark_type_to_stats_data_type(field.dataType)
72+
if not data_type:
73+
logger.warning(f"Unable to map spark data type fields to MLM fields, "
74+
f"Actual data type {field.dataType}")
75+
elif relevant_columns:
76+
if field.name in relevant_columns:
77+
feature_schema[field.name] = FeatureType(data_type,
78+
map_spark_type_to_stats_variable_type(
79+
field.dataType))
80+
else:
81+
feature_schema[field.name] = FeatureType(data_type,
82+
map_spark_type_to_stats_variable_type(
83+
field.dataType))
84+

0 commit comments

Comments
 (0)