diff --git a/doc/api/services/report_images.rst b/doc/api/services/report_images.rst index 7f99ef44..a88ed5f5 100644 --- a/doc/api/services/report_images.rst +++ b/doc/api/services/report_images.rst @@ -3,4 +3,5 @@ sasctl.services.report_images .. automodule:: sasctl._services.report_images :members: - :undoc-members: \ No newline at end of file + :undoc-members: + :show-inheritance: diff --git a/doc/api/services/reports.rst b/doc/api/services/reports.rst index eeb923ff..02c5def2 100644 --- a/doc/api/services/reports.rst +++ b/doc/api/services/reports.rst @@ -4,3 +4,4 @@ sasctl.services.reports .. automodule:: sasctl._services.reports :members: :undoc-members: + :show-inheritance: diff --git a/doc/conf.py b/doc/conf.py index 3d588087..5fef0d6f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -109,7 +109,7 @@ ('py:class','Response'), ('py:class','Request'), ('py:class','_io.BytesIO'), - ('py:class','sasctl.utils.pymas.ds2.Ds2Variable'), # not sure what is wrong + ('py:class','sasctl.utils.pymas.ds2.Ds2Variable'), # not sure what is wrong ('py:class','sasctl._services.service.Service') # should the Service class be documented? ] diff --git a/src/sasctl/_services/cas_management.py b/src/sasctl/_services/cas_management.py index 90d4108f..f2fd7592 100644 --- a/src/sasctl/_services/cas_management.py +++ b/src/sasctl/_services/cas_management.py @@ -31,7 +31,7 @@ def check_keys(valid_keys: list, input_keys: list, parameters: str): Raises ------ - ValueError + ValueError if input_keys are not valid """ if not all(key in valid_keys for key in input_keys): diff --git a/src/sasctl/_services/concepts.py b/src/sasctl/_services/concepts.py index a74d3ef0..3e18d220 100644 --- a/src/sasctl/_services/concepts.py +++ b/src/sasctl/_services/concepts.py @@ -56,7 +56,7 @@ def assign_concepts( output_postfix : str, optional Text to be added to the end of all output table names. match_type : str, optional - Choose from ``{'all', 'longest', 'best'}``. + Choose from ``{'all', 'longest', 'best'}``. Type of matches to return. Defaults to 'all'. enable_facts : bool, optional Whether to enable facts in the results. Defaults to False. diff --git a/src/sasctl/_services/data_sources.py b/src/sasctl/_services/data_sources.py index 1d8cf6b0..92140952 100644 --- a/src/sasctl/_services/data_sources.py +++ b/src/sasctl/_services/data_sources.py @@ -40,7 +40,7 @@ def get_provider(cls, provider, refresh=False): A dictionary containing the provider attributes or None. Notes - ------- + ----- If `provider` is a complete representation of the provider it will be returned unless `refresh` is set. This prevents unnecessary REST calls when data is already available on the client. diff --git a/src/sasctl/_services/folders.py b/src/sasctl/_services/folders.py index 03cf7150..adc25a7c 100644 --- a/src/sasctl/_services/folders.py +++ b/src/sasctl/_services/folders.py @@ -71,7 +71,7 @@ def get_folder(cls, folder, refresh=False): ---------- folder : str or dict May be one of: - + - folder name - folder ID - folder path diff --git a/src/sasctl/_services/service.py b/src/sasctl/_services/service.py index 6af9dc9d..b1fd1497 100644 --- a/src/sasctl/_services/service.py +++ b/src/sasctl/_services/service.py @@ -229,7 +229,7 @@ def get_item(cls, item, refresh=False): A dictionary containing the {item} attributes or None. Notes - ------- + ----- If `item` is a complete representation of the {item} it will be returned unless `refresh` is set. This prevents unnecessary REST calls when data is already available on the client. diff --git a/src/sasctl/core.py b/src/sasctl/core.py index 3cdc9226..c8fc3f7a 100644 --- a/src/sasctl/core.py +++ b/src/sasctl/core.py @@ -1254,7 +1254,7 @@ def _request_token_with_oauth( """Request a token from the SAS SASLogon service. Supports four different flows: - + - authenticate with a username & password and receive a token - authenticate with a client id & secret and receive a token - provide an authorization code and receive a token @@ -1531,6 +1531,13 @@ def __init__(self, obj, session=None, threads=4): # Store the current items to iterate over self._obj = obj + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._pool is not None: + self._pool.shutdown(wait=False) + def __next__(self): if self._pool is None: self._pool = concurrent.futures.ThreadPoolExecutor( @@ -1778,7 +1785,7 @@ class VersionInfo: Release cadence for Viya 4. Should be one of 'stable' or 'LTS'. release : str, optional Release number for Viya 4. Two formats are currently possible: - + - YYYY.R.U where R is the LTS release number in YYYY and U is the updates since R - YYYY.MM where MM is the month of the release. @@ -2028,6 +2035,7 @@ def request(verb, path, session=None, format="auto", **kwargs): Returns ------- + str, bytes, or requests.Response """ session = session or current_session() @@ -2111,6 +2119,7 @@ def request_link(obj, rel, **kwargs): Returns ------- + RestObj """ link = get_link(obj, rel) @@ -2359,7 +2368,7 @@ def platform_version(): Returns ------- - str + str SAS Viya version number '3.5' or '4.0' """ diff --git a/src/sasctl/pzmm/import_model.py b/src/sasctl/pzmm/import_model.py index 56e50fda..cc28f7e5 100644 --- a/src/sasctl/pzmm/import_model.py +++ b/src/sasctl/pzmm/import_model.py @@ -306,7 +306,7 @@ def import_model( **kwargs Other keyword arguments are passed to the following function: :meth:`.ScoreCode.write_score_code` - + Returns ------- diff --git a/src/sasctl/pzmm/pickle_model.py b/src/sasctl/pzmm/pickle_model.py index ac91e98d..7b17f2a4 100644 --- a/src/sasctl/pzmm/pickle_model.py +++ b/src/sasctl/pzmm/pickle_model.py @@ -37,7 +37,7 @@ def pickle_trained_model( object. The following files are generated by this function: - + * '\*.pickle' Binary pickle file containing a trained model. * '\*.mojo' diff --git a/src/sasctl/pzmm/write_json_files.py b/src/sasctl/pzmm/write_json_files.py index 42d5e6aa..1efc45e5 100644 --- a/src/sasctl/pzmm/write_json_files.py +++ b/src/sasctl/pzmm/write_json_files.py @@ -5,7 +5,8 @@ import ast import importlib import json -# import math #not used + +# import math #not used import pickle import pickletools import sys diff --git a/src/sasctl/pzmm/zip_model.py b/src/sasctl/pzmm/zip_model.py index 47beee68..0e47345b 100644 --- a/src/sasctl/pzmm/zip_model.py +++ b/src/sasctl/pzmm/zip_model.py @@ -71,7 +71,7 @@ def zip_files( """ if isinstance(model_files, dict): - buffer = io.BytesIO() + buffer = BytesIO() with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED, False) as archive: for file_name, data in model_files.items(): diff --git a/src/sasctl/utils/model_info.py b/src/sasctl/utils/model_info.py index 57c733db..e307d26a 100644 --- a/src/sasctl/utils/model_info.py +++ b/src/sasctl/utils/model_info.py @@ -4,6 +4,8 @@ # Copyright © 2023, SAS Institute Inc., Cary, NC, USA. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import math +import warnings from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Union @@ -15,6 +17,19 @@ except ImportError: torch = None +try: + import onnx + + # ONNX serializes models using protobuf, so this should be safe + from google.protobuf import json_format +except ImportError: + onnx = None + +try: + import onnxruntime +except ImportError: + onnxruntime = None + def get_model_info(model, X, y=None): """Extracts metadata about the model and associated data sets. @@ -40,9 +55,12 @@ def get_model_info(model, X, y=None): """ # Don't need to import sklearn, just check if the class is part of that module. - if model.__class__.__module__.startswith("sklearn."): + if type(model).__module__.startswith("sklearn."): return SklearnModelInfo(model, X, y) + if type(model).__module__.startswith("onnx"): + return _load_onnx_model(model, X, y) + # Most PyTorch models are actually subclasses of torch.nn.Module, so checking module # name alone is not sufficient. elif torch and isinstance(model, torch.nn.Module): @@ -51,17 +69,29 @@ def get_model_info(model, X, y=None): raise ValueError(f"Unrecognized model type {type(model)} received.") +def _load_onnx_model(model, X, y=None): + # TODO: unncessary? static analysis of onnx file sufficient? + if onnxruntime: + return OnnxModelInfo(model, X, y) + + return OnnxModelInfo(model, X, y) + + class ModelInfo(ABC): """Base class for storing model metadata. Attributes ---------- algorithm : str + Will appear in the "Algorithm" drop-down menu in Model Manager. + Example: "Forest", "Neural networks", "Binning", etc. analytic_function : str + Will appear in the "Function" drop-down menu in Model Manager. + Example: "Classification", "Clustering", "Prediction" is_binary_classifier : bool - is_classifier - is_regressor - is_clusterer + is_classifier : bool + is_regressor : bool + is_clusterer : bool model : object The model instance that the information was extracted from. model_params : {str: any} @@ -166,13 +196,169 @@ def y(self) -> pd.DataFrame: return +class OnnxModelInfo(ModelInfo): + def __init__(self, model, X, y=None): + if onnx is None: + raise RuntimeError( + "The onnx package must be installed to work with ONNX models. Please `pip install onnx`." + ) + + self._model = model + self._X = X + self._y = y + + inferred_model = onnx.shape_inference.infer_shapes(model) + + inputs = [self._tensor_to_dataframe(i) for i in inferred_model.graph.input] + outputs = [self._tensor_to_dataframe(o) for o in inferred_model.graph.output] + + if len(inputs) > 1: + warnings.warn( + f"The ONNX model has {len(inputs)} inputs but only the first input will be captured in Model Manager." + ) + + if len(outputs) > 1: + warnings.warn( + f"The ONNX model has {len(outputs)} outputs but only the first input will be captured in Model Manager." + ) + + self._X_df = inputs[0] + self._y_df = outputs[0] + + # initializer (static params) + + # for field in model.ListFields(): + # doc_string + # domain + # metadata_props + # model_author + # model_license + # model_version + # producer_name + # producer_version + # training_info + + # irVersion + # producerName + # producerVersion + # opsetImport + + # # list of (FieldDescriptor, value) + # fields = model.ListFields() + + @staticmethod + def _tensor_to_dataframe(tensor): + """ + + Parameters + ---------- + tensor : onnx.onnx_ml_pb2.ValueInfoProto or dict + A protobuf `Message` containing information + + Returns + ------- + pandas.DataFrame + + Examples + -------- + df = _tensor_to_dataframe(model.graph.input[0]) + + """ + if isinstance(tensor, onnx.onnx_ml_pb2.ValueInfoProto): + tensor = json_format.MessageToDict(tensor) + elif not isinstance(tensor, dict): + raise ValueError(f"Unexpected type {type(tensor)}.") + + name = tensor.get("name", "Var") + type_ = tensor["type"] + + if not "tensorType" in type_: + raise ValueError(f"Received an unexpected ONNX input type: {type_}.") + + dtype = onnx.helper.tensor_dtype_to_np_dtype(type_["tensorType"]["elemType"]) + + # Tuple of tensor dimensions e.g. (1, 1, 24) + input_dims = tuple( + int(d["dimValue"]) for d in type_["tensorType"]["shape"]["dim"] + ) + + return pd.DataFrame( + dtype=dtype, columns=[f"{name}{i+1}" for i in range(math.prod(input_dims))] + ) + + @property + def algorithm(self) -> str: + return "neural network" + + @property + def description(self) -> str: + return self.model.doc_string + + @property + def is_binary_classifier(self) -> bool: + return len(self.output_column_names) == 2 + + @property + def is_classifier(self) -> bool: + return len(self.output_column_names) > 1 + + @property + def is_clusterer(self) -> bool: + return False + + @property + def is_regressor(self) -> bool: + return len(self.output_column_names) == 1 + + @property + def model(self) -> object: + return self._model + + @property + def model_params(self) -> Dict[str, Any]: + return { + k: getattr(self.model, k, None) + for k in ( + "ir_version", + "model_version", + "opset_import", + "producer_name", + "producer_version", + ) + } + + @property + def predict_function(self) -> Callable: + return None + + @property + def target_column(self): + return None + + @property + def target_values(self): + return None + + @property + def threshold(self) -> Union[str, None]: + return None + + @property + def X(self) -> pd.DataFrame: + return self._X_df + + @property + def y(self) -> pd.DataFrame: + return self._y_df + + class PyTorchModelInfo(ModelInfo): """Stores model information for a PyTorch model instance.""" def __init__(self, model, X, y=None): if torch is None: raise RuntimeError( - "The PyTorch library must be installed to work with PyTorch models. Please `pip install torch`." + "The PyTorch package must be installed to work with PyTorch models. Please `pip install torch`." ) if not isinstance(model, torch.nn.Module): diff --git a/tests/conftest.py b/tests/conftest.py index 609b68b5..a6172985 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -639,10 +639,11 @@ def pytest_runtest_makereport(item, call): # elif "cas_session" in item.callspec.params: # key = item.callspec.params["cas_session"] # else: - key = item.callspec.id + if hasattr(item, "callspec"): + key = item.callspec.id - # Track that this test was the last test to fail for this Viya version - parent._previousfailed[key] = item + # Track that this test was the last test to fail for this Viya version + parent._previousfailed[key] = item def pytest_runtest_setup(item): diff --git a/tests/unit/test_model_info_onnx.py b/tests/unit/test_model_info_onnx.py new file mode 100644 index 00000000..e44f3711 --- /dev/null +++ b/tests/unit/test_model_info_onnx.py @@ -0,0 +1,56 @@ +import pandas as pd +import pytest + +onnx = pytest.importorskip("onnx") +torch = pytest.importorskip("torch") + +import sasctl.utils.model_info +from sasctl.utils import get_model_info + +# mnist +# get input/output shapes +# get var names if available +# classification/regression/etc +# + + +@pytest.fixture +def mnist_model(tmp_path): + class Net(torch.nn.Module): + def __init__(self): + super(Net, self).__init__() + self.fc1 = torch.nn.Linear(14 * 14, 128) + self.fc2 = torch.nn.Linear(128, 10) + + def forward(self, x): + x = torch.nn.functional.max_pool2d(x, 2) + x = x.reshape(-1, 1 * 14 * 14) + x = self.fc1(x) + x = torch.nn.functional.relu(x) + x = self.fc2(x) + output = torch.nn.functional.softmax(x, dim=1) + return output + + model = Net() + + path = tmp_path / "model.onnx" + X = torch.randn(1, 1, 28, 28) + torch.onnx.export(model, X, path, input_names=["image"], output_names=["digit"]) + yield onnx.load(path), X + + +def test_get_info(mnist_model): + info = get_model_info(*mnist_model) + assert isinstance(info, sasctl.utils.model_info.OnnxModelInfo) + + # Output be classification into 10 digits + assert len(info.output_column_names) == 10 + assert all(c.startswith("digit") for c in info.output_column_names) + + assert isinstance(info.X, pd.DataFrame) + assert len(info.X.columns) == 28 * 28 + + assert info.is_classifier + assert not info.is_binary_classifier + assert not info.is_regressor + assert not info.is_clusterer diff --git a/tests/unit/test_model_info_torch.py b/tests/unit/test_model_info_torch.py index 46bf42ec..6d470736 100644 --- a/tests/unit/test_model_info_torch.py +++ b/tests/unit/test_model_info_torch.py @@ -83,4 +83,4 @@ def test_mnist(): info = get_model_info(model, X) meta = prepare_model_for_sas(model, "MnistLogistic") - assert info.is_classifier + # assert info.is_classifier diff --git a/tests/unit/test_pageiterator.py b/tests/unit/test_pageiterator.py index dd66cab6..4c49c7d2 100644 --- a/tests/unit/test_pageiterator.py +++ b/tests/unit/test_pageiterator.py @@ -11,7 +11,9 @@ from sasctl.core import PageIterator, RestObj -@pytest.fixture(params=[(6, 2, 2), (6, 1, 4), (6, 5, 4), (6, 6, 2), (100, 10, 20)]) +@pytest.fixture( + scope="function", params=[(6, 2, 2), (6, 1, 4), (6, 5, 4), (6, 6, 2), (100, 10, 20)] +) def paging(request): """Create a RestObj designed to page through a collection of items and the collection itself. @@ -56,35 +58,11 @@ def side_effect(_, link, **kwargs): assert req.call_count >= math.ceil(call_count) -@pytest.mark.incremental -class TestPageIterator: - def test_no_paging_required(self): - """If "next" link not present, current items should be included.""" +def test_paging_required(paging): + """Requests should be made to retrieve additional pages.""" + obj, items, _ = paging - items = [{"name": "a"}, {"name": "b"}, {"name": "c"}] - obj = RestObj(items=items, count=len(items)) - - with mock.patch("sasctl.core.request") as request: - pager = PageIterator(obj) - - # Returned page of items should preserve item order - items = next(pager) - for idx, item in enumerate(items): - assert item.name == RestObj(items[idx]).name - - # No request should have been made to retrieve additional data. - try: - request.assert_not_called() - except AssertionError as e: - raise AssertionError( - f"method_calls={request.mock_calls} call_args={request.call_args_list}" - ) - - def test_paging_required(self, paging): - """Requests should be made to retrieve additional pages.""" - obj, items, _ = paging - - pager = PageIterator(obj) + with PageIterator(obj) as pager: init_count = pager._start for i, page in enumerate(pager): diff --git a/tox.ini b/tox.ini index 10f5c911..f776b4a5 100644 --- a/tox.ini +++ b/tox.ini @@ -51,6 +51,10 @@ deps = tests: urllib3 < 2.0.0 tests: nbconvert tests: nbformat +# tests: torch +# tests: onnx +# tests: h2o +# tests: tensorflow # tests: lightgbm ; platform_system != "Darwin" # lightgmb seems to have build issues on MacOS # doc skips install, so explicitly add minimum packages doc: sphinx