Skip to content

Commit 3aae663

Browse files
committed
Black reformatting + minor reformatting
1 parent b9d3aa3 commit 3aae663

File tree

2 files changed

+41
-31
lines changed

2 files changed

+41
-31
lines changed

src/sasctl/pzmm/write_score_code.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -522,11 +522,11 @@ def _viya4_model_load(
522522
Flag to indicate that the model is a H2O.ai binary model. The default value
523523
is None.
524524
tf_keras_model : boolean, optional
525-
Flag to indicate that the model is a tensorflow keras model. The default value is
526-
None.
525+
Flag to indicate that the model is a tensorflow keras model. The default
526+
value is False.
527527
tf_core_model : boolean, optional
528-
Flag to indicate that the model is a tensorflow core model. The default value is
529-
None.
528+
Flag to indicate that the model is a tensorflow core model. The default
529+
value is False.
530530
"""
531531
pickle_type = pickle_type if pickle_type else "pickle"
532532

@@ -551,11 +551,13 @@ def _viya4_model_load(
551551
elif tf_keras_model:
552552
cls.score_code += (
553553
f"model = tf.keras.models.load_model(Path(settings.pickle_path) / "
554-
f"\"{str(Path(model_file_name).with_suffix('.h5'))}\", safe_mode=True)\n"
554+
f"\"{str(Path(model_file_name).with_suffix('.h5'))}\", "
555+
f"safe_mode=True)\n"
555556
)
556557
return (
557-
f"{'':8}model = tf.keras.models.load_model(Path(settings.pickle_path) / "
558-
f"\"{str(Path(model_file_name).with_suffix('.h5'))}\", safe_mode=True)\n"
558+
f"{'':8}model = tf.keras.models.load_model(Path(settings.pickle_path) "
559+
f"/ \"{str(Path(model_file_name).with_suffix('.h5'))}\", "
560+
f"safe_mode=True)\n"
559561
)
560562
else:
561563
cls.score_code += (
@@ -687,10 +689,12 @@ def _predict_method(
687689
column_types += f'"{var}" : "{col_type}", '
688690
column_types = column_types.rstrip(", ")
689691
column_types += "}"
690-
input_dict = [f"\"{var}\": {var}" for var in var_list]
691-
cls.score_code += (f"{'':4}index=None\n"
692-
f"{'':4}if not isinstance({var_list[0]}, pd.Series):\n" +
693-
f"{'':8}index=[0]\n")
692+
input_dict = [f'"{var}": {var}' for var in var_list]
693+
cls.score_code += (
694+
f"{'':4}index=None\n"
695+
f"{'':4}if not isinstance({var_list[0]}, pd.Series):\n"
696+
+ f"{'':8}index=[0]\n"
697+
)
694698

695699
cls.score_code += f"{'':4}input_array = pd.DataFrame(\n"
696700
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
@@ -710,13 +714,15 @@ def _predict_method(
710714
# Statsmodels models
711715
elif statsmodels_model:
712716
var_list.insert(0, "const")
713-
input_dict = [f"\"{var}\": {var}" for var in var_list]
714-
cls.score_code += (f"{'':4}index=None\n"
715-
f"{'':4}if not isinstance({var_list[0]}, pd.Series):\n"
716-
f"{'':8}index=[0]\n"
717-
f"{'':8}const = 1\n"
718-
f"{'':4}else:\n"
719-
f"{'':8}const = pd.Series([1 for x in len({var_list[0]})])")
717+
input_dict = [f'"{var}": {var}' for var in var_list]
718+
cls.score_code += (
719+
f"{'':4}index=None\n"
720+
f"{'':4}if not isinstance({var_list[0]}, pd.Series):\n"
721+
f"{'':8}index=[0]\n"
722+
f"{'':8}const = 1\n"
723+
f"{'':4}else:\n"
724+
f"{'':8}const = pd.Series([1 for x in len({var_list[0]})])"
725+
)
720726

721727
cls.score_code += f"{'':4}input_array = pd.DataFrame(\n"
722728
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
@@ -730,10 +736,12 @@ def _predict_method(
730736
f"{'':4}prediction = model.{method.__name__}(input_array)\n"
731737
)
732738
elif tf_model:
733-
input_dict = [f"\"{var}\": {var}" for var in var_list]
734-
cls.score_code += (f"{'':4}index=None\n"
735-
f"{'':4}if not isinstance({var_list[0]}, pd.Series):\n"
736-
f"{'':8}index=[0]\n")
739+
input_dict = [f'"{var}": {var}' for var in var_list]
740+
cls.score_code += (
741+
f"{'':4}index=None\n"
742+
f"{'':4}if not isinstance({var_list[0]}, pd.Series):\n"
743+
f"{'':8}index=[0]\n"
744+
)
737745

738746
cls.score_code += f"{'':4}input_array = pd.DataFrame(\n"
739747
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
@@ -747,15 +755,17 @@ def _predict_method(
747755
f"{'':4}prediction = model.{method.__name__}(input_array)\n"
748756
f"{'':4} # Check if model returns logits or probabilities\n"
749757
f"{'':4}if not math.isclose(sum(predictions[0]), 1, rel_tol=.01):\n"
750-
f"{'':8}predictions = [tf.nn.softmax(p).numpy().tolist() for p in predictions]\n"
751-
f"{'':4}else:\n"
758+
f"{'':8}predictions = [tf.nn.softmax(p).numpy().tolist() for p in "
759+
f"predictions]\n{'':4}else:\n"
752760
f"{'':8}predictions = [p.tolist() for p in predictions]\n"
753761
)
754762
else:
755-
input_dict = [f"\"{var}\": {var}" for var in var_list]
756-
cls.score_code += (f"{'':4}index=None\n"
757-
f"{'':4}if not isinstance({var_list[0]}, pd.Series):\n"
758-
f"{'':8}index=[0]\n")
763+
input_dict = [f'"{var}": {var}' for var in var_list]
764+
cls.score_code += (
765+
f"{'':4}index=None\n"
766+
f"{'':4}if not isinstance({var_list[0]}, pd.Series):\n"
767+
f"{'':8}index=[0]\n"
768+
)
759769

760770
cls.score_code += f"{'':4}input_array = pd.DataFrame(\n"
761771
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
@@ -1134,7 +1144,7 @@ def _binary_target(
11341144
"the target event to occur."
11351145
)
11361146
cls.score_code += (
1137-
f"{'':4}return prediction[1][0], " f"float(prediction[1][2])"
1147+
f"{'':4}return prediction[1][0], float(prediction[1][2])"
11381148
)
11391149
# Calculate the classification; return the classification and probability
11401150
elif sum(returns) == 0 and len(returns) == 1:

tests/unit/test_write_score_code.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,15 @@ def test_predict_method():
198198
var_list = ["first", "second", "third"]
199199
dtype_list = ["str", "int", "float"]
200200
sc._predict_method(predict_proba, var_list)
201-
assert "{\"first\": first, \"second\": second" in sc.score_code
201+
assert '{"first": first, "second": second' in sc.score_code
202202
sc.score_code = ""
203203

204204
sc._predict_method(predict_proba, var_list, dtype_list=dtype_list)
205205
assert "column_types = " in sc.score_code
206206
sc.score_code = ""
207207

208208
sc._predict_method(predict_proba, var_list, statsmodels_model=True)
209-
assert "{\"const\": const, \"first\": first" in sc.score_code
209+
assert '{"const": const, "first": first' in sc.score_code
210210
sc.score_code = ""
211211

212212

0 commit comments

Comments
 (0)