Skip to content

Commit b5c0e06

Browse files
committed
chore: use pzmm for open source models
1 parent 4b0016b commit b5c0e06

File tree

1 file changed

+117
-80
lines changed

1 file changed

+117
-80
lines changed

src/sasctl/tasks.py

Lines changed: 117 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -458,88 +458,125 @@ def register_model(
458458
)
459459
return model
460460

461-
# If the model is a scikit-learn model, generate the model dictionary
462-
# from it and pickle the model for storage
463-
if all(hasattr(model, attr) for attr in ["_estimator_type", "get_params"]):
464-
# Pickle the model so we can store it
465-
model_pkl = pickle.dumps(model)
466-
files.append({"name": "model.pkl", "file": model_pkl, "role": "Python Pickle"})
467-
468-
target_funcs = [f for f in ("predict", "predict_proba") if hasattr(model, f)]
469-
470-
# Extract model properties
471-
model = _sklearn_to_dict(model)
472-
model["name"] = name
473-
474-
# Get package versions in environment
475-
packages = installed_packages()
476-
if record_packages and packages is not None:
477-
model.setdefault("properties", [])
478-
479-
# Define a custom property to capture each package version
480-
# NOTE: some packages may not conform to the 'name==version' format
481-
# expected here (e.g those installed with pip install -e). Such
482-
# packages also generally contain characters that are not allowed
483-
# in custom properties, so they are excluded here.
484-
for p in packages:
485-
if "==" in p:
486-
n, v = p.split("==")
487-
model["properties"].append(_property("env_%s" % n, v))
488-
489-
# Generate and upload a requirements.txt file
490-
files.append({"name": "requirements.txt", "file": "\n".join(packages)})
491-
492-
# Generate PyMAS wrapper
461+
if not isinstance(model, dict):
493462
try:
494-
mas_module = from_pickle(
495-
model_pkl, target_funcs, input_types=input, array_input=True
496-
)
497-
498-
# Include score code files from ESP and MAS
499-
files.append(
500-
{
501-
"name": "dmcas_packagescorecode.sas",
502-
"file": mas_module.score_code(),
503-
"role": "Score Code",
504-
}
505-
)
506-
files.append(
507-
{
508-
"name": "dmcas_epscorecode.sas",
509-
"file": mas_module.score_code(dest="CAS"),
510-
"role": "score",
511-
}
512-
)
513-
files.append(
514-
{
515-
"name": "python_wrapper.py",
516-
"file": mas_module.score_code(dest="Python"),
517-
}
518-
)
463+
info = utils.get_model_info(model, X=input)
464+
except ValueError as e:
465+
logger.debug("Model of type %s could not be inspected: %s", type(model), e)
466+
raise
467+
468+
from .pzmm import ImportModel, JSONFiles, PickleModel
469+
470+
pzmm_files = {}
471+
472+
pickled_model = PickleModel.pickle_trained_model("model", info.model)
473+
# Returns dict with "prefix.pickle": bytes
474+
assert len(pickled_model) == 1
475+
476+
input_vars = JSONFiles.write_var_json(info.X, is_input=True)
477+
output_vars = JSONFiles.write_var_json(info.y, is_input=False)
478+
metadata = JSONFiles.write_file_metadata_json(model_prefix=name)
479+
properties = JSONFiles.write_model_properties_json(
480+
model_name=name,
481+
model_desc=info.description,
482+
model_algorithm=info.algorithm,
483+
target_variable=info.target_column,
484+
target_values=info.target_values
485+
)
519486

520-
model["inputVariables"] = [
521-
var.as_model_metadata() for var in mas_module.variables if not var.out
522-
]
523-
524-
model["outputVariables"] = [
525-
var.as_model_metadata() for var in mas_module.variables if var.out
526-
]
527-
except ValueError:
528-
# PyMAS creation failed, most likely because input data wasn't
529-
# provided
530-
logger.exception("Unable to inspect model %s", model)
531-
532-
warn(
533-
"Unable to determine input/output variables. "
534-
" Model variables will not be specified and some "
535-
"model functionality may not be available."
536-
)
537-
else:
538-
# Otherwise, the model better be a dictionary of metadata
539-
if not isinstance(model, dict):
540-
raise TypeError(
541-
"Expected an instance of '%r' but received '%r'." % ({}, model)
542-
)
487+
pzmm_files.update(pickled_model)
488+
pzmm_files.update(input_vars)
489+
pzmm_files.update(output_vars)
490+
pzmm_files.update(metadata)
491+
pzmm_files.update(properties)
492+
493+
model_obj, _ = ImportModel.import_model(pzmm_files, name, project)
494+
return model_obj
495+
496+
# # If the model is a scikit-learn model, generate the model dictionary
497+
# # from it and pickle the model for storage
498+
# if all(hasattr(model, attr) for attr in ["_estimator_type", "get_params"]):
499+
# # Pickle the model so we can store it
500+
# model_pkl = pickle.dumps(model)
501+
# files.append({"name": "model.pkl", "file": model_pkl, "role": "Python Pickle"})
502+
#
503+
# target_funcs = [f for f in ("predict", "predict_proba") if hasattr(model, f)]
504+
#
505+
# # Extract model properties
506+
# model = _sklearn_to_dict(model)
507+
# model["name"] = name
508+
#
509+
# # Get package versions in environment
510+
# packages = installed_packages()
511+
# if record_packages and packages is not None:
512+
# model.setdefault("properties", [])
513+
#
514+
# # Define a custom property to capture each package version
515+
# # NOTE: some packages may not conform to the 'name==version' format
516+
# # expected here (e.g those installed with pip install -e). Such
517+
# # packages also generally contain characters that are not allowed
518+
# # in custom properties, so they are excluded here.
519+
# for p in packages:
520+
# if "==" in p:
521+
# n, v = p.split("==")
522+
# model["properties"].append(_property("env_%s" % n, v))
523+
#
524+
# # Generate and upload a requirements.txt file
525+
# files.append({"name": "requirements.txt", "file": "\n".join(packages)})
526+
#
527+
# # Generate PyMAS wrapper
528+
# try:
529+
# mas_module = from_pickle(
530+
# model_pkl, target_funcs, input_types=input, array_input=True
531+
# )
532+
#
533+
# # Include score code files from ESP and MAS
534+
# files.append(
535+
# {
536+
# "name": "dmcas_packagescorecode.sas",
537+
# "file": mas_module.score_code(),
538+
# "role": "Score Code",
539+
# }
540+
# )
541+
# files.append(
542+
# {
543+
# "name": "dmcas_epscorecode.sas",
544+
# "file": mas_module.score_code(dest="CAS"),
545+
# "role": "score",
546+
# }
547+
# )
548+
# files.append(
549+
# {
550+
# "name": "python_wrapper.py",
551+
# "file": mas_module.score_code(dest="Python"),
552+
# }
553+
# )
554+
#
555+
# model["inputVariables"] = [
556+
# var.as_model_metadata() for var in mas_module.variables if not var.out
557+
# ]
558+
#
559+
# model["outputVariables"] = [
560+
# var.as_model_metadata() for var in mas_module.variables if var.out
561+
# ]
562+
# except ValueError:
563+
# # PyMAS creation failed, most likely because input data wasn't
564+
# # provided
565+
# logger.exception("Unable to inspect model %s", model)
566+
#
567+
# warn(
568+
# "Unable to determine input/output variables. "
569+
# " Model variables will not be specified and some "
570+
# "model functionality may not be available."
571+
# )
572+
# else:
573+
# # Otherwise, the model better be a dictionary of metadata
574+
# if not isinstance(model, dict):
575+
# raise TypeError(
576+
# "Expected an instance of '%r' but received '%r'." % ({}, model)
577+
# )
578+
579+
# If we got this far, then `model` is a dictionary of model metadata.
543580

544581
if create_project:
545582
project = _create_project(project, model, repo_obj)

0 commit comments

Comments
 (0)