@@ -36,6 +36,7 @@ def write_score_code(
36
36
score_cas : Optional [bool ] = True ,
37
37
score_code_path : Union [Path , str , None ] = None ,
38
38
target_index : Optional [int ] = None ,
39
+ preprocess_function : Optional [Callable [DataFrame , DataFrame ]] = None ,
39
40
** kwargs ,
40
41
) -> Union [dict , None ]:
41
42
"""
@@ -292,6 +293,9 @@ def score(var1, var2, var3, var4):
292
293
if missing_values :
293
294
self ._impute_missing_values (input_data , missing_values )
294
295
296
+ if preprocess_function :
297
+ self ._add_preprocess_code (preprocess_function )
298
+
295
299
# SAS Viya 3.5 model
296
300
if model_id :
297
301
mas_code , cas_code = self ._viya35_score_code_import (
@@ -759,6 +763,7 @@ def _predict_method(
759
763
missing_values : Optional [Any ] = None ,
760
764
statsmodels_model : Optional [bool ] = False ,
761
765
tf_model : Optional [bool ] = False ,
766
+ preprocess_function : Optional [Callable [DataFrame , DataFrame ]] = None ,
762
767
) -> None :
763
768
"""
764
769
Write the model prediction section of the score code.
@@ -809,6 +814,10 @@ def _predict_method(
809
814
input_frame = f'{{{ ", " .join (input_dict )} }}, index=index'
810
815
self .score_code += self ._wrap_indent_string (input_frame , 8 )
811
816
self .score_code += f"\n { '' :4} )\n "
817
+ if preprocess_function :
818
+ self .score_code += (
819
+ f"{ '' :4} input_array = { preprocess_function .__name__ } (input_array)"
820
+ )
812
821
if missing_values :
813
822
self .score_code += (
814
823
f"{ '' :4} input_array = impute_missing_values(input_array)\n "
@@ -851,6 +860,10 @@ def _predict_method(
851
860
input_frame = f'{{{ ", " .join (input_dict )} }}, index=index'
852
861
self .score_code += self ._wrap_indent_string (input_frame , 8 )
853
862
self .score_code += f"\n { '' :4} )\n "
863
+ if preprocess_function :
864
+ self .score_code += (
865
+ f"{ '' :4} input_array = { preprocess_function .__name__ } (input_array)"
866
+ )
854
867
if missing_values :
855
868
self .score_code += (
856
869
f"{ '' :4} input_array = impute_missing_values(input_array)\n "
@@ -872,6 +885,10 @@ def _predict_method(
872
885
input_frame = f'{{{ ", " .join (input_dict )} }}, index=index'
873
886
self .score_code += self ._wrap_indent_string (input_frame , 8 )
874
887
self .score_code += f"\n { '' :4} )\n "
888
+ if preprocess_function :
889
+ self .score_code += (
890
+ f"{ '' :4} input_array = { preprocess_function .__name__ } (input_array)"
891
+ )
875
892
if missing_values :
876
893
self .score_code += (
877
894
f"{ '' :4} input_array = impute_missing_values(input_array)\n "
@@ -904,6 +921,10 @@ def _predict_method(
904
921
input_frame = f'{{{ ", " .join (input_dict )} }}, index=index'
905
922
self .score_code += self ._wrap_indent_string (input_frame , 8 )
906
923
self .score_code += f"\n { '' :4} )\n "
924
+ if preprocess_function :
925
+ self .score_code += (
926
+ f"{ '' :4} input_array = { preprocess_function .__name__ } (input_array)"
927
+ )
907
928
if missing_values :
908
929
self .score_code += (
909
930
f"{ '' :4} input_array = impute_missing_values(input_array)\n "
@@ -2238,3 +2259,30 @@ def _viya35_score_code_import(
2238
2259
model ["scoreCodeType" ] = "ds2MultiType"
2239
2260
mr .update_model (model )
2240
2261
return mas_code , cas_code
2262
+
2263
+ def _add_preprocess_code (
2264
+ self ,
2265
+ preprocess_function : Callable [DataFrame , DataFrame ]
2266
+ ):
2267
+ """
2268
+ Places the given preprocess function, which must both take a DataFrame as an argument
2269
+ and return a DataFrame, into the score code. If the preprocess function does not
2270
+ return anything, an error is thrown.
2271
+
2272
+ Parameters
2273
+ ----------
2274
+ preprocess_function: function
2275
+ The preprocess function to be added to the score code.
2276
+ """
2277
+ import inspect
2278
+ preprocess_code = inspect .getsource (preprocess_function )
2279
+ if not "return" in preprocess_code :
2280
+ raise ValueError (
2281
+ "The given score code does not return a value. " +
2282
+ "To allow for the score code to work correctly, please ensure the preprocessed " +
2283
+ "data is returned."
2284
+ )
2285
+ if self .score_code [- 1 ] == '\n ' :
2286
+ self .score_code += preprocess_code
2287
+ else :
2288
+ self .score_code += '\n ' + preprocess_code
0 commit comments