|
14 | 14 |
|
15 | 15 | from sasctl import RestObj, current_session
|
16 | 16 | from sasctl.pzmm import ModelParameters as mp
|
| 17 | +import unittest |
| 18 | +import uuid |
| 19 | +import xgboost |
| 20 | +import h2o |
| 21 | +import tensorflow as tf |
| 22 | +import statsmodels.formula.api as smf |
| 23 | +import numpy as np |
17 | 24 |
|
18 | 25 |
|
19 | 26 | class BadModel:
|
@@ -213,3 +220,104 @@ def test_get_project_kpis(self):
|
213 | 220 | assert len(kpi_table.columns) == 1
|
214 | 221 |
|
215 | 222 | assert not kpi_table.loc[1]["testName"]
|
| 223 | + |
| 224 | + |
| 225 | +class TestSyncModelProperties(unittest.TestCase): |
| 226 | + MODEL_PROPERTIES = [ |
| 227 | + ("targetVariable", "targetVariable"), |
| 228 | + ("targetLevel", "targetLevel"), |
| 229 | + ("targetEventValue", "targetEvent"), |
| 230 | + ("eventProbabilityVariable", "eventProbVar"), |
| 231 | + ("function", "function"), |
| 232 | + ] |
| 233 | + |
| 234 | + def test_project_id(self): |
| 235 | + with mock.patch('sasctl._services.model_repository.ModelRepository.get_project') as get_project: |
| 236 | + with mock.patch('sasctl._services.model_repository.ModelRepository.get') as get: |
| 237 | + with mock.patch('sasctl._services.model_repository.ModelRepository.get_model') as get_model: |
| 238 | + with mock.patch('sasctl._services.model_repository.ModelRepository.update_model') as update: |
| 239 | + pUUID = uuid.uuid4() |
| 240 | + mp.sync_model_properties(pUUID, False) |
| 241 | + get.assert_called_with(f"/projects/{pUUID}/models") |
| 242 | + |
| 243 | + project_dict = {'id': 'projectID'} |
| 244 | + mp.sync_model_properties(project_dict, False) |
| 245 | + get.assert_called_with("/projects/projectID/models") |
| 246 | + |
| 247 | + project_name = 'project' |
| 248 | + get_project.return_value = {'id': "pid"} |
| 249 | + mp.sync_model_properties(project_name, False) |
| 250 | + get.assert_called_with("/projects/pid/models") |
| 251 | + |
| 252 | + def test_overwrite(self): |
| 253 | + with mock.patch('sasctl._services.model_repository.ModelRepository.get_project') as get_project: |
| 254 | + with mock.patch('sasctl._services.model_repository.ModelRepository.get') as get: |
| 255 | + with mock.patch('sasctl._services.model_repository.ModelRepository.get_model') as get_model: |
| 256 | + with mock.patch('sasctl._services.model_repository.ModelRepository.update_model') as update: |
| 257 | + project_dict = {'id': 'projectID', 'function': 'project_function', 'targetLevel': '1'} |
| 258 | + get.return_value = [RestObj({'id': 'modelID'})] |
| 259 | + get_model.return_value = {'function': 'classification'} |
| 260 | + mp.sync_model_properties(project_dict) |
| 261 | + update.assert_called_with({'function': 'classification', 'targetLevel': '1'}) |
| 262 | + |
| 263 | + project_dict = {'id': 'projectID', 'function': 'project_function', 'targetLevel': '1'} |
| 264 | + get.return_value = [RestObj({'id': 'modelID'})] |
| 265 | + get_model.return_value = {'function': 'classification'} |
| 266 | + mp.sync_model_properties(project_dict, True) |
| 267 | + update.assert_called_with({'function': 'project_function', 'targetLevel': '1'}) |
| 268 | + |
| 269 | + |
| 270 | +class TestGenerateHyperparameters(unittest.TestCase): |
| 271 | + |
| 272 | + def test_xgboost(self): |
| 273 | + model = unittest.mock.Mock() |
| 274 | + model.__class__ = xgboost.Booster |
| 275 | + attrs = {'save_config.return_value': json.dumps({'test': 'passed'})} |
| 276 | + model.configure_mock(**attrs) |
| 277 | + tmp_dir = tempfile.TemporaryDirectory() |
| 278 | + mp.generate_hyperparameters(model, 'prefix', Path(tmp_dir.name)) |
| 279 | + assert Path( |
| 280 | + Path(tmp_dir.name) / f"./prefixHyperparameters.json" |
| 281 | + ).exists() |
| 282 | + |
| 283 | + def test_xgboost_sklearn(self): |
| 284 | + model = unittest.mock.Mock() |
| 285 | + model.__class__ = xgboost.XGBModel |
| 286 | + attrs = {'get_params.return_value': json.dumps({'test': 'passed'})} |
| 287 | + model.configure_mock(**attrs) |
| 288 | + tmp_dir = tempfile.TemporaryDirectory() |
| 289 | + mp.generate_hyperparameters(model, 'prefix', Path(tmp_dir.name)) |
| 290 | + assert Path( |
| 291 | + Path(tmp_dir.name) / f"./prefixHyperparameters.json" |
| 292 | + ).exists() |
| 293 | + |
| 294 | + def test_h2o(self): |
| 295 | + model = unittest.mock.Mock() |
| 296 | + model.__class__ = h2o.H2OFrame |
| 297 | + attrs = {'get_params.return_value': json.dumps({'test': 'passed'})} |
| 298 | + model.configure_mock(**attrs) |
| 299 | + tmp_dir = tempfile.TemporaryDirectory() |
| 300 | + mp.generate_hyperparameters(model, 'prefix', Path(tmp_dir.name)) |
| 301 | + assert Path( |
| 302 | + Path(tmp_dir.name) / f"./prefixHyperparameters.json" |
| 303 | + ).exists() |
| 304 | + |
| 305 | + def test_tensorflow(self): |
| 306 | + model = unittest.mock.Mock() |
| 307 | + model.__class__ = tf.keras.Sequential |
| 308 | + attrs = {'get_config.return_value': json.dumps({'test': 'passed'})} |
| 309 | + model.configure_mock(**attrs) |
| 310 | + tmp_dir = tempfile.TemporaryDirectory() |
| 311 | + mp.generate_hyperparameters(model, 'prefix', Path(tmp_dir.name)) |
| 312 | + assert Path( |
| 313 | + Path(tmp_dir.name) / f"./prefixHyperparameters.json" |
| 314 | + ).exists() |
| 315 | + |
| 316 | + def test_statsmodels(self): |
| 317 | + model = unittest.mock.Mock(exog_names = ['test', 'exog'], weights = np.array([0, 1])) |
| 318 | + model.__class__ = smf.ols |
| 319 | + tmp_dir = tempfile.TemporaryDirectory() |
| 320 | + mp.generate_hyperparameters(model, 'prefix', Path(tmp_dir.name)) |
| 321 | + assert Path( |
| 322 | + Path(tmp_dir.name) / f"./prefixHyperparameters.json" |
| 323 | + ).exists() |
0 commit comments