Skip to content

Commit f7b4b14

Browse files
committed
Fix requirements.json test to use scikit-learn version available for pickle file
1 parent 7ecc7ce commit f7b4b14

File tree

1 file changed

+51
-15
lines changed

1 file changed

+51
-15
lines changed

tests/unit/test_write_json_files.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010
import pickle
1111
import random
12+
import shutil
1213
import sys
1314
import tempfile
1415
import unittest
@@ -19,6 +20,8 @@
1920
import numpy as np
2021
import pandas as pd
2122
import pytest
23+
from sklearn.model_selection import train_test_split
24+
from sklearn.tree import DecisionTreeClassifier
2225

2326
import sasctl.pzmm as pzmm
2427
from sasctl.pzmm.write_json_files import JSONFiles as jf
@@ -628,23 +631,56 @@ def test_create_requirements_json(change_dir):
628631

629632
example_model = (Path.cwd() / "data/hmeqModels/DecisionTreeClassifier").resolve()
630633
with tempfile.TemporaryDirectory() as tmp_dir:
631-
jf.create_requirements_json(example_model, Path(tmp_dir))
634+
tmp_dir = Path(tmp_dir)
635+
for item in example_model.iterdir():
636+
if item.is_file() and item.name != "DecisionTreeClassifier.pickle":
637+
shutil.copy(item, tmp_dir / tmp_dir.name)
638+
data = pd.read_csv("data/hmeq.csv")
639+
predictor_columns = [
640+
"LOAN",
641+
"MORTDUE",
642+
"VALUE",
643+
"YOJ",
644+
"DEROG",
645+
"DELINQ",
646+
"CLAGE",
647+
"NINQ",
648+
"CLNO",
649+
"DEBTINC",
650+
]
651+
target_column = "BAD"
652+
x = data[predictor_columns]
653+
y = data[target_column]
654+
x_train, x_test, y_train, y_test = train_test_split(
655+
x, y, test_size=0.3, random_state=42
656+
)
657+
x_test.fillna(x_test.mean(), inplace=True)
658+
x_train.fillna(x_train.mean(), inplace=True)
659+
dtc = DecisionTreeClassifier(
660+
max_depth=7, min_samples_split=2, min_samples_leaf=2, max_leaf_nodes=500
661+
)
662+
dtc = dtc.fit(x_train, y_train)
663+
with open(tmp_dir / "DecisionTreeClassifier.pickle", "wb") as pkl_file:
664+
pickle.dump(dtc, pkl_file)
665+
jf.create_requirements_json(tmp_dir, Path(tmp_dir))
632666
assert (Path(tmp_dir) / "requirements.json").exists()
633667

634-
json_dict = jf.create_requirements_json(example_model)
635-
assert "requirements.json" in json_dict
636-
expected = [
637-
{"step": "install pandas", "command": f"pip install pandas=={pd.__version__}"},
638-
{"step": "install numpy", "command": f"pip install numpy=={np.__version__}"},
639-
{
640-
"step": "install sklearn",
641-
"command": f"pip install sklearn=={sk.__version__}",
642-
},
643-
]
644-
unittest.TestCase.maxDiff = None
645-
unittest.TestCase().assertCountEqual(
646-
json.loads(json_dict["requirements.json"]), expected
647-
)
668+
json_dict = jf.create_requirements_json(tmp_dir)
669+
assert "requirements.json" in json_dict
670+
expected = [
671+
{
672+
"step": "install numpy",
673+
"command": f"pip install numpy=={np.__version__}",
674+
},
675+
{
676+
"step": "install sklearn",
677+
"command": f"pip install sklearn=={sk.__version__}",
678+
},
679+
]
680+
unittest.TestCase.maxDiff = None
681+
unittest.TestCase().assertCountEqual(
682+
json.loads(json_dict["requirements.json"]), expected
683+
)
648684

649685

650686
class TestAssessBiasHelpers(unittest.TestCase):

0 commit comments

Comments
 (0)