Skip to content

Commit 58d32f1

Browse files
committed
feat: pass additional info when registering model
1 parent 816c0c2 commit 58d32f1

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/sasctl/tasks.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def register_model(
469469

470470
pzmm_files = {}
471471

472-
pickled_model = PickleModel.pickle_trained_model("model", info.model)
472+
pickled_model = PickleModel.pickle_trained_model(name, info.model)
473473
# Returns dict with "prefix.pickle": bytes
474474
assert len(pickled_model) == 1
475475

@@ -490,7 +490,13 @@ def register_model(
490490
pzmm_files.update(metadata)
491491
pzmm_files.update(properties)
492492

493-
model_obj, _ = ImportModel.import_model(pzmm_files, name, project)
493+
model_obj, _ = ImportModel.import_model(model_files=pzmm_files,
494+
model_prefix=name,
495+
project=project,
496+
input_data=info.X,
497+
predict_method=info.predict_function,
498+
predict_threshold=info.threshold,
499+
target_values=info.target_values)
494500
return model_obj
495501

496502
# # If the model is a scikit-learn model, generate the model dictionary

0 commit comments

Comments
 (0)