10
10
import logging
11
11
import math
12
12
import os
13
- import pickle # skipcq BAN-B301
14
13
import re
15
14
import sys
16
15
from warnings import warn
@@ -49,59 +48,6 @@ def _property(k, v):
49
48
return {"name" : str (k )[:_PROP_NAME_MAXLEN ], "value" : str (v )[:_PROP_VALUE_MAXLEN ]}
50
49
51
50
52
- def _sklearn_to_dict (model ):
53
- # Convert Scikit-learn values to built-in Model Manager values
54
- mappings = {
55
- "LogisticRegression" : "Logistic regression" ,
56
- "LinearRegression" : "Linear regression" ,
57
- "SVC" : "Support vector machine" ,
58
- "GradientBoostingClassifier" : "Gradient boosting" ,
59
- "GradientBoostingRegressor" : "Gradient boosting" ,
60
- "XGBClassifier" : "Gradient boosting" ,
61
- "XGBRegressor" : "Gradient boosting" ,
62
- "RandomForestClassifier" : "Forest" ,
63
- "DecisionTreeClassifier" : "Decision tree" ,
64
- "DecisionTreeRegressor" : "Decision tree" ,
65
- "classifier" : "classification" ,
66
- "regressor" : "prediction" ,
67
- }
68
-
69
- if hasattr (model , "_final_estimator" ):
70
- estimator = model ._final_estimator
71
- else :
72
- estimator = model
73
- estimator = type (estimator ).__name__
74
-
75
- # Standardize algorithm names
76
- algorithm = mappings .get (estimator , estimator )
77
-
78
- # Standardize regression/classification terms
79
- analytic_function = mappings .get (model ._estimator_type , model ._estimator_type )
80
-
81
- if analytic_function == "classification" and "logistic" in algorithm .lower ():
82
- target_level = "Binary"
83
- elif analytic_function == "prediction" and (
84
- "regressor" in estimator .lower () or "regression" in algorithm .lower ()
85
- ):
86
- target_level = "Interval"
87
- else :
88
- target_level = None
89
-
90
- # Can tell if multi-class .multi_class
91
- result = dict (
92
- description = str (model )[:_DESC_MAXLEN ],
93
- algorithm = algorithm ,
94
- scoreCodeType = "ds2MultiType" ,
95
- trainCodeType = "Python" ,
96
- targetLevel = target_level ,
97
- function = analytic_function ,
98
- tool = "Python %s.%s" % (sys .version_info .major , sys .version_info .minor ),
99
- properties = [_property (k , v ) for k , v in model .get_params ().items ()],
100
- )
101
-
102
- return result
103
-
104
-
105
51
def _create_project (project_name , model , repo , input_vars = None , output_vars = None ):
106
52
"""Creates a project based on the model specifications.
107
53
@@ -468,14 +414,14 @@ def register_model(
468
414
469
415
pzmm_files = {}
470
416
471
- pickled_model = PickleModel .pickle_trained_model (name , info .model )
417
+ pickled_model = PickleModel () .pickle_trained_model (name , info .model )
472
418
# Returns dict with "prefix.pickle": bytes
473
419
assert len (pickled_model ) == 1
474
420
475
- input_vars = JSONFiles .write_var_json (info .X , is_input = True )
476
- output_vars = JSONFiles .write_var_json (info .y , is_input = False )
477
- metadata = JSONFiles .write_file_metadata_json (model_prefix = name )
478
- properties = JSONFiles .write_model_properties_json (
421
+ input_vars = JSONFiles () .write_var_json (info .X , is_input = True )
422
+ output_vars = JSONFiles () .write_var_json (info .y , is_input = False )
423
+ metadata = JSONFiles () .write_file_metadata_json (model_prefix = name )
424
+ properties = JSONFiles () .write_model_properties_json (
479
425
model_name = name ,
480
426
model_desc = info .description ,
481
427
model_algorithm = info .algorithm ,
@@ -489,7 +435,7 @@ def register_model(
489
435
pzmm_files .update (metadata )
490
436
pzmm_files .update (properties )
491
437
492
- model_obj , _ = ImportModel .import_model (model_files = pzmm_files ,
438
+ model_obj , _ = ImportModel () .import_model (model_files = pzmm_files ,
493
439
model_prefix = name ,
494
440
project = project ,
495
441
input_data = info .X ,
0 commit comments