Skip to content

Commit 8c1e68c

Browse files
addressed review comments
1 parent 5b5ce60 commit 8c1e68c

File tree

5 files changed

+42
-18
lines changed

5 files changed

+42
-18
lines changed

THIRD_PARTY_LICENSES.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,12 @@ pyspark
247247
* Source code: https://github.com/apache/spark/tree/master/python
248248
* Project home: https://spark.apache.org/
249249

250+
pyarrow
251+
* Copyright 2004 and onwards The Apache Software Foundation.
252+
* License: Apache-2.0 LICENSE
253+
* Source code: https://github.com/apache/arrow/tree/main/python
254+
* Project home: https://arrow.apache.org/
255+
250256
python_jsonschema_objects
251257
* Copyright (c) 2014 Chris Wacek
252258
* License: MIT License

ads/feature_store/common/utils/feature_schema_mapper.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def map_pandas_type_to_feature_type(feature_name, values):
8787
inferred_dtype = current_dtype
8888
else:
8989
if (
90-
current_dtype != inferred_dtype
91-
and current_dtype is not FeatureType.UNKNOWN
90+
current_dtype != inferred_dtype
91+
and current_dtype is not FeatureType.UNKNOWN
9292
):
9393
raise TypeError(
9494
f"Input feature '{feature_name}' has mixed types, {current_dtype} and {inferred_dtype}. "
@@ -233,7 +233,7 @@ def map_feature_type_to_pandas(feature_type):
233233

234234

235235
def convert_pandas_datatype_with_schema(
236-
raw_feature_details: List[dict], input_df: pd.DataFrame
236+
raw_feature_details: List[dict], input_df: pd.DataFrame
237237
):
238238
feature_detail_map = {}
239239
for feature_details in raw_feature_details:
@@ -245,16 +245,19 @@ def convert_pandas_datatype_with_schema(
245245
pandas_type = map_feature_type_to_pandas(feature_type)
246246
input_df[column] = (
247247
input_df[column]
248-
.astype(pandas_type)
249-
.where(pd.notnull(input_df[column]), None)
248+
.astype(pandas_type)
249+
.where(pd.notnull(input_df[column]), None)
250250
)
251251

252+
252253
def map_spark_type_to_stats_data_type(spark_type):
253-
"""Returns the MLM data type corresponding to SparkType
254-
:param spark_type:
254+
""" Maps the spark data types to MLM library data types
255+
args:
256+
param spark_type: Spark data type input from the feature dataframe on which we need stats
255257
:return:
258+
Returns the MLM data type corresponding to SparkType
256259
"""
257-
spark_type_to_feature_type = {
260+
spark_type_to_mlm_data_type = {
258261
StringType(): types.DataType.STRING,
259262
IntegerType(): types.DataType.INTEGER,
260263
FloatType(): types.DataType.FLOAT,
@@ -265,13 +268,15 @@ def map_spark_type_to_stats_data_type(spark_type):
265268
LongType(): types.DataType.INTEGER,
266269
}
267270

268-
return spark_type_to_feature_type.get(spark_type)
271+
return spark_type_to_mlm_data_type.get(spark_type)
269272

270273

271274
def map_spark_type_to_stats_variable_type(spark_type):
272-
"""Returns the MLM variable type corresponding to SparkType
273-
:param spark_type:
275+
""" Maps the spark data types to MLM library variable types
276+
args:
277+
param spark_type: Spark data type input from the feature dataframe on which we need stats
274278
:return:
279+
Returns the MLM variable type corresponding to SparkType
275280
"""
276281
spark_type_to_feature_type = {
277282
StringType(): types.VariableType.NOMINAL,

ads/feature_store/common/utils/transformation_query_validator.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55

66
from ads.feature_store.common.spark_session_singleton import SparkSessionSingleton
77

8+
"""
9+
USER_TRANSFORMATION_FUNCTION template: It is used to transform the user provided SQL query to the
10+
transformation function
11+
12+
Args:
13+
function_name: Transformation function name
14+
input : The input placeholder for the FROM clause
15+
"""
816
USER_TRANSFORMATION_FUNCTION = \
917
"""def {function_name}(input):
1018
sql_query = f\"\"\"{query}\"\"\"
@@ -19,9 +27,13 @@ def __verify_sql_query_plan(parser_plan, input_symbol: str):
1927
Once the sql parser has parsed the query,
2028
This function takes the parser plan as an input, It checks for the table names
2129
and verifies to ensure that there should only be single table and that too should have the placeholder name
22-
30+
A regex has been added to cater to common table expressions
2331
Args:
2432
parser_plan: A Spark sqlParser ParsePlan object.
33+
parser_plan contain the project and unresolved relation items
34+
project: list of unresolved attributes - table field names
35+
UnresolvedRelation: list of unresolved relation attributes - table names
36+
e.g. : Project ['user_id, 'credit_score], 'UnresolvedRelation [DATA_SOURCE_INPUT], [], false
2537
input_symbol (Transformation): The table name to be matched.
2638
"""
2739
plan_items = json.loads(parser_plan.toJSON())
@@ -58,11 +70,12 @@ def verify_sql_input(query_input: str, input_symbol: str):
5870
parser = spark._jsparkSession.sessionState().sqlParser()
5971
try:
6072
parser_plan = parser.parsePlan(query_input)
61-
# verify if the parser plan has only FROM DATA_SOURCE_INPUT template
62-
TransformationQueryValidator.__verify_sql_query_plan(parser_plan, input_symbol)
6373
except ParseException as pe:
6474
raise ParseException(f"Unable to parse the sql expression, exception occurred: {pe}")
6575

76+
# verify if the parser plan has only FROM DATA_SOURCE_INPUT template
77+
TransformationQueryValidator.__verify_sql_query_plan(parser_plan, input_symbol)
78+
6679
@staticmethod
6780
def create_transformation_template(query: str, input_symbol: str, function_name: str):
6881
"""

ads/feature_store/execution_strategy/spark/spark_execution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _save_offline_dataframe(
229229
# Compute Feature Statistics
230230

231231
feature_statistics = StatisticsService.compute_stats_with_mlm(
232-
statistics_config=feature_group.statistics_config,
232+
statistics_config=feature_group.oci_feature_group.statistics_config,
233233
input_df=featured_data)
234234

235235
except Exception as ex:
@@ -347,7 +347,7 @@ def _save_dataset_input(self, dataset, dataset_job: DatasetJob):
347347

348348
# Compute Feature Statistics
349349
feature_statistics = StatisticsService.compute_stats_with_mlm(
350-
statistics_config=dataset.statistics_config,
350+
statistics_config=dataset.oci_dataset.statistics_config,
351351
input_df=dataset_dataframe,
352352
)
353353

ads/feature_store/feature_statistics/statistics_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def compute_stats_with_mlm(
3939
statistics_config: StatisticsConfig, input_df: DataFrame
4040
):
4141
feature_metrics = None
42-
if bool(input_df.head(1)) and statistics_config and statistics_config.get("isEnabled"):
42+
if bool(input_df.head(1)) and statistics_config and statistics_config.is_enabled:
4343
feature_schema = {}
4444
if input_df.schema:
4545
StatisticsService.__get_mlm_supported_schema(feature_schema, input_df, statistics_config)
@@ -66,7 +66,7 @@ def __get_feature_metric(feature_schema: dict, data_frame: DataFrame):
6666

6767
@staticmethod
6868
def __get_mlm_supported_schema(feature_schema: dict, input_df: DataFrame, statistics_config: StatisticsConfig):
69-
relevant_columns = statistics_config.get("columns")
69+
relevant_columns = statistics_config.columns
7070
for field in input_df.schema.fields:
7171
data_type = map_spark_type_to_stats_data_type(field.dataType)
7272
if not data_type:

0 commit comments

Comments
 (0)