Skip to content

Commit 5f79aeb

Browse files
committed
fix: use sanitized model names in file names.
1 parent 0770141 commit 5f79aeb

File tree

4 files changed

+27
-18
lines changed

4 files changed

+27
-18
lines changed

src/sasctl/pzmm/pickle_model.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-License-Identifier: Apache-2.0
33
# %%
44
import codecs
5-
import gzip
65
import pickle
76
import shutil
87
from pathlib import Path
@@ -77,6 +76,9 @@ def pickle_trained_model(
7776
models.
7877
7978
"""
79+
from .write_score_code import ScoreCode
80+
sanitized_prefix = ScoreCode.sanitize_model_prefix(model_prefix)
81+
8082
if is_binary_string:
8183
# For models that use a binary string representation
8284
binary_string = codecs.encode(
@@ -91,25 +93,25 @@ def pickle_trained_model(
9193
# For models imported from MLFlow
9294
shutil.copy(ml_pickle_path, pickle_path)
9395
pzmm_pickle_path = Path(pickle_path) / mlflow_details["model_path"]
94-
pzmm_pickle_path.rename(Path(pickle_path) / (model_prefix + PICKLE))
96+
pzmm_pickle_path.rename(Path(pickle_path) / (sanitized_prefix + PICKLE))
9597
else:
9698
with open(ml_pickle_path, "rb") as pickle_file:
97-
return {model_prefix + PICKLE: pickle.load(pickle_file)}
99+
return {sanitized_prefix + PICKLE: pickle.load(pickle_file)}
98100
else:
99101
# For all other model types
100102
if not is_h2o_model:
101103
if pickle_path:
102104
with open(
103-
Path(pickle_path) / (model_prefix + PICKLE), "wb"
105+
Path(pickle_path) / (sanitized_prefix + PICKLE), "wb"
104106
) as pickle_file:
105107
pickle.dump(trained_model, pickle_file)
106108
if cls.notebook_output:
107109
print(
108110
f"Model {model_prefix} was successfully pickled and saved "
109-
f"to {Path(pickle_path) / (model_prefix + PICKLE)}."
111+
f"to {Path(pickle_path) / (sanitized_prefix + PICKLE)}."
110112
)
111113
else:
112-
return {model_prefix + PICKLE: pickle.dumps(trained_model)}
114+
return {sanitized_prefix + PICKLE: pickle.dumps(trained_model)}
113115
# For binary H2O models, save the binary file as a "pickle" file
114116
elif is_h2o_model and is_binary_model and pickle_path:
115117
if not h2o:
@@ -121,7 +123,7 @@ def pickle_trained_model(
121123
model=trained_model,
122124
force=True,
123125
path=str(pickle_path),
124-
filename=f"{model_prefix}.pickle",
126+
filename=f"{sanitized_prefix}.pickle",
125127
)
126128
# For MOJO H2O models, save as a mojo file and adjust the extension to .mojo
127129
elif is_h2o_model and pickle_path:
@@ -130,7 +132,7 @@ def pickle_trained_model(
130132
"The h2o package is required to save the model as a mojo model."
131133
)
132134
trained_model.save_mojo(
133-
force=True, path=str(pickle_path), filename=f"{model_prefix}.mojo"
135+
force=True, path=str(pickle_path), filename=f"{sanitized_prefix}.mojo"
134136
)
135137
elif is_binary_model or is_h2o_model:
136138
raise ValueError(

src/sasctl/pzmm/write_json_files.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -498,18 +498,22 @@ def write_file_metadata_json(
498498
Dictionary containing a key-value pair representing the file name and json
499499
dump respectively.
500500
"""
501+
502+
from .write_score_code import ScoreCode
503+
sanitized_prefix = ScoreCode.sanitize_model_prefix(model_prefix)
504+
501505
dict_list = [
502506
{"role": "inputVariables", "name": INPUT},
503507
{"role": "outputVariables", "name": OUTPUT},
504-
{"role": "score", "name": f"score_{model_prefix}.py"},
508+
{"role": "score", "name": f"score_{sanitized_prefix}.py"},
505509
]
506510
if is_h2o_model:
507-
dict_list.append({"role": "scoreResource", "name": model_prefix + ".mojo"})
511+
dict_list.append({"role": "scoreResource", "name": sanitized_prefix + ".mojo"})
508512
elif is_tf_keras_model:
509-
dict_list.append({"role": "scoreResource", "name": model_prefix + ".h5"})
513+
dict_list.append({"role": "scoreResource", "name": sanitized_prefix + ".h5"})
510514
else:
511515
dict_list.append(
512-
{"role": "scoreResource", "name": model_prefix + ".pickle"}
516+
{"role": "scoreResource", "name": sanitized_prefix + ".pickle"}
513517
)
514518

515519
if json_path:

src/sasctl/pzmm/write_score_code.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def write_score_code(
153153

154154
model_id = cls._check_viya_version(model)
155155

156+
sanitized_model_prefix = cls.sanitize_model_prefix(model_prefix)
157+
156158
# Set the model_file_name based on kwargs input
157159
if "model_file_name" in kwargs and "binary_string" in kwargs:
158160
raise ValueError(
@@ -203,7 +205,7 @@ def write_score_code(
203205
else:
204206
model_load = None
205207

206-
model_prefix = cls._check_valid_model_prefix(model_prefix)
208+
207209

208210
# Define the score function using the variables found in input_data
209211
cls.score_code += f"def score({', '.join(input_var_list)}):\n"
@@ -295,7 +297,7 @@ def score(var1, var2, var3, var4):
295297
)
296298

297299
if score_code_path:
298-
py_code_path = Path(score_code_path) / f"score_{model_prefix}.py"
300+
py_code_path = Path(score_code_path) / f"score_{sanitized_model_prefix}.py"
299301
with open(py_code_path, "w") as py_file:
300302
py_file.write(cls.score_code)
301303
if model_id and score_cas:
@@ -306,7 +308,7 @@ def score(var1, var2, var3, var4):
306308
# noinspection PyUnboundLocalVariable
307309
sas_file.write(cas_code)
308310
else:
309-
output_dict = {f"score_{model_prefix}.py": cls.score_code}
311+
output_dict = {f"score_{sanitized_model_prefix}.py": cls.score_code}
310312
if model_id and score_cas:
311313
# noinspection PyUnboundLocalVariable
312314
output_dict[MAS_CODE_NAME] = mas_code
@@ -2139,7 +2141,7 @@ def _check_viya_version(cls, model: Union[str, dict, RestObj]) -> Union[str, Non
21392141
return None
21402142

21412143
@staticmethod
2142-
def _check_valid_model_prefix(prefix: str) -> str:
2144+
def sanitize_model_prefix(prefix: str) -> str:
21432145
"""
21442146
Check the model_prefix for a valid Python function name.
21452147
@@ -2153,6 +2155,7 @@ def _check_valid_model_prefix(prefix: str) -> str:
21532155
-------
21542156
model_prefix : str
21552157
Returns a model_prefix, adjusted as needed for valid Python function names.
2158+
21562159
"""
21572160
# Replace model_prefix if a valid function name is not provided
21582161
if not prefix.isidentifier():

tests/unit/test_write_score_code.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,8 +1170,8 @@ def test_check_valid_model_prefix():
11701170
- check model_prefix validity
11711171
- raise warning and replace if invalid
11721172
"""
1173-
assert sc._check_valid_model_prefix("TestPrefix") == "TestPrefix"
1174-
assert sc._check_valid_model_prefix("Test Prefix") == "Test_Prefix"
1173+
assert sc.sanitize_model_prefix("TestPrefix") == "TestPrefix"
1174+
assert sc.sanitize_model_prefix("Test Prefix") == "Test_Prefix"
11751175

11761176

11771177
def test_write_score_code(score_code_mocks):

0 commit comments

Comments
 (0)