Skip to content

Commit 074c01d

Browse files
committed
black formatting
1 parent aea7a68 commit 074c01d

File tree

8 files changed

+685
-569
lines changed

8 files changed

+685
-569
lines changed

src/sasctl/pzmm/model_parameters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,4 +420,4 @@ def sync_model_properties(
420420
# If property is set in project, check if it's set in model, and update model accordingly
421421
if model_property not in model or overrwrite:
422422
model[model_property] = project[project_property]
423-
mr.update_model(model)
423+
mr.update_model(model)

src/sasctl/pzmm/write_json_files.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,7 +1146,9 @@ def bias_dataframes_to_json(
11461146
json_dict[1]["parameterMap"]["predict"]["label"] = pred_values
11471147
json_dict[1]["parameterMap"]["predict"]["parameter"] = pred_values
11481148
json_dict[1]["parameterMap"]["predict"]["values"] = [pred_values]
1149-
json_dict[1]["parameterMap"][pred_values] = json_dict[1]["parameterMap"]["predict"]
1149+
json_dict[1]["parameterMap"][pred_values] = json_dict[1]["parameterMap"][
1150+
"predict"
1151+
]
11501152
del json_dict[1]["parameterMap"]["predict"]
11511153

11521154
if json_path:
@@ -1164,8 +1166,6 @@ def bias_dataframes_to_json(
11641166
GROUPMETRICS: json.dumps(json_dict[1], indent=4, cls=NpEncoder),
11651167
}
11661168

1167-
1168-
11691169
@classmethod
11701170
def calculate_model_statistics(
11711171
cls,

src/sasctl/pzmm/write_score_code.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -469,9 +469,9 @@ def _write_imports(
469469
"""
470470
elif binary_string:
471471
cls.score_code += (
472-
f"import codecs\n\nbinary_string = \"{binary_string}\""
472+
f'import codecs\n\nbinary_string = "{binary_string}"'
473473
f"\nmodel = {pickle_type}.loads(codecs.decode(binary_string"
474-
".encode(), \"base64\"))\n\n"
474+
'.encode(), "base64"))\n\n'
475475
)
476476
"""
477477
import codecs
@@ -532,8 +532,8 @@ def _viya35_model_load(
532532
)
533533
elif binary_h2o_model:
534534
cls.score_code += (
535-
f"model = h2o.load(str(Path(\"/models/resources/viya/"
536-
f"{model_id}/{model_file_name}\")))\n\n"
535+
f'model = h2o.load(str(Path("/models/resources/viya/'
536+
f'{model_id}/{model_file_name}")))\n\n'
537537
)
538538
"""
539539
model = h2o.load(str(Path("/models/resources/viya/<UUID>/model.h2o")))
@@ -545,8 +545,8 @@ def _viya35_model_load(
545545
)
546546
else:
547547
cls.score_code += (
548-
f"model_path = Path(\"/models/resources/viya/{model_id}"
549-
f"\")\nwith open(model_path / \"{model_file_name}\", "
548+
f'model_path = Path("/models/resources/viya/{model_id}'
549+
f'")\nwith open(model_path / "{model_file_name}", '
550550
f"\"rb\") as pickle_model:\n{'':4}model = {pickle_type}"
551551
".load(pickle_model)\n\n"
552552
)
@@ -602,7 +602,7 @@ def _viya4_model_load(
602602
if mojo_model:
603603
cls.score_code += (
604604
f"model = h2o.import_mojo(str(Path(settings.pickle_path"
605-
f") / \"{model_file_name}\"))\n\n"
605+
f') / "{model_file_name}"))\n\n'
606606
)
607607
"""
608608
model = h2o.import_mojo(str(Path(settings.pickle_path) / "model.mojo"))
@@ -1201,7 +1201,8 @@ def _no_targets_no_thresholds(
12011201
if h2o_model:
12021202
cls.score_code += (
12031203
f"{'':4}if input_array.shape[0] == 1:\n"
1204-
f"{'':8}{metrics[0]} = prediction[1][0]\n")
1204+
f"{'':8}{metrics[0]} = prediction[1][0]\n"
1205+
)
12051206
for i in range(len(metrics) - 1):
12061207
cls.score_code += (
12071208
f"{'':8}{metrics[i + 1]} = float(prediction[1][{i + 1}])\n"
@@ -1326,7 +1327,8 @@ def _binary_target(
13261327
f"{'':4}if input_array.shape[0] == 1:\n"
13271328
f"{'':8}return prediction\n"
13281329
f"{'':4}else:\n"
1329-
f"{'':8}return pd.DataFrame({{'{metrics}': prediction}})")
1330+
f"{'':8}return pd.DataFrame({{'{metrics}': prediction}})"
1331+
)
13301332
"""
13311333
if input_array.shape[0] == 1:
13321334
return prediction
@@ -1498,7 +1500,8 @@ def _binary_target(
14981500
f"{'':4}if input_array.shape[0] == 1:\n"
14991501
f"{'':8}return prediction[0], prediction[1]\n"
15001502
f"{'':4}else:\n"
1501-
f"{'':8}return pd.DataFrame(prediction, columns={metrics})")
1503+
f"{'':8}return pd.DataFrame(prediction, columns={metrics})"
1504+
)
15021505
"""
15031506
if input_array.shape[0] == 1:
15041507
return prediction[0], prediction[1]
@@ -1520,7 +1523,8 @@ def _binary_target(
15201523
f"{'':8}return prediction[0], prediction[1]\n"
15211524
f"{'':4}else:\n"
15221525
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{metric_list}])\n"
1523-
f"{'':8}return output_table.drop('drop', axis=1)")
1526+
f"{'':8}return output_table.drop('drop', axis=1)"
1527+
)
15241528

15251529
"""
15261530
if input_array.shape[0] == 1:
@@ -1537,7 +1541,8 @@ def _binary_target(
15371541
f"{'':4}else:\n"
15381542
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{metric_list}])\n"
15391543
f"{'':8}output_table = output_table[output_table.columns[::-1]]\n"
1540-
f"{'':8}return output_table.drop('drop', axis=1)")
1544+
f"{'':8}return output_table.drop('drop', axis=1)"
1545+
)
15411546
"""
15421547
if input_array.shape[0] == 1:
15431548
return prediction[2], prediction[0]
@@ -1782,7 +1787,7 @@ def _nonbinary_targets(
17821787
f"{'':8}return prediction[{class_index}]\n"
17831788
f"{'':4}else:\n"
17841789
f"{'':8}return pd.DataFrame({{'{metrics}': [p[{class_index}] for p in prediction]}})"
1785-
)
1790+
)
17861791
"""
17871792
if input_array.shape[0] == 1:
17881793
return prediction[0]

tests/conftest.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -532,9 +532,7 @@ def sklearn_classification_model(iris_dataset):
532532
with warnings.catch_warnings():
533533
warnings.simplefilter("ignore")
534534
model = sk.LogisticRegression(
535-
multi_class="multinomial",
536-
solver="lbfgs",
537-
max_iter=1000
535+
multi_class="multinomial", solver="lbfgs", max_iter=1000
538536
)
539537
model.fit(X, y)
540538
return model

tests/unit/test_model_parameters.py

Lines changed: 65 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -232,92 +232,111 @@ class TestSyncModelProperties(unittest.TestCase):
232232
]
233233

234234
def test_project_id(self):
235-
with mock.patch('sasctl._services.model_repository.ModelRepository.get_project') as get_project:
236-
with mock.patch('sasctl._services.model_repository.ModelRepository.get') as get:
237-
with mock.patch('sasctl._services.model_repository.ModelRepository.get_model') as get_model:
238-
with mock.patch('sasctl._services.model_repository.ModelRepository.update_model') as update:
235+
with mock.patch(
236+
"sasctl._services.model_repository.ModelRepository.get_project"
237+
) as get_project:
238+
with mock.patch(
239+
"sasctl._services.model_repository.ModelRepository.get"
240+
) as get:
241+
with mock.patch(
242+
"sasctl._services.model_repository.ModelRepository.get_model"
243+
) as get_model:
244+
with mock.patch(
245+
"sasctl._services.model_repository.ModelRepository.update_model"
246+
) as update:
239247
pUUID = uuid.uuid4()
240248
mp.sync_model_properties(pUUID, False)
241249
get.assert_called_with(f"/projects/{pUUID}/models")
242250

243-
project_dict = {'id': 'projectID'}
251+
project_dict = {"id": "projectID"}
244252
mp.sync_model_properties(project_dict, False)
245253
get.assert_called_with("/projects/projectID/models")
246254

247-
project_name = 'project'
248-
get_project.return_value = {'id': "pid"}
255+
project_name = "project"
256+
get_project.return_value = {"id": "pid"}
249257
mp.sync_model_properties(project_name, False)
250258
get.assert_called_with("/projects/pid/models")
251259

252260
def test_overwrite(self):
253-
with mock.patch('sasctl._services.model_repository.ModelRepository.get_project') as get_project:
254-
with mock.patch('sasctl._services.model_repository.ModelRepository.get') as get:
255-
with mock.patch('sasctl._services.model_repository.ModelRepository.get_model') as get_model:
256-
with mock.patch('sasctl._services.model_repository.ModelRepository.update_model') as update:
257-
project_dict = {'id': 'projectID', 'function': 'project_function', 'targetLevel': '1'}
258-
get.return_value = [RestObj({'id': 'modelID'})]
259-
get_model.return_value = {'function': 'classification'}
261+
with mock.patch(
262+
"sasctl._services.model_repository.ModelRepository.get_project"
263+
) as get_project:
264+
with mock.patch(
265+
"sasctl._services.model_repository.ModelRepository.get"
266+
) as get:
267+
with mock.patch(
268+
"sasctl._services.model_repository.ModelRepository.get_model"
269+
) as get_model:
270+
with mock.patch(
271+
"sasctl._services.model_repository.ModelRepository.update_model"
272+
) as update:
273+
project_dict = {
274+
"id": "projectID",
275+
"function": "project_function",
276+
"targetLevel": "1",
277+
}
278+
get.return_value = [RestObj({"id": "modelID"})]
279+
get_model.return_value = {"function": "classification"}
260280
mp.sync_model_properties(project_dict)
261-
update.assert_called_with({'function': 'classification', 'targetLevel': '1'})
262-
263-
project_dict = {'id': 'projectID', 'function': 'project_function', 'targetLevel': '1'}
264-
get.return_value = [RestObj({'id': 'modelID'})]
265-
get_model.return_value = {'function': 'classification'}
281+
update.assert_called_with(
282+
{"function": "classification", "targetLevel": "1"}
283+
)
284+
285+
project_dict = {
286+
"id": "projectID",
287+
"function": "project_function",
288+
"targetLevel": "1",
289+
}
290+
get.return_value = [RestObj({"id": "modelID"})]
291+
get_model.return_value = {"function": "classification"}
266292
mp.sync_model_properties(project_dict, True)
267-
update.assert_called_with({'function': 'project_function', 'targetLevel': '1'})
293+
update.assert_called_with(
294+
{"function": "project_function", "targetLevel": "1"}
295+
)
268296

269297

270298
class TestGenerateHyperparameters(unittest.TestCase):
271-
272299
def test_xgboost(self):
273300
model = unittest.mock.Mock()
274301
model.__class__ = xgboost.Booster
275-
attrs = {'save_config.return_value': json.dumps({'test': 'passed'})}
302+
attrs = {"save_config.return_value": json.dumps({"test": "passed"})}
276303
model.configure_mock(**attrs)
277304
tmp_dir = tempfile.TemporaryDirectory()
278-
mp.generate_hyperparameters(model, 'prefix', Path(tmp_dir.name))
279-
assert Path(
280-
Path(tmp_dir.name) / f"./prefixHyperparameters.json"
281-
).exists()
305+
mp.generate_hyperparameters(model, "prefix", Path(tmp_dir.name))
306+
assert Path(Path(tmp_dir.name) / f"./prefixHyperparameters.json").exists()
282307

283308
def test_xgboost_sklearn(self):
284309
model = unittest.mock.Mock()
285310
model.__class__ = xgboost.XGBModel
286-
attrs = {'get_params.return_value': json.dumps({'test': 'passed'})}
311+
attrs = {"get_params.return_value": json.dumps({"test": "passed"})}
287312
model.configure_mock(**attrs)
288313
tmp_dir = tempfile.TemporaryDirectory()
289-
mp.generate_hyperparameters(model, 'prefix', Path(tmp_dir.name))
290-
assert Path(
291-
Path(tmp_dir.name) / f"./prefixHyperparameters.json"
292-
).exists()
314+
mp.generate_hyperparameters(model, "prefix", Path(tmp_dir.name))
315+
assert Path(Path(tmp_dir.name) / f"./prefixHyperparameters.json").exists()
293316

294317
def test_h2o(self):
295318
model = unittest.mock.Mock()
296319
model.__class__ = h2o.H2OFrame
297-
attrs = {'get_params.return_value': json.dumps({'test': 'passed'})}
320+
attrs = {"get_params.return_value": json.dumps({"test": "passed"})}
298321
model.configure_mock(**attrs)
299322
tmp_dir = tempfile.TemporaryDirectory()
300-
mp.generate_hyperparameters(model, 'prefix', Path(tmp_dir.name))
301-
assert Path(
302-
Path(tmp_dir.name) / f"./prefixHyperparameters.json"
303-
).exists()
323+
mp.generate_hyperparameters(model, "prefix", Path(tmp_dir.name))
324+
assert Path(Path(tmp_dir.name) / f"./prefixHyperparameters.json").exists()
304325

305326
def test_tensorflow(self):
306327
model = unittest.mock.Mock()
307328
model.__class__ = tf.keras.Sequential
308-
attrs = {'get_config.return_value': json.dumps({'test': 'passed'})}
329+
attrs = {"get_config.return_value": json.dumps({"test": "passed"})}
309330
model.configure_mock(**attrs)
310331
tmp_dir = tempfile.TemporaryDirectory()
311-
mp.generate_hyperparameters(model, 'prefix', Path(tmp_dir.name))
312-
assert Path(
313-
Path(tmp_dir.name) / f"./prefixHyperparameters.json"
314-
).exists()
332+
mp.generate_hyperparameters(model, "prefix", Path(tmp_dir.name))
333+
assert Path(Path(tmp_dir.name) / f"./prefixHyperparameters.json").exists()
315334

316335
def test_statsmodels(self):
317-
model = unittest.mock.Mock(exog_names = ['test', 'exog'], weights = np.array([0, 1]))
336+
model = unittest.mock.Mock(
337+
exog_names=["test", "exog"], weights=np.array([0, 1])
338+
)
318339
model.__class__ = smf.ols
319340
tmp_dir = tempfile.TemporaryDirectory()
320-
mp.generate_hyperparameters(model, 'prefix', Path(tmp_dir.name))
321-
assert Path(
322-
Path(tmp_dir.name) / f"./prefixHyperparameters.json"
323-
).exists()
341+
mp.generate_hyperparameters(model, "prefix", Path(tmp_dir.name))
342+
assert Path(Path(tmp_dir.name) / f"./prefixHyperparameters.json").exists()

0 commit comments

Comments
 (0)