Skip to content

Commit 3105ac6

Browse files
committed
Added unit testing for write_score_code.py, tasks.py, and model_parameters.py
1 parent cb1ce9a commit 3105ac6

File tree

5 files changed

+339
-265
lines changed

5 files changed

+339
-265
lines changed

src/sasctl/pzmm/model_parameters.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,13 @@ def statsmodels_params():
171171

172172
if model.__class__.__module__.__contains__("sklearn"):
173173
sklearn_params()
174-
elif model.__class__.__module__.startswith("keras."):
174+
elif model.__class__.__module__.startswith("keras"):
175175
tf_params()
176-
elif model.__class__.__module__.startswith("xgboost."):
176+
elif model.__class__.__module__.startswith("xgboost"):
177177
xg_params()
178-
elif model.__class__.__module__.startswith("h2o."):
178+
elif model.__class__.__module__.startswith("h2o"):
179179
h2o_params()
180-
elif model.__class__.__module__.startswith("statsmodels."):
180+
elif model.__class__.__module__.startswith("statsmodels"):
181181
statsmodels_params()
182182

183183
else:
@@ -397,6 +397,7 @@ def get_project_kpis(
397397

398398
return kpi_table_df
399399

400+
@staticmethod
400401
def sync_model_properties(
401402
project: Union[str, dict, RestObj], overrwrite: Optional[bool] = False
402403
):
@@ -419,4 +420,4 @@ def sync_model_properties(
419420
# If property is set in project, check if it's set in model, and update model accordingly
420421
if model_property not in model or overrwrite:
421422
model[model_property] = project[project_property]
422-
mr.update_model(model)
423+
mr.update_model(model)

src/sasctl/tasks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def _format_properties(model, input_vars=None, output_vars=None):
192192
# Get input & output variable lists
193193
# Note: copying lists to avoid altering original
194194
input_vars = input_vars or model.get("inputVariables", [])
195-
output_vars = output_vars or model.get("outputVariables", [])[:]
195+
output_vars = output_vars or model.get("outputVariables", [])
196196
input_vars = input_vars[:]
197197
output_vars = output_vars[:]
198198
unformatted_variables = input_vars + output_vars
@@ -227,7 +227,7 @@ def _format_properties(model, input_vars=None, output_vars=None):
227227

228228

229229
def _compare_properties(project_name, model, input_vars=None, output_vars=None):
230-
properties, variables = _format_properties(model, input_vars, output_vars)
230+
properties, _ = _format_properties(model, input_vars, output_vars)
231231
project = mr.get_project(project_name)
232232
same_properties = True
233233
for p in properties:

tests/unit/test_model_parameters.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414

1515
from sasctl import RestObj, current_session
1616
from sasctl.pzmm import ModelParameters as mp
17+
import unittest
18+
import uuid
19+
import xgboost
20+
import h2o
21+
import tensorflow as tf
22+
import statsmodels.formula.api as smf
23+
import numpy as np
1724

1825

1926
class BadModel:
@@ -213,3 +220,104 @@ def test_get_project_kpis(self):
213220
assert len(kpi_table.columns) == 1
214221

215222
assert not kpi_table.loc[1]["testName"]
223+
224+
225+
class TestSyncModelProperties(unittest.TestCase):
226+
MODEL_PROPERTIES = [
227+
("targetVariable", "targetVariable"),
228+
("targetLevel", "targetLevel"),
229+
("targetEventValue", "targetEvent"),
230+
("eventProbabilityVariable", "eventProbVar"),
231+
("function", "function"),
232+
]
233+
234+
def test_project_id(self):
235+
with mock.patch('sasctl._services.model_repository.ModelRepository.get_project') as get_project:
236+
with mock.patch('sasctl._services.model_repository.ModelRepository.get') as get:
237+
with mock.patch('sasctl._services.model_repository.ModelRepository.get_model') as get_model:
238+
with mock.patch('sasctl._services.model_repository.ModelRepository.update_model') as update:
239+
pUUID = uuid.uuid4()
240+
mp.sync_model_properties(pUUID, False)
241+
get.assert_called_with(f"/projects/{pUUID}/models")
242+
243+
project_dict = {'id': 'projectID'}
244+
mp.sync_model_properties(project_dict, False)
245+
get.assert_called_with("/projects/projectID/models")
246+
247+
project_name = 'project'
248+
get_project.return_value = {'id': "pid"}
249+
mp.sync_model_properties(project_name, False)
250+
get.assert_called_with("/projects/pid/models")
251+
252+
def test_overwrite(self):
253+
with mock.patch('sasctl._services.model_repository.ModelRepository.get_project') as get_project:
254+
with mock.patch('sasctl._services.model_repository.ModelRepository.get') as get:
255+
with mock.patch('sasctl._services.model_repository.ModelRepository.get_model') as get_model:
256+
with mock.patch('sasctl._services.model_repository.ModelRepository.update_model') as update:
257+
project_dict = {'id': 'projectID', 'function': 'project_function', 'targetLevel': '1'}
258+
get.return_value = [RestObj({'id': 'modelID'})]
259+
get_model.return_value = {'function': 'classification'}
260+
mp.sync_model_properties(project_dict)
261+
update.assert_called_with({'function': 'classification', 'targetLevel': '1'})
262+
263+
project_dict = {'id': 'projectID', 'function': 'project_function', 'targetLevel': '1'}
264+
get.return_value = [RestObj({'id': 'modelID'})]
265+
get_model.return_value = {'function': 'classification'}
266+
mp.sync_model_properties(project_dict, True)
267+
update.assert_called_with({'function': 'project_function', 'targetLevel': '1'})
268+
269+
270+
class TestGenerateHyperparameters(unittest.TestCase):
271+
272+
def test_xgboost(self):
273+
model = unittest.mock.Mock()
274+
model.__class__ = xgboost.Booster
275+
attrs = {'save_config.return_value': json.dumps({'test': 'passed'})}
276+
model.configure_mock(**attrs)
277+
tmp_dir = tempfile.TemporaryDirectory()
278+
mp.generate_hyperparameters(model, 'prefix', Path(tmp_dir.name))
279+
assert Path(
280+
Path(tmp_dir.name) / f"./prefixHyperparameters.json"
281+
).exists()
282+
283+
def test_xgboost_sklearn(self):
284+
model = unittest.mock.Mock()
285+
model.__class__ = xgboost.XGBModel
286+
attrs = {'get_params.return_value': json.dumps({'test': 'passed'})}
287+
model.configure_mock(**attrs)
288+
tmp_dir = tempfile.TemporaryDirectory()
289+
mp.generate_hyperparameters(model, 'prefix', Path(tmp_dir.name))
290+
assert Path(
291+
Path(tmp_dir.name) / f"./prefixHyperparameters.json"
292+
).exists()
293+
294+
def test_h2o(self):
295+
model = unittest.mock.Mock()
296+
model.__class__ = h2o.H2OFrame
297+
attrs = {'get_params.return_value': json.dumps({'test': 'passed'})}
298+
model.configure_mock(**attrs)
299+
tmp_dir = tempfile.TemporaryDirectory()
300+
mp.generate_hyperparameters(model, 'prefix', Path(tmp_dir.name))
301+
assert Path(
302+
Path(tmp_dir.name) / f"./prefixHyperparameters.json"
303+
).exists()
304+
305+
def test_tensorflow(self):
306+
model = unittest.mock.Mock()
307+
model.__class__ = tf.keras.Sequential
308+
attrs = {'get_config.return_value': json.dumps({'test': 'passed'})}
309+
model.configure_mock(**attrs)
310+
tmp_dir = tempfile.TemporaryDirectory()
311+
mp.generate_hyperparameters(model, 'prefix', Path(tmp_dir.name))
312+
assert Path(
313+
Path(tmp_dir.name) / f"./prefixHyperparameters.json"
314+
).exists()
315+
316+
def test_statsmodels(self):
317+
model = unittest.mock.Mock(exog_names = ['test', 'exog'], weights = np.array([0, 1]))
318+
model.__class__ = smf.ols
319+
tmp_dir = tempfile.TemporaryDirectory()
320+
mp.generate_hyperparameters(model, 'prefix', Path(tmp_dir.name))
321+
assert Path(
322+
Path(tmp_dir.name) / f"./prefixHyperparameters.json"
323+
).exists()

0 commit comments

Comments
 (0)