Skip to content

Commit ee23c18

Browse files
committed
added capability to have preprocessing function in score code
1 parent 9b5e4c1 commit ee23c18

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

src/sasctl/pzmm/import_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def import_model(
213213
target_values: Optional[List[str]] = None,
214214
overwrite_project_properties: Optional[bool] = False,
215215
target_index: Optional[int] = None,
216+
preprocess_function: Optional[Callable[DataFrame, DataFrame]] = None,
216217
**kwargs,
217218
) -> Tuple[RestObj, Union[dict, str, Path]]:
218219
"""
@@ -371,6 +372,7 @@ def import_model(
371372
missing_values=missing_values,
372373
score_cas=score_cas,
373374
target_index=target_index,
375+
preprocess_function=preprocess_function,
374376
**kwargs,
375377
)
376378
if score_code_dict:
@@ -471,6 +473,7 @@ def import_model(
471473
missing_values=missing_values,
472474
score_cas=score_cas,
473475
target_index=target_index,
476+
preprocess_function=preprocess_function,
474477
**kwargs,
475478
)
476479
if score_code_dict:

src/sasctl/pzmm/write_score_code.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def write_score_code(
3636
score_cas: Optional[bool] = True,
3737
score_code_path: Union[Path, str, None] = None,
3838
target_index: Optional[int] = None,
39+
preprocess_function: Optional[Callable[DataFrame, DataFrame]] = None,
3940
**kwargs,
4041
) -> Union[dict, None]:
4142
"""
@@ -292,6 +293,9 @@ def score(var1, var2, var3, var4):
292293
if missing_values:
293294
self._impute_missing_values(input_data, missing_values)
294295

296+
if preprocess_function:
297+
self._add_preprocess_code(preprocess_function)
298+
295299
# SAS Viya 3.5 model
296300
if model_id:
297301
mas_code, cas_code = self._viya35_score_code_import(
@@ -759,6 +763,7 @@ def _predict_method(
759763
missing_values: Optional[Any] = None,
760764
statsmodels_model: Optional[bool] = False,
761765
tf_model: Optional[bool] = False,
766+
preprocess_function: Optional[Callable[DataFrame, DataFrame]] = None,
762767
) -> None:
763768
"""
764769
Write the model prediction section of the score code.
@@ -809,6 +814,10 @@ def _predict_method(
809814
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
810815
self.score_code += self._wrap_indent_string(input_frame, 8)
811816
self.score_code += f"\n{'':4})\n"
817+
if preprocess_function:
818+
self.score_code += (
819+
f"{'':4}input_array = {preprocess_function.__name__}(input_array)"
820+
)
812821
if missing_values:
813822
self.score_code += (
814823
f"{'':4}input_array = impute_missing_values(input_array)\n"
@@ -851,6 +860,10 @@ def _predict_method(
851860
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
852861
self.score_code += self._wrap_indent_string(input_frame, 8)
853862
self.score_code += f"\n{'':4})\n"
863+
if preprocess_function:
864+
self.score_code += (
865+
f"{'':4}input_array = {preprocess_function.__name__}(input_array)"
866+
)
854867
if missing_values:
855868
self.score_code += (
856869
f"{'':4}input_array = impute_missing_values(input_array)\n"
@@ -872,6 +885,10 @@ def _predict_method(
872885
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
873886
self.score_code += self._wrap_indent_string(input_frame, 8)
874887
self.score_code += f"\n{'':4})\n"
888+
if preprocess_function:
889+
self.score_code += (
890+
f"{'':4}input_array = {preprocess_function.__name__}(input_array)"
891+
)
875892
if missing_values:
876893
self.score_code += (
877894
f"{'':4}input_array = impute_missing_values(input_array)\n"
@@ -904,6 +921,10 @@ def _predict_method(
904921
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
905922
self.score_code += self._wrap_indent_string(input_frame, 8)
906923
self.score_code += f"\n{'':4})\n"
924+
if preprocess_function:
925+
self.score_code += (
926+
f"{'':4}input_array = {preprocess_function.__name__}(input_array)"
927+
)
907928
if missing_values:
908929
self.score_code += (
909930
f"{'':4}input_array = impute_missing_values(input_array)\n"
@@ -2238,3 +2259,30 @@ def _viya35_score_code_import(
22382259
model["scoreCodeType"] = "ds2MultiType"
22392260
mr.update_model(model)
22402261
return mas_code, cas_code
2262+
2263+
def _add_preprocess_code(
2264+
self,
2265+
preprocess_function: Callable[DataFrame, DataFrame]
2266+
):
2267+
"""
2268+
Places the given preprocess function, which must both take a DataFrame as an argument
2269+
and return a DataFrame, into the score code. If the preprocess function does not
2270+
return anything, an error is thrown.
2271+
2272+
Parameters
2273+
----------
2274+
preprocess_function: function
2275+
The preprocess function to be added to the score code.
2276+
"""
2277+
import inspect
2278+
preprocess_code = inspect.getsource(preprocess_function)
2279+
if not "return" in preprocess_code:
2280+
raise ValueError(
2281+
"The given score code does not return a value. " +
2282+
"To allow for the score code to work correctly, please ensure the preprocessed " +
2283+
"data is returned."
2284+
)
2285+
if self.score_code[-1] == '\n':
2286+
self.score_code += preprocess_code
2287+
else:
2288+
self.score_code += '\n' + preprocess_code

tests/unit/test_write_score_code.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,29 @@ def test_impute_missing_values():
188188
assert "'b': 'test'" in sc.score_code
189189
assert "'c': 1" in sc.score_code or "'c': np.int64(1)" in sc.score_code
190190

191+
def test_preprocess_function():
192+
"""
193+
Test Cases:
194+
- function
195+
- function with no return
196+
"""
197+
test_df = pd.DataFrame(
198+
data=[[0, "a", 1], [2, "b", 0]], columns=["num", "char", "bin"]
199+
)
200+
sc = ScoreCode()
201+
def preprocess_function_one(data: pd.DataFrame):
202+
print("preprocessing happens here")
203+
return data
204+
sc._add_preprocess_code(preprocess_function_one)
205+
assert "preprocessing happens here" in sc.score_code
206+
assert "preprocess_function_one" in sc.score_code
207+
208+
sc = ScoreCode()
209+
def preprocess_function_two(data: pd.DataFrame):
210+
print("preprocessing happens here?")
211+
with pytest.raises(ValueError):
212+
sc._add_preprocess_code(preprocess_function_two)
213+
191214

192215
def test_predict_method():
193216
"""

0 commit comments

Comments
 (0)