diff --git a/src/sasctl/tasks.py b/src/sasctl/tasks.py index d466c10f..dfa07e64 100644 --- a/src/sasctl/tasks.py +++ b/src/sasctl/tasks.py @@ -15,6 +15,7 @@ from pathlib import Path from typing import Union from warnings import warn +import zipfile import pandas as pd @@ -998,3 +999,55 @@ def score_model_with_cas( print(score_execution_poll) score_results = se.get_score_execution_results(score_execution, use_cas_gateway) return score_results + + +def upload_local_model( + path: Union[str, Path], + model_name: str, + project_name: str, + repo_name: Union[str, dict] = None, + version: str = "latest", +): + """A barebones function to upload a model and any associated files to the model repository. + Parameters + ---------- + path : Union[str, Path] + The path to the model and any associated files. + model_name : str + The name of the model. + project_name : str + The name of the project to which the model will be uploaded. + """ + # Use default repository if not specified + try: + if repo_name is None: + repository = mr.default_repository() + else: + repository = mr.get_repository(repo_name) + except HTTPError as e: + if e.code == 403: + raise AuthorizationError( + "Unable to register model. User account does not have read permissions " + "for the /modelRepository/repositories/ URL. Please contact your SAS " + "Viya administrator." + ) + raise e + + # Unable to find or create the repo. + if not repository and not repo_name: + raise ValueError("Unable to find a default repository") + elif not repository: + raise ValueError(f"Unable to find repository '{repo_name}'") + p = mr.get_project(project_name) + if p is None: + mr.create_project(project_name, repository) + zip_name = str(Path(path) / (model_name + ".zip")) + file_names = sorted(Path(path).glob("*[!zip]")) + with zipfile.ZipFile(str(zip_name), mode="w") as zFile: + for file in file_names: + zFile.write(str(file), arcname=file.name) + with open(zip_name, "rb") as zip_file: + model = mr.import_model_from_zip( + model_name, project_name, zip_file, version=version + ) + return model