Skip to content

Commit 033ce8d

Browse files
committed
kwargs support for the transformation
1 parent 38e19d1 commit 033ce8d

File tree

4 files changed

+55
-4
lines changed

4 files changed

+55
-4
lines changed

ads/feature_store/common/utils/transformation_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
3-
3+
import json
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

@@ -28,6 +28,7 @@ def apply_transformation(
2828
spark: SparkSession,
2929
dataframe: Union[DataFrame, pd.DataFrame],
3030
transformation: Transformation,
31+
transformation_kwargs: str,
3132
):
3233
"""
3334
Perform data transformation using either SQL or Pandas, depending on the specified transformation mode.
@@ -36,6 +37,7 @@ def apply_transformation(
3637
spark: A SparkSession object.
3738
transformation (Transformation): A transformation object containing details of transformation to be performed.
3839
dataframe (DataFrame): The input dataframe to be transformed.
40+
transformation_kwargs(str): The transformation parameters as json string.
3941
4042
Returns:
4143
DataFrame: The resulting transformed data.
@@ -54,15 +56,17 @@ def apply_transformation(
5456
)
5557
transformed_data = None
5658

59+
transformation_kwargs_dict = json.loads(transformation_kwargs)
60+
5761
if transformation.transformation_mode == TransformationMode.SQL.value:
5862
# Register the temporary table
5963
temporary_table_view = "df_view"
6064
dataframe.createOrReplaceTempView(temporary_table_view)
6165

6266
transformed_data = spark.sql(
63-
transformation_function_caller(temporary_table_view)
67+
transformation_function_caller(temporary_table_view, **transformation_kwargs_dict)
6468
)
6569
elif transformation.transformation_mode == TransformationMode.PANDAS.value:
66-
transformed_data = transformation_function_caller(dataframe)
70+
transformed_data = transformation_function_caller(dataframe, **transformation_kwargs_dict)
6771

6872
return transformed_data

ads/feature_store/entity.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def _build_feature_group(
292292
name: str = None,
293293
description: str = None,
294294
compartment_id: str = None,
295+
transformation_kwargs: Dict = None,
295296
):
296297
feature_group_resource = (
297298
FeatureGroup()
@@ -304,6 +305,7 @@ def _build_feature_group(
304305
.with_entity_id(self.id)
305306
.with_transformation_id(transformation_id)
306307
.with_partition_keys(partition_keys)
308+
.with_transformation_kwargs(transformation_kwargs)
307309
.with_primary_keys(primary_keys)
308310
.with_input_feature_details(input_feature_details)
309311
.with_statistics_config(statistics_config)
@@ -328,6 +330,7 @@ def create_feature_group(
328330
name: str = None,
329331
description: str = None,
330332
compartment_id: str = None,
333+
transformation_kwargs: Dict = None,
331334
) -> "FeatureGroup":
332335
"""Creates FeatureGroup resource.
333336
@@ -355,6 +358,8 @@ def create_feature_group(
355358
Description about the Resource.
356359
compartment_id: str = None
357360
compartment_id
361+
transformation_kwargs: Dict
362+
Arguments for the transformation.
358363
359364
360365
Returns
@@ -391,6 +396,7 @@ def create_feature_group(
391396
name,
392397
description,
393398
compartment_id,
399+
transformation_kwargs,
394400
)
395401

396402
return self.oci_feature_group.create()

ads/feature_store/execution_strategy/spark/spark_execution.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pandas as pd
1010

1111
from ads.common.decorator.runtime_dependency import OptionalDependency
12+
from ads.feature_store.common.utils.base64_encoder_decoder import Base64EncoderDecoder
1213
from ads.feature_store.common.utils.utility import (
1314
get_features,
1415
show_ingestion_summary,
@@ -238,11 +239,20 @@ def _save_offline_dataframe(
238239
# Apply the transformation
239240
if feature_group.transformation_id:
240241
logger.info("Dataframe is transformation enabled.")
242+
243+
# Get the Transformation Arguments if exists and pass to the transformation function.
244+
transformation_kwargs = Base64EncoderDecoder.decode(
245+
feature_group.transformation_kwargs
246+
)
247+
241248
# Loads the transformation resource
242249
transformation = Transformation.from_id(feature_group.transformation_id)
243250

244251
featured_data = TransformationUtils.apply_transformation(
245-
self._spark_session, data_frame, transformation
252+
self._spark_session,
253+
data_frame,
254+
transformation,
255+
transformation_kwargs,
246256
)
247257
else:
248258
logger.info("Transformation not defined.")

ads/feature_store/feature_group.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ads.feature_store.common.exceptions import (
2121
NotMaterializedError,
2222
)
23+
from ads.feature_store.common.utils.base64_encoder_decoder import Base64EncoderDecoder
2324
from ads.feature_store.common.utils.utility import (
2425
get_metastore_id,
2526
get_execution_engine_type,
@@ -137,6 +138,7 @@ class FeatureGroup(Builder):
137138
CONST_LIFECYCLE_STATE = "lifecycleState"
138139
CONST_LAST_JOB_ID = "jobId"
139140
CONST_INFER_SCHEMA = "isInferSchema"
141+
CONST_TRANSFORMATION_KWARGS = "transformationParameters"
140142

141143
attribute_map = {
142144
CONST_ID: "id",
@@ -157,6 +159,7 @@ class FeatureGroup(Builder):
157159
CONST_STATISTICS_CONFIG: "statistics_config",
158160
CONST_INFER_SCHEMA: "is_infer_schema",
159161
CONST_PARTITION_KEYS: "partition_keys",
162+
CONST_TRANSFORMATION_KWARGS: "transformation_parameters",
160163
}
161164

162165
def __init__(self, spec: Dict = None, **kwargs) -> None:
@@ -325,6 +328,34 @@ def with_primary_keys(self, primary_keys: List[str]) -> "FeatureGroup":
325328
},
326329
)
327330

331+
@property
332+
def transformation_kwargs(self) -> str:
333+
return self.get_spec(self.CONST_TRANSFORMATION_KWARGS)
334+
335+
@transformation_kwargs.setter
336+
def transformation_kwargs(self, value: Dict):
337+
self.with_transformation_kwargs(value)
338+
339+
def with_transformation_kwargs(
340+
self, transformation_kwargs: Dict = {}
341+
) -> "FeatureGroup":
342+
"""Sets the primary keys of the feature group.
343+
344+
Parameters
345+
----------
346+
transformation_kwargs: Dict
347+
Dictionary containing the transformation arguments.
348+
349+
Returns
350+
-------
351+
FeatureGroup
352+
The FeatureGroup instance (self)
353+
"""
354+
return self.set_spec(
355+
self.CONST_TRANSFORMATION_KWARGS,
356+
Base64EncoderDecoder.encode(json.dumps(transformation_kwargs)),
357+
)
358+
328359
@property
329360
def partition_keys(self) -> List[str]:
330361
return self.get_spec(self.CONST_PARTITION_KEYS)

0 commit comments

Comments
 (0)