Skip to content

Commit b9d3aa3

Browse files
committed
Fixed broken unit tests for write_score_code
1 parent 2220d67 commit b9d3aa3

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

src/sasctl/pzmm/write_score_code.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,10 @@ def _wrap_indent_string(text, indent=0):
640640
Wrapped and indented string.
641641
"""
642642
wrapped_lines = textwrap.fill(str(text), width=88 - indent).split("\n")
643-
return "\n".join(f"{'':{indent}}" + line for line in wrapped_lines)
643+
if indent > 0:
644+
return "\n".join(f"{'':{indent}}" + line for line in wrapped_lines)
645+
else:
646+
return "\n".join(line for line in wrapped_lines)
644647

645648
@classmethod
646649
def _predict_method(
@@ -673,7 +676,6 @@ def _predict_method(
673676
Flag to indicate that the model is a tensorflow model. The default value is
674677
False.
675678
"""
676-
column_names = ", ".join(f'"{col}"' for col in var_list)
677679
# H2O models
678680
if dtype_list:
679681
column_types = "{"
@@ -685,13 +687,15 @@ def _predict_method(
685687
column_types += f'"{var}" : "{col_type}", '
686688
column_types = column_types.rstrip(", ")
687689
column_types += "}"
688-
input_dict = [f"'{var}': {var}" for var in var_list]
690+
input_dict = [f"\"{var}\": {var}" for var in var_list]
689691
cls.score_code += (f"{'':4}index=None\n"
690692
f"{'':4}if not isinstance({var_list[0]}, pd.Series):\n" +
691693
f"{'':8}index=[0]\n")
692-
693-
input_frame = f"{'':4}input_array = pd.DataFrame({{{','.join(input_dict)}}}, index=index)\n"
694-
cls.score_code += cls._wrap_indent_string(input_frame)
694+
695+
cls.score_code += f"{'':4}input_array = pd.DataFrame(\n"
696+
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
697+
cls.score_code += cls._wrap_indent_string(input_frame, 8)
698+
cls.score_code += f"\n{'':4})\n"
695699
if missing_values:
696700
cls.score_code += (
697701
f"{'':4}input_array = impute_missing_values(input_array)"
@@ -705,31 +709,36 @@ def _predict_method(
705709
)
706710
# Statsmodels models
707711
elif statsmodels_model:
708-
input_dict = [f"'{var}': {var}" for var in var_list]
709-
input_dict.append("'const': const")
712+
var_list.insert(0, "const")
713+
input_dict = [f"\"{var}\": {var}" for var in var_list]
710714
cls.score_code += (f"{'':4}index=None\n"
711715
f"{'':4}if not isinstance({var_list[0]}, pd.Series):\n"
712716
f"{'':8}index=[0]\n"
713717
f"{'':8}const = 1\n"
714718
f"{'':4}else:\n"
715719
f"{'':8}const = pd.Series([1 for x in len({var_list[0]})])")
716-
717-
input_frame = f"{'':4}input_array = pd.DataFrame({{{','.join(input_dict)}}}, index=index)\n"
718-
cls.score_code += cls._wrap_indent_string(input_frame)
720+
721+
cls.score_code += f"{'':4}input_array = pd.DataFrame(\n"
722+
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
723+
cls.score_code += cls._wrap_indent_string(input_frame, 8)
724+
cls.score_code += f"\n{'':4})\n"
719725
if missing_values:
720726
cls.score_code += (
721727
f"{'':4}input_array = impute_missing_values(input_array)"
722728
)
723729
cls.score_code += (
724-
f"{'':4}prediction = model.{method.__name__}" f"(input_array)\n"
730+
f"{'':4}prediction = model.{method.__name__}(input_array)\n"
725731
)
726732
elif tf_model:
727-
input_dict = [f"'{var}': {var}" for var in var_list]
733+
input_dict = [f"\"{var}\": {var}" for var in var_list]
728734
cls.score_code += (f"{'':4}index=None\n"
729735
f"{'':4}if not isinstance({var_list[0]}, pd.Series):\n"
730736
f"{'':8}index=[0]\n")
731-
732-
input_frame = f"{'':4}input_array = pd.DataFrame({{{','.join(input_dict)}}}, index=index)\n"
737+
738+
cls.score_code += f"{'':4}input_array = pd.DataFrame(\n"
739+
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
740+
cls.score_code += cls._wrap_indent_string(input_frame, 8)
741+
cls.score_code += f"\n{'':4})\n"
733742
if missing_values:
734743
cls.score_code += (
735744
f"{'':4}input_array = impute_missing_values(input_array)"
@@ -743,13 +752,15 @@ def _predict_method(
743752
f"{'':8}predictions = [p.tolist() for p in predictions]\n"
744753
)
745754
else:
746-
input_dict = [f"'{var}': {var}" for var in var_list]
755+
input_dict = [f"\"{var}\": {var}" for var in var_list]
747756
cls.score_code += (f"{'':4}index=None\n"
748757
f"{'':4}if not isinstance({var_list[0]}, pd.Series):\n"
749758
f"{'':8}index=[0]\n")
750-
751-
input_frame = f"{'':4}input_array = pd.DataFrame({{{','.join(input_dict)}}}, index=index)\n"
752-
cls.score_code += cls._wrap_indent_string(input_frame)
759+
760+
cls.score_code += f"{'':4}input_array = pd.DataFrame(\n"
761+
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
762+
cls.score_code += cls._wrap_indent_string(input_frame, 8)
763+
cls.score_code += f"\n{'':4})\n"
753764
if missing_values:
754765
cls.score_code += (
755766
f"{'':4}input_array = impute_missing_values(input_array)"

tests/unit/test_write_score_code.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,15 @@ def test_predict_method():
198198
var_list = ["first", "second", "third"]
199199
dtype_list = ["str", "int", "float"]
200200
sc._predict_method(predict_proba, var_list)
201-
assert f"pd.DataFrame([[first, second, third]]," in sc.score_code
201+
assert "{\"first\": first, \"second\": second" in sc.score_code
202202
sc.score_code = ""
203203

204204
sc._predict_method(predict_proba, var_list, dtype_list=dtype_list)
205205
assert "column_types = " in sc.score_code
206206
sc.score_code = ""
207207

208208
sc._predict_method(predict_proba, var_list, statsmodels_model=True)
209-
assert f"pd.DataFrame([[1.0, first, second, third]]," in sc.score_code
209+
assert "{\"const\": const, \"first\": first" in sc.score_code
210210
sc.score_code = ""
211211

212212

0 commit comments

Comments
 (0)