Skip to content

Commit ee8e1b1

Browse files
committed
added unit test and log level off
1 parent 064cb7d commit ee8e1b1

File tree

7 files changed

+91
-57
lines changed

7 files changed

+91
-57
lines changed

ads/feature_store/common/spark_session_singleton.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(self, metastore_id: str = None):
8787
self.spark_session = spark_builder.getOrCreate()
8888

8989
self.spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
90+
self.spark_session.sparkContext.setLogLevel("OFF")
9091

9192
def get_spark_session(self):
9293
"""Access method to get the spark session."""

ads/feature_store/common/utils/feature_schema_mapper.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def map_pandas_type_to_feature_type(feature_name, values):
8787
inferred_dtype = current_dtype
8888
else:
8989
if (
90-
current_dtype != inferred_dtype
91-
and current_dtype is not FeatureType.UNKNOWN
90+
current_dtype != inferred_dtype
91+
and current_dtype is not FeatureType.UNKNOWN
9292
):
9393
raise TypeError(
9494
f"Input feature '{feature_name}' has mixed types, {current_dtype} and {inferred_dtype}. "
@@ -233,7 +233,7 @@ def map_feature_type_to_pandas(feature_type):
233233

234234

235235
def convert_pandas_datatype_with_schema(
236-
raw_feature_details: List[dict], input_df: pd.DataFrame
236+
raw_feature_details: List[dict], input_df: pd.DataFrame
237237
):
238238
feature_detail_map = {}
239239
for feature_details in raw_feature_details:
@@ -245,13 +245,13 @@ def convert_pandas_datatype_with_schema(
245245
pandas_type = map_feature_type_to_pandas(feature_type)
246246
input_df[column] = (
247247
input_df[column]
248-
.astype(pandas_type)
249-
.where(pd.notnull(input_df[column]), None)
248+
.astype(pandas_type)
249+
.where(pd.notnull(input_df[column]), None)
250250
)
251251

252252

253253
def map_spark_type_to_stats_data_type(spark_type):
254-
""" Maps the spark data types to MLM library data types
254+
"""Maps the spark data types to MLM library data types
255255
args:
256256
param spark_type: Spark data type input from the feature dataframe on which we need stats
257257
:return:
@@ -272,7 +272,7 @@ def map_spark_type_to_stats_data_type(spark_type):
272272

273273

274274
def map_spark_type_to_stats_variable_type(spark_type):
275-
""" Maps the spark data types to MLM library variable types
275+
"""Maps the spark data types to MLM library variable types
276276
args:
277277
param spark_type: Spark data type input from the feature dataframe on which we need stats
278278
:return:

ads/feature_store/common/utils/transformation_query_validator.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@
1313
function_name: Transformation function name
1414
input : The input placeholder for the FROM clause
1515
"""
16-
USER_TRANSFORMATION_FUNCTION = \
17-
"""def {function_name}(input):
16+
USER_TRANSFORMATION_FUNCTION = """def {function_name}(input):
1817
sql_query = f\"\"\"{query}\"\"\"
1918
return sql_query"""
2019

2120

2221
class TransformationQueryValidator:
23-
2422
@staticmethod
2523
def __verify_sql_query_plan(parser_plan, input_symbol: str):
2624
"""
@@ -41,21 +39,23 @@ def __verify_sql_query_plan(parser_plan, input_symbol: str):
4139
cte = re.findall(r"CTE \[(.*?)\]", plan_string)
4240
table_names = []
4341
for plan_item in plan_items:
44-
if plan_item['class'] == 'org.apache.spark.sql.catalyst.analysis.UnresolvedRelation':
45-
table = plan_item['multipartIdentifier']
46-
res = table.strip('][').split(', ')
42+
if (
43+
plan_item["class"]
44+
== "org.apache.spark.sql.catalyst.analysis.UnresolvedRelation"
45+
):
46+
table = plan_item["multipartIdentifier"]
47+
res = table.strip("][").split(", ")
4748
if len(res) >= 2:
48-
raise ValueError(
49-
"FROM Clause has invalid input {0}".format(table))
49+
raise ValueError("FROM Clause has invalid input {0}".format(table))
5050
else:
5151
if res[0].lower() != input_symbol.lower():
5252
raise ValueError(
53-
f"Incorrect table template name, It should be {input_symbol}")
53+
f"Incorrect table template name, It should be {input_symbol}"
54+
)
5455
if table not in cte:
5556
table_names.append(f"{table}")
5657
if len(table_names) > 1:
57-
raise ValueError(
58-
"Multiple tables are not supported ")
58+
raise ValueError("Multiple tables are not supported ")
5959

6060
@staticmethod
6161
def verify_sql_input(query_input: str, input_symbol: str):
@@ -71,13 +71,17 @@ def verify_sql_input(query_input: str, input_symbol: str):
7171
try:
7272
parser_plan = parser.parsePlan(query_input)
7373
except ParseException as pe:
74-
raise ParseException(f"Unable to parse the sql expression, exception occurred: {pe}")
74+
raise ParseException(
75+
f"Unable to parse the sql expression, exception occurred: {pe}"
76+
)
7577

7678
# verify if the parser plan has only FROM DATA_SOURCE_INPUT template
7779
TransformationQueryValidator.__verify_sql_query_plan(parser_plan, input_symbol)
7880

7981
@staticmethod
80-
def create_transformation_template(query: str, input_symbol: str, function_name: str):
82+
def create_transformation_template(
83+
query: str, input_symbol: str, function_name: str
84+
):
8185
"""
8286
Creates the query transformation function to ensure backend integrity
8387
Args:
@@ -86,5 +90,7 @@ def create_transformation_template(query: str, input_symbol: str, function_name:
8690
function_name : The name of the transformation function
8791
"""
8892
transformation_query = query.replace(input_symbol, "{input}")
89-
output = USER_TRANSFORMATION_FUNCTION.format(query=transformation_query, function_name=function_name)
93+
output = USER_TRANSFORMATION_FUNCTION.format(
94+
query=transformation_query, function_name=function_name
95+
)
9096
return output

ads/feature_store/feature_statistics/statistics_service.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,49 +36,64 @@ class StatisticsService:
3636

3737
@staticmethod
3838
def compute_stats_with_mlm(
39-
statistics_config: StatisticsConfig, input_df: DataFrame
39+
statistics_config: StatisticsConfig, input_df: DataFrame
4040
):
4141
feature_metrics = None
42-
if bool(input_df.head(1)) and statistics_config and statistics_config.is_enabled:
42+
if (
43+
bool(input_df.head(1))
44+
and statistics_config
45+
and statistics_config.is_enabled
46+
):
4347
feature_schema = {}
4448
if input_df.schema:
45-
StatisticsService.__get_mlm_supported_schema(feature_schema, input_df, statistics_config)
46-
feature_metrics = StatisticsService.__get_feature_metric(feature_schema, input_df)
49+
StatisticsService.__get_mlm_supported_schema(
50+
feature_schema, input_df, statistics_config
51+
)
52+
feature_metrics = StatisticsService.__get_feature_metric(
53+
feature_schema, input_df
54+
)
4755
else:
4856
raise ValueError("Dataframe schema is missing")
4957
return feature_metrics
5058

5159
@staticmethod
5260
def __get_feature_metric(feature_schema: dict, data_frame: DataFrame):
5361
feature_metrics = None
54-
runner = InsightsBuilder(). \
55-
with_input_schema(input_schema=feature_schema). \
56-
with_data_frame(data_frame=data_frame). \
57-
with_engine(engine=EngineDetail(engine_name="spark")). \
58-
build()
62+
runner = (
63+
InsightsBuilder()
64+
.with_input_schema(input_schema=feature_schema)
65+
.with_data_frame(data_frame=data_frame)
66+
.with_engine(engine=EngineDetail(engine_name="spark"))
67+
.build()
68+
)
5969
result = runner.run()
6070
if result and result.profile:
6171
profile = result.profile
6272
feature_metrics = json.dumps(profile.to_json()[CONST_FEATURE_METRICS])
6373
else:
64-
logger.warning(f"stats computation failed with MLM for schema {feature_schema}")
74+
logger.warning(
75+
f"stats computation failed with MLM for schema {feature_schema}"
76+
)
6577
return feature_metrics
6678

6779
@staticmethod
68-
def __get_mlm_supported_schema(feature_schema: dict, input_df: DataFrame, statistics_config: StatisticsConfig):
80+
def __get_mlm_supported_schema(
81+
feature_schema: dict, input_df: DataFrame, statistics_config: StatisticsConfig
82+
):
6983
relevant_columns = statistics_config.columns
7084
for field in input_df.schema.fields:
7185
data_type = map_spark_type_to_stats_data_type(field.dataType)
7286
if not data_type:
73-
logger.warning(f"Unable to map spark data type fields to MLM fields, "
74-
f"Actual data type {field.dataType}")
87+
logger.warning(
88+
f"Unable to map spark data type fields to MLM fields, "
89+
f"Actual data type {field.dataType}"
90+
)
7591
elif relevant_columns:
7692
if field.name in relevant_columns:
77-
feature_schema[field.name] = FeatureType(data_type,
78-
map_spark_type_to_stats_variable_type(
79-
field.dataType))
93+
feature_schema[field.name] = FeatureType(
94+
data_type, map_spark_type_to_stats_variable_type(field.dataType)
95+
)
8096
else:
81-
feature_schema[field.name] = FeatureType(data_type,
82-
map_spark_type_to_stats_variable_type(
83-
field.dataType))
84-
97+
feature_schema[field.name] = FeatureType(
98+
data_type, map_spark_type_to_stats_variable_type(field.dataType)
99+
)

ads/feature_store/feature_store.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def _build_transformation(
431431
display_name: str = None,
432432
description: str = None,
433433
compartment_id: str = None,
434-
sql_query: str = None
434+
sql_query: str = None,
435435
):
436436
transformation = (
437437
Transformation()
@@ -455,7 +455,7 @@ def create_transformation(
455455
display_name: str = None,
456456
description: str = None,
457457
compartment_id: str = None,
458-
sql_query: str = None
458+
sql_query: str = None,
459459
) -> "Transformation":
460460
"""Creates transformation resource from feature store.
461461
@@ -491,7 +491,7 @@ def create_transformation(
491491
display_name,
492492
description,
493493
compartment_id,
494-
sql_query
494+
sql_query,
495495
)
496496

497497
return self.oci_transformation.create()

ads/feature_store/feature_store_registrar.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,9 @@ def _create_transformations(self) -> List[Transformation]:
252252
)
253253
if transformation.source_code_function:
254254
# to encode to base64
255-
transformation.source_code_function = transformation.source_code_function
255+
transformation.source_code_function = (
256+
transformation.source_code_function
257+
)
256258
return self._transformations.create_models(self._progress)
257259

258260
def _create_feature_groups(self) -> List[FeatureGroup]:

ads/feature_store/transformation.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
from ads.common.oci_mixin import OCIModelMixin
2020
from ads.feature_store.service.oci_transformation import OCITransformation
2121
from ads.jobs.builders.base import Builder
22-
from ads.feature_store.common.utils.transformation_query_validator import TransformationQueryValidator
22+
from ads.feature_store.common.utils.transformation_query_validator import (
23+
TransformationQueryValidator,
24+
)
2325

2426
logger = logging.getLogger(__name__)
2527

@@ -278,16 +280,18 @@ def with_source_code_function(self, source_code_func) -> "Transformation":
278280
Base64EncoderDecoder.encode(source_code),
279281
)
280282
else:
281-
return self.set_spec(
282-
self.CONST_SOURCE_CODE, None)
283+
return self.set_spec(self.CONST_SOURCE_CODE, None)
283284

284285
@property
285286
def transformation_query_input(self) -> str:
286287
return self.get_spec(self.CONST_TRANSFORMATION_QUERY_INPUT)
287288

288289
@transformation_query_input.setter
289290
def transformation_query_input(self, transformation_query):
290-
self.set_spec(self.CONST_TRANSFORMATION_QUERY_INPUT, Base64EncoderDecoder.encode(transformation_query))
291+
self.set_spec(
292+
self.CONST_TRANSFORMATION_QUERY_INPUT,
293+
Base64EncoderDecoder.encode(transformation_query),
294+
)
291295

292296
def with_transformation_query_input(self, transformation_query) -> "Transformation":
293297
"""Sets the source code function for the transformation.
@@ -303,7 +307,8 @@ def with_transformation_query_input(self, transformation_query) -> "Transformati
303307
The Transformation instance (self)
304308
"""
305309
return self.set_spec(
306-
self.CONST_TRANSFORMATION_QUERY_INPUT, transformation_query)
310+
self.CONST_TRANSFORMATION_QUERY_INPUT, transformation_query
311+
)
307312

308313
def with_id(self, id: str) -> "Transformation":
309314
return self.set_spec(self.CONST_ID, id)
@@ -380,7 +385,9 @@ def create(self, **kwargs) -> "Transformation":
380385

381386
if not self.transformation_query_input:
382387
if not self.source_code_function:
383-
raise ValueError("Transformation source code function must be provided.")
388+
raise ValueError(
389+
"Transformation source code function must be provided."
390+
)
384391

385392
if not self.display_name:
386393
self.display_name = self._transformation_function_name
@@ -390,13 +397,16 @@ def create(self, **kwargs) -> "Transformation":
390397
"Transformation display name and function name must be same."
391398
)
392399
elif self.transformation_mode.lower() == TransformationMode.SQL.value:
393-
TransformationQueryValidator.verify_sql_input(self.transformation_query_input,
394-
self.CONST_DATA_SOURCE_TRANSFORMATION_INPUT)
400+
TransformationQueryValidator.verify_sql_input(
401+
self.transformation_query_input,
402+
self.CONST_DATA_SOURCE_TRANSFORMATION_INPUT,
403+
)
395404
# convert it to transformation function to ensure the integrity with backend
396-
output = TransformationQueryValidator \
397-
.create_transformation_template(self.transformation_query_input,
398-
self.CONST_DATA_SOURCE_TRANSFORMATION_INPUT,
399-
self.CONST_TRANSFORMATION_FUNCTION_NAME)
405+
output = TransformationQueryValidator.create_transformation_template(
406+
self.transformation_query_input,
407+
self.CONST_DATA_SOURCE_TRANSFORMATION_INPUT,
408+
self.CONST_TRANSFORMATION_FUNCTION_NAME,
409+
)
400410
self.with_display_name(self.CONST_TRANSFORMATION_FUNCTION_NAME)
401411
self.with_source_code_function(output)
402412
else:

0 commit comments

Comments
 (0)