Skip to content

Commit 9d839b8

Browse files
committed
changes to ensure all unit tests pass
1 parent cdb3a52 commit 9d839b8

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/sasctl/pzmm/write_score_code.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,18 @@ def _write_imports(
475475
import tensorflow as tf
476476
477477
"""
478+
elif binary_string:
479+
cls.score_code += (
480+
f'import codecs\n\nbinary_string = "{binary_string}"'
481+
f"\nmodel = {pickle_type}.loads(codecs.decode(binary_string"
482+
'.encode(), "base64"))\n\n'
483+
)
484+
"""
485+
import codecs
486+
487+
binary_string = "<binary string>"
488+
model = pickle.load(codecs.decode(binary_string.encode(), "base64"))
489+
"""
478490

479491

480492
@classmethod

tests/unit/test_write_json_files.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,6 @@ def test_create_requirements_json(change_dir):
700700
assert (Path(tmp_dir) / "requirements.json").exists()
701701

702702
json_dict = jf.create_requirements_json(tmp_dir)
703-
assert "requirements.json" in json_dict
704703
expected = [
705704
{
706705
"step": "install numpy",
@@ -713,7 +712,7 @@ def test_create_requirements_json(change_dir):
713712
]
714713
unittest.TestCase.maxDiff = None
715714
unittest.TestCase().assertCountEqual(
716-
json.loads(json_dict["requirements.json"]), expected
715+
json_dict, expected
717716
)
718717

719718

@@ -889,7 +888,7 @@ class TestModelCardGeneration(unittest.TestCase):
889888
def test_generate_outcome_average_interval(self):
890889
df = pd.DataFrame({"input": [3, 2, 1], "output": [1, 2, 3]})
891890
assert (
892-
jf.generate_outcome_average(df, ["input"], "interval") ==
891+
jf.generate_outcome_average(df, ["input"], "prediction") ==
893892
{'eventAverage': 2.0}
894893
)
895894

@@ -901,7 +900,7 @@ def test_generate_outcome_average_classification(self):
901900
def test_generate_outcome_average_interval_non_numeric_output(self):
902901
df = pd.DataFrame({"input": [3, 2, 1], "output": ["one", "two", "three"]})
903902
with pytest.raises(ValueError):
904-
jf.generate_outcome_average(df, ["input"], "interval")
903+
jf.generate_outcome_average(df, ["input"], "prediction")
905904

906905

907906
class TestGetSelectionStatisticValue(unittest.TestCase):

0 commit comments

Comments
 (0)