Skip to content

Commit 7f7c900

Browse files
committed
chore: remove obsolete code
1 parent 7ced3f6 commit 7f7c900

File tree

2 files changed

+6
-95
lines changed

2 files changed

+6
-95
lines changed

src/sasctl/tasks.py

Lines changed: 6 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import logging
1111
import math
1212
import os
13-
import pickle # skipcq BAN-B301
1413
import re
1514
import sys
1615
from warnings import warn
@@ -49,59 +48,6 @@ def _property(k, v):
4948
return {"name": str(k)[:_PROP_NAME_MAXLEN], "value": str(v)[:_PROP_VALUE_MAXLEN]}
5049

5150

52-
def _sklearn_to_dict(model):
53-
# Convert Scikit-learn values to built-in Model Manager values
54-
mappings = {
55-
"LogisticRegression": "Logistic regression",
56-
"LinearRegression": "Linear regression",
57-
"SVC": "Support vector machine",
58-
"GradientBoostingClassifier": "Gradient boosting",
59-
"GradientBoostingRegressor": "Gradient boosting",
60-
"XGBClassifier": "Gradient boosting",
61-
"XGBRegressor": "Gradient boosting",
62-
"RandomForestClassifier": "Forest",
63-
"DecisionTreeClassifier": "Decision tree",
64-
"DecisionTreeRegressor": "Decision tree",
65-
"classifier": "classification",
66-
"regressor": "prediction",
67-
}
68-
69-
if hasattr(model, "_final_estimator"):
70-
estimator = model._final_estimator
71-
else:
72-
estimator = model
73-
estimator = type(estimator).__name__
74-
75-
# Standardize algorithm names
76-
algorithm = mappings.get(estimator, estimator)
77-
78-
# Standardize regression/classification terms
79-
analytic_function = mappings.get(model._estimator_type, model._estimator_type)
80-
81-
if analytic_function == "classification" and "logistic" in algorithm.lower():
82-
target_level = "Binary"
83-
elif analytic_function == "prediction" and (
84-
"regressor" in estimator.lower() or "regression" in algorithm.lower()
85-
):
86-
target_level = "Interval"
87-
else:
88-
target_level = None
89-
90-
# Can tell if multi-class .multi_class
91-
result = dict(
92-
description=str(model)[:_DESC_MAXLEN],
93-
algorithm=algorithm,
94-
scoreCodeType="ds2MultiType",
95-
trainCodeType="Python",
96-
targetLevel=target_level,
97-
function=analytic_function,
98-
tool="Python %s.%s" % (sys.version_info.major, sys.version_info.minor),
99-
properties=[_property(k, v) for k, v in model.get_params().items()],
100-
)
101-
102-
return result
103-
104-
10551
def _create_project(project_name, model, repo, input_vars=None, output_vars=None):
10652
"""Creates a project based on the model specifications.
10753
@@ -468,14 +414,14 @@ def register_model(
468414

469415
pzmm_files = {}
470416

471-
pickled_model = PickleModel.pickle_trained_model(name, info.model)
417+
pickled_model = PickleModel().pickle_trained_model(name, info.model)
472418
# Returns dict with "prefix.pickle": bytes
473419
assert len(pickled_model) == 1
474420

475-
input_vars = JSONFiles.write_var_json(info.X, is_input=True)
476-
output_vars = JSONFiles.write_var_json(info.y, is_input=False)
477-
metadata = JSONFiles.write_file_metadata_json(model_prefix=name)
478-
properties = JSONFiles.write_model_properties_json(
421+
input_vars = JSONFiles().write_var_json(info.X, is_input=True)
422+
output_vars = JSONFiles().write_var_json(info.y, is_input=False)
423+
metadata = JSONFiles().write_file_metadata_json(model_prefix=name)
424+
properties = JSONFiles().write_model_properties_json(
479425
model_name=name,
480426
model_desc=info.description,
481427
model_algorithm=info.algorithm,
@@ -489,7 +435,7 @@ def register_model(
489435
pzmm_files.update(metadata)
490436
pzmm_files.update(properties)
491437

492-
model_obj, _ = ImportModel.import_model(model_files=pzmm_files,
438+
model_obj, _ = ImportModel().import_model(model_files=pzmm_files,
493439
model_prefix=name,
494440
project=project,
495441
input_data=info.X,

tests/unit/test_tasks.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,41 +19,6 @@
1919
)
2020

2121

22-
def test_sklearn_metadata():
23-
pytest.importorskip("sklearn")
24-
25-
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
26-
from sklearn.linear_model import LinearRegression, LogisticRegression
27-
from sklearn.svm import SVC
28-
from sklearn.tree import DecisionTreeClassifier
29-
30-
from sasctl.tasks import _sklearn_to_dict
31-
32-
info = _sklearn_to_dict(LinearRegression())
33-
assert info["algorithm"] == "Linear regression"
34-
assert info["function"] == "prediction"
35-
36-
info = _sklearn_to_dict(LogisticRegression())
37-
assert info["algorithm"] == "Logistic regression"
38-
assert info["function"] == "classification"
39-
40-
info = _sklearn_to_dict(SVC())
41-
assert info["algorithm"] == "Support vector machine"
42-
assert info["function"] == "classification"
43-
44-
info = _sklearn_to_dict(GradientBoostingClassifier())
45-
assert info["algorithm"] == "Gradient boosting"
46-
assert info["function"] == "classification"
47-
48-
info = _sklearn_to_dict(DecisionTreeClassifier())
49-
assert info["algorithm"] == "Decision tree"
50-
assert info["function"] == "classification"
51-
52-
info = _sklearn_to_dict(RandomForestClassifier())
53-
assert info["algorithm"] == "Forest"
54-
assert info["function"] == "classification"
55-
56-
5722
def test_parse_module_url():
5823
from sasctl.tasks import _parse_module_url
5924

0 commit comments

Comments
 (0)