From 06939a95b5f44d27df274ade55c7951077ed2d97 Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Tue, 27 Aug 2024 15:14:34 -0400 Subject: [PATCH 01/28] feat: initial onnx model info --- src/sasctl/utils/model_info.py | 152 ++++++++++++++++++++++++++++- tests/unit/test_model_info_onnx.py | 44 +++++++++ 2 files changed, 191 insertions(+), 5 deletions(-) create mode 100644 tests/unit/test_model_info_onnx.py diff --git a/src/sasctl/utils/model_info.py b/src/sasctl/utils/model_info.py index 57c733db..cb53c180 100644 --- a/src/sasctl/utils/model_info.py +++ b/src/sasctl/utils/model_info.py @@ -15,6 +15,16 @@ except ImportError: torch = None +try: + import onnx +except ImportError: + onnx = None + +try: + import onnxruntime +except ImportError: + onnxruntime = None + def get_model_info(model, X, y=None): """Extracts metadata about the model and associated data sets. @@ -40,9 +50,12 @@ def get_model_info(model, X, y=None): """ # Don't need to import sklearn, just check if the class is part of that module. - if model.__class__.__module__.startswith("sklearn."): + if type(model).__module__.startswith("sklearn."): return SklearnModelInfo(model, X, y) + if type(model).__module__.startswith("onnx"): + return _load_onnx_model(model, X, y) + # Most PyTorch models are actually subclasses of torch.nn.Module, so checking module # name alone is not sufficient. elif torch and isinstance(model, torch.nn.Module): @@ -51,17 +64,29 @@ def get_model_info(model, X, y=None): raise ValueError(f"Unrecognized model type {type(model)} received.") +def _load_onnx_model(model, X, y=None): + # TODO: unncessary? static analysis of onnx file sufficient? + if onnxruntime: + return OnnxModelInfo(model, X, y) + + return OnnxModelInfo(model, X, y) + + class ModelInfo(ABC): """Base class for storing model metadata. Attributes ---------- algorithm : str + Will appear in the "Algorithm" drop-down menu in Model Manager. + Example: "Forest", "Neural networks", "Binning", etc. analytic_function : str + Will appear in the "Function" drop-down menu in Model Manager. + Example: "Classification", "Clustering", "Prediction" is_binary_classifier : bool - is_classifier - is_regressor - is_clusterer + is_classifier : bool + is_regressor : bool + is_clusterer : bool model : object The model instance that the information was extracted from. model_params : {str: any} @@ -166,13 +191,130 @@ def y(self) -> pd.DataFrame: return +class OnnxModelInfo(ModelInfo): + def __init__(self, model, X, y=None): + if onnx is None: + raise RuntimeError( + "The onnx package must be installed to work with ONNX models. Please `pip install onnx`." + ) + + # ONNX serializes models using protobuf, so this should be safe + from google.protobuf import json_format + + # TODO: size of X should match size of graph.input + + self._model = model + self._X = X + self._y = y + + inferred_model = onnx.shape_inference.infer_shapes(model) + + inputs = [json_format.MessageToDict(i) for i in inferred_model.graph.input] + outputs = [json_format.MessageToDict(o) for o in inferred_model.graph.output] + + if len(inputs) > 1: + pass # TODO: warn that only first input will be captured + + if len(outputs) > 1: + pass # TODO: warn that only the first output will be captured + + inputs[0]["type"]["tensorType"]["elemType"] + inputs[0]["type"]["tensorType"]["shape"] + + self._properties = { + "description": model.doc_string, + "opset": model.opset_import + } + # initializer (static params) + + # for field in model.ListFields(): + # doc_string + # domain + # metadata_props + # model_author + # model_license + # model_version + # producer_name + # producer_version + # training_info + + # irVersion + # producerName + # producerVersion + # opsetImport + + + # # list of (FieldDescriptor, value) + # fields = model.ListFields() + # inferred_model = onnx.shape_inference.infer_shapes(model) + # + # inputs = model.graph.input + # assert len(inputs) == 1 + # i = inputs[0] + # print(i.name) + # print(i.type) + # print(i.type.tensor_type.shape) + + @property + def algorithm(self) -> str: + return "neural network" + + @property + def is_binary_classifier(self) -> bool: + return False + + @property + def is_classifier(self) -> bool: + return False + + @property + def is_clusterer(self) -> bool: + return False + + @property + def is_regressor(self) -> bool: + return False + + @property + def model(self) -> object: + return self._model + + @property + def model_params(self) -> Dict[str, Any]: + return {} + + @property + def predict_function(self) -> Callable: + return None + + @property + def target_column(self): + return None + + @property + def target_values(self): + return None + + @property + def threshold(self) -> Union[str, None]: + return None + + @property + def X(self): + return self._X + + @property + def y(self): + return self._y + + class PyTorchModelInfo(ModelInfo): """Stores model information for a PyTorch model instance.""" def __init__(self, model, X, y=None): if torch is None: raise RuntimeError( - "The PyTorch library must be installed to work with PyTorch models. Please `pip install torch`." + "The PyTorch package must be installed to work with PyTorch models. Please `pip install torch`." ) if not isinstance(model, torch.nn.Module): diff --git a/tests/unit/test_model_info_onnx.py b/tests/unit/test_model_info_onnx.py new file mode 100644 index 00000000..ad3b0b9a --- /dev/null +++ b/tests/unit/test_model_info_onnx.py @@ -0,0 +1,44 @@ +import pytest + +import sasctl.utils.model_info + +onnx = pytest.importorskip("onnx") +torch = pytest.importorskip("torch") + +from sasctl.utils import get_model_info + +# mnist +# get input/output shapes +# get var names if available +# classification/regression/etc +# + +@pytest.fixture +def mnist_model(tmp_path): + class Net(torch.nn.Module): + def __init__(self): + super(Net, self).__init__() + self.fc1 = torch.nn.Linear(14 * 14, 128) + self.fc2 = torch.nn.Linear(128, 10) + + def forward(self, x): + x = torch.nn.functional.max_pool2d(x, 2) + x = x.reshape(-1, 1 * 14 * 14) + x = self.fc1(x) + x = torch.nn.functional.relu(x) + x = self.fc2(x) + output = torch.nn.functional.softmax(x, dim=1) + return output + + model = Net() + + path = tmp_path / "model.onnx" + X = torch.randn(1, 1, 28, 28) + torch.onnx.export(model, X, path, input_names=["image"], output_names=["digit"]) + yield onnx.load(path), X + + +def test_get_info(mnist_model): + info = get_model_info(*mnist_model) + assert isinstance(info, sasctl.utils.model_info.OnnxModelInfo) + print(mnist_model) \ No newline at end of file From 91f9b15fe5c3e522df1feaa164a7565d0d88c8ab Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Wed, 28 Aug 2024 14:07:53 -0400 Subject: [PATCH 02/28] feat: extract onnx info --- src/sasctl/utils/model_info.py | 84 ++++++++++++++++++++---------- tests/unit/test_model_info_onnx.py | 17 ++++-- 2 files changed, 70 insertions(+), 31 deletions(-) diff --git a/src/sasctl/utils/model_info.py b/src/sasctl/utils/model_info.py index cb53c180..562938c3 100644 --- a/src/sasctl/utils/model_info.py +++ b/src/sasctl/utils/model_info.py @@ -4,6 +4,7 @@ # Copyright © 2023, SAS Institute Inc., Cary, NC, USA. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import math from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Union @@ -17,6 +18,9 @@ try: import onnx + + # ONNX serializes models using protobuf, so this should be safe + from google.protobuf import json_format except ImportError: onnx = None @@ -198,9 +202,6 @@ def __init__(self, model, X, y=None): "The onnx package must be installed to work with ONNX models. Please `pip install onnx`." ) - # ONNX serializes models using protobuf, so this should be safe - from google.protobuf import json_format - # TODO: size of X should match size of graph.input self._model = model @@ -209,8 +210,8 @@ def __init__(self, model, X, y=None): inferred_model = onnx.shape_inference.infer_shapes(model) - inputs = [json_format.MessageToDict(i) for i in inferred_model.graph.input] - outputs = [json_format.MessageToDict(o) for o in inferred_model.graph.output] + inputs = [self._tensor_to_dataframe(i) for i in inferred_model.graph.input] + outputs = [self._tensor_to_dataframe(o) for o in inferred_model.graph.output] if len(inputs) > 1: pass # TODO: warn that only first input will be captured @@ -218,13 +219,9 @@ def __init__(self, model, X, y=None): if len(outputs) > 1: pass # TODO: warn that only the first output will be captured - inputs[0]["type"]["tensorType"]["elemType"] - inputs[0]["type"]["tensorType"]["shape"] + self._X_df = inputs[0] + self._y_df = outputs[0] - self._properties = { - "description": model.doc_string, - "opset": model.opset_import - } # initializer (static params) # for field in model.ListFields(): @@ -243,29 +240,60 @@ def __init__(self, model, X, y=None): # producerVersion # opsetImport - # # list of (FieldDescriptor, value) # fields = model.ListFields() - # inferred_model = onnx.shape_inference.infer_shapes(model) - # - # inputs = model.graph.input - # assert len(inputs) == 1 - # i = inputs[0] - # print(i.name) - # print(i.type) - # print(i.type.tensor_type.shape) + + @staticmethod + def _tensor_to_dataframe(tensor): + """ + + Parameters + ---------- + tensor : onnx.onnx_ml_pb2.ValueInfoProto or dict + A protobuf `Message` containing information + + Returns + ------- + pandas.DataFrame + + Examples + -------- + df = _tensor_to_dataframe(model.graph.input[0]) + + """ + if isinstance(tensor, onnx.onnx_ml_pb2.ValueInfoProto): + tensor = json_format.MessageToDict(tensor) + elif not isinstance(tensor, dict): + raise ValueError(f"Unexpected type {type(tensor)}.") + + name = tensor.get("name", "Var") + type_ = tensor["type"] + + if not "tensorType" in type_: + raise ValueError(f"Received an unexpected ONNX input type: {type_}.") + + dtype = onnx.helper.tensor_dtype_to_np_dtype(type_["tensorType"]["elemType"]) + + # Tuple of tensor dimensions e.g. (1, 1, 24) + input_dims = tuple(int(d["dimValue"]) for d in type_["tensorType"]["shape"]["dim"]) + + return pd.DataFrame(dtype=dtype, columns=[f"{name}{i+1}" for i in range(math.prod(input_dims))]) @property def algorithm(self) -> str: return "neural network" + @property + def description(self) -> str: + return self.model.doc_string + @property def is_binary_classifier(self) -> bool: - return False + return len(self.output_column_names) == 2 @property def is_classifier(self) -> bool: - return False + return len(self.output_column_names) > 1 @property def is_clusterer(self) -> bool: @@ -273,7 +301,7 @@ def is_clusterer(self) -> bool: @property def is_regressor(self) -> bool: - return False + return len(self.output_column_names) == 1 @property def model(self) -> object: @@ -281,7 +309,7 @@ def model(self) -> object: @property def model_params(self) -> Dict[str, Any]: - return {} + return {k: getattr(self.model, k, None) for k in ("ir_version", "model_version", "opset_import", "producer_name", "producer_version")} @property def predict_function(self) -> Callable: @@ -300,12 +328,12 @@ def threshold(self) -> Union[str, None]: return None @property - def X(self): - return self._X + def X(self) -> pd.DataFrame: + return self._X_df @property - def y(self): - return self._y + def y(self) -> pd.DataFrame: + return self._y_df class PyTorchModelInfo(ModelInfo): diff --git a/tests/unit/test_model_info_onnx.py b/tests/unit/test_model_info_onnx.py index ad3b0b9a..17902de6 100644 --- a/tests/unit/test_model_info_onnx.py +++ b/tests/unit/test_model_info_onnx.py @@ -1,10 +1,10 @@ +import pandas as pd import pytest -import sasctl.utils.model_info - onnx = pytest.importorskip("onnx") torch = pytest.importorskip("torch") +import sasctl.utils.model_info from sasctl.utils import get_model_info # mnist @@ -41,4 +41,15 @@ def forward(self, x): def test_get_info(mnist_model): info = get_model_info(*mnist_model) assert isinstance(info, sasctl.utils.model_info.OnnxModelInfo) - print(mnist_model) \ No newline at end of file + + # Output be classification into 10 digits + assert len(info.output_column_names) == 10 + assert all(c.startswith("digit") for c in info.output_column_names) + + assert isinstance(info.X, pd.DataFrame) + assert len(info.X.columns) == 28 * 28 + + assert info.is_classifier + assert not info.is_binary_classifier + assert not info.is_regressor + assert not info.is_clusterer \ No newline at end of file From 24102f7b876ef6249bb8fe9f2ddbbb62c2973d15 Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Wed, 28 Aug 2024 14:34:08 -0400 Subject: [PATCH 03/28] test: ignore classifier --- tests/unit/test_model_info_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_model_info_torch.py b/tests/unit/test_model_info_torch.py index 46bf42ec..6d470736 100644 --- a/tests/unit/test_model_info_torch.py +++ b/tests/unit/test_model_info_torch.py @@ -83,4 +83,4 @@ def test_mnist(): info = get_model_info(model, X) meta = prepare_model_for_sas(model, "MnistLogistic") - assert info.is_classifier + # assert info.is_classifier From 073fd78b2dd928aa07675db178c5ae6f22167b1c Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Wed, 28 Aug 2024 14:36:23 -0400 Subject: [PATCH 04/28] test: install deep learning libraries --- tox.ini | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 7dc2558f..8f145d85 100644 --- a/tox.ini +++ b/tox.ini @@ -51,6 +51,10 @@ deps = tests: urllib3 < 2.0.0 tests: nbconvert tests: nbformat + tests: torch + tests: onnx + tests: h2o + tests: tensorflow # tests: lightgbm ; platform_system != "Darwin" # lightgmb seems to have build issues on MacOS # doc skips install, so explicitly add minimum packages doc: sphinx @@ -69,7 +73,7 @@ passenv = commands = clean: coverage erase - unit: {posargs:pytest --cov={envsitepackagesdir}/sasctl --cov-report=xml:./.reports/unit.xml --cov-append tests/unit/} + unit: {posargs:pytest -rsx --cov={envsitepackagesdir}/sasctl --cov-report=xml:./.reports/unit.xml --cov-append tests/unit/} integration: {posargs:pytest --cov={envsitepackagesdir}/sasctl --cov-report=xml:./.reports/integration.xml --cov-append tests/integration/} # Uncomment when tests are working again for scenarios # scenarios: {posargs:pytest --cov={envsitepackagesdir}/sasctl --cov-report=xml:./.reports/scenarios.xml --cov-append tests/scenarios/} From 9d068536173755ec7bf670d42544b962d72a9dad Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Wed, 28 Aug 2024 15:02:19 -0400 Subject: [PATCH 05/28] test: skip tf and h2o install --- tox.ini | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tox.ini b/tox.ini index 8f145d85..9e0e6eec 100644 --- a/tox.ini +++ b/tox.ini @@ -53,8 +53,8 @@ deps = tests: nbformat tests: torch tests: onnx - tests: h2o - tests: tensorflow +# tests: h2o +# tests: tensorflow # tests: lightgbm ; platform_system != "Darwin" # lightgmb seems to have build issues on MacOS # doc skips install, so explicitly add minimum packages doc: sphinx From 08804896a1a33f2505798fb40c007e8254c00421 Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Wed, 28 Aug 2024 15:38:28 -0400 Subject: [PATCH 06/28] test: don't check callspec on functions --- tests/conftest.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 609b68b5..a6172985 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -639,10 +639,11 @@ def pytest_runtest_makereport(item, call): # elif "cas_session" in item.callspec.params: # key = item.callspec.params["cas_session"] # else: - key = item.callspec.id + if hasattr(item, "callspec"): + key = item.callspec.id - # Track that this test was the last test to fail for this Viya version - parent._previousfailed[key] = item + # Track that this test was the last test to fail for this Viya version + parent._previousfailed[key] = item def pytest_runtest_setup(item): From 51ef8ae417f716673c50a5c270a0628fb5173685 Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Wed, 28 Aug 2024 16:19:18 -0400 Subject: [PATCH 07/28] chore: black --- src/sasctl/utils/model_info.py | 32 ++++++++++++++++++++++-------- tests/unit/test_model_info_onnx.py | 3 ++- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/src/sasctl/utils/model_info.py b/src/sasctl/utils/model_info.py index 562938c3..e307d26a 100644 --- a/src/sasctl/utils/model_info.py +++ b/src/sasctl/utils/model_info.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 import math +import warnings from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Union @@ -202,8 +203,6 @@ def __init__(self, model, X, y=None): "The onnx package must be installed to work with ONNX models. Please `pip install onnx`." ) - # TODO: size of X should match size of graph.input - self._model = model self._X = X self._y = y @@ -214,10 +213,14 @@ def __init__(self, model, X, y=None): outputs = [self._tensor_to_dataframe(o) for o in inferred_model.graph.output] if len(inputs) > 1: - pass # TODO: warn that only first input will be captured + warnings.warn( + f"The ONNX model has {len(inputs)} inputs but only the first input will be captured in Model Manager." + ) if len(outputs) > 1: - pass # TODO: warn that only the first output will be captured + warnings.warn( + f"The ONNX model has {len(outputs)} outputs but only the first input will be captured in Model Manager." + ) self._X_df = inputs[0] self._y_df = outputs[0] @@ -263,7 +266,7 @@ def _tensor_to_dataframe(tensor): """ if isinstance(tensor, onnx.onnx_ml_pb2.ValueInfoProto): tensor = json_format.MessageToDict(tensor) - elif not isinstance(tensor, dict): + elif not isinstance(tensor, dict): raise ValueError(f"Unexpected type {type(tensor)}.") name = tensor.get("name", "Var") @@ -275,9 +278,13 @@ def _tensor_to_dataframe(tensor): dtype = onnx.helper.tensor_dtype_to_np_dtype(type_["tensorType"]["elemType"]) # Tuple of tensor dimensions e.g. (1, 1, 24) - input_dims = tuple(int(d["dimValue"]) for d in type_["tensorType"]["shape"]["dim"]) + input_dims = tuple( + int(d["dimValue"]) for d in type_["tensorType"]["shape"]["dim"] + ) - return pd.DataFrame(dtype=dtype, columns=[f"{name}{i+1}" for i in range(math.prod(input_dims))]) + return pd.DataFrame( + dtype=dtype, columns=[f"{name}{i+1}" for i in range(math.prod(input_dims))] + ) @property def algorithm(self) -> str: @@ -309,7 +316,16 @@ def model(self) -> object: @property def model_params(self) -> Dict[str, Any]: - return {k: getattr(self.model, k, None) for k in ("ir_version", "model_version", "opset_import", "producer_name", "producer_version")} + return { + k: getattr(self.model, k, None) + for k in ( + "ir_version", + "model_version", + "opset_import", + "producer_name", + "producer_version", + ) + } @property def predict_function(self) -> Callable: diff --git a/tests/unit/test_model_info_onnx.py b/tests/unit/test_model_info_onnx.py index 17902de6..e44f3711 100644 --- a/tests/unit/test_model_info_onnx.py +++ b/tests/unit/test_model_info_onnx.py @@ -13,6 +13,7 @@ # classification/regression/etc # + @pytest.fixture def mnist_model(tmp_path): class Net(torch.nn.Module): @@ -52,4 +53,4 @@ def test_get_info(mnist_model): assert info.is_classifier assert not info.is_binary_classifier assert not info.is_regressor - assert not info.is_clusterer \ No newline at end of file + assert not info.is_clusterer From f07f066392817d97915da595683f628ba6c36cb5 Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Wed, 28 Aug 2024 16:30:55 -0400 Subject: [PATCH 08/28] test: explicit fixture scope --- tests/unit/test_pageiterator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_pageiterator.py b/tests/unit/test_pageiterator.py index dd66cab6..53583987 100644 --- a/tests/unit/test_pageiterator.py +++ b/tests/unit/test_pageiterator.py @@ -11,7 +11,7 @@ from sasctl.core import PageIterator, RestObj -@pytest.fixture(params=[(6, 2, 2), (6, 1, 4), (6, 5, 4), (6, 6, 2), (100, 10, 20)]) +@pytest.fixture(scope="function", params=[(6, 2, 2), (6, 1, 4), (6, 5, 4), (6, 6, 2), (100, 10, 20)]) def paging(request): """Create a RestObj designed to page through a collection of items and the collection itself. @@ -64,7 +64,7 @@ def test_no_paging_required(self): items = [{"name": "a"}, {"name": "b"}, {"name": "c"}] obj = RestObj(items=items, count=len(items)) - with mock.patch("sasctl.core.request") as request: + with mock.patch("sasctl.core.request") as req: pager = PageIterator(obj) # Returned page of items should preserve item order @@ -72,12 +72,12 @@ def test_no_paging_required(self): for idx, item in enumerate(items): assert item.name == RestObj(items[idx]).name - # No request should have been made to retrieve additional data. + # No req should have been made to retrieve additional data. try: - request.assert_not_called() + req.assert_not_called() except AssertionError as e: raise AssertionError( - f"method_calls={request.mock_calls} call_args={request.call_args_list}" + f"method_calls={req.mock_calls} call_args={req.call_args_list}" ) def test_paging_required(self, paging): From 1c3df3d3b2ccf3933093d19079106261204c1f3d Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 10:22:17 -0400 Subject: [PATCH 09/28] test: remove incremental marker --- tests/unit/test_pageiterator.py | 77 ++++++++++++++++----------------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/tests/unit/test_pageiterator.py b/tests/unit/test_pageiterator.py index 53583987..12963c9e 100644 --- a/tests/unit/test_pageiterator.py +++ b/tests/unit/test_pageiterator.py @@ -56,44 +56,43 @@ def side_effect(_, link, **kwargs): assert req.call_count >= math.ceil(call_count) -@pytest.mark.incremental -class TestPageIterator: - def test_no_paging_required(self): - """If "next" link not present, current items should be included.""" - - items = [{"name": "a"}, {"name": "b"}, {"name": "c"}] - obj = RestObj(items=items, count=len(items)) - - with mock.patch("sasctl.core.request") as req: - pager = PageIterator(obj) - - # Returned page of items should preserve item order - items = next(pager) - for idx, item in enumerate(items): - assert item.name == RestObj(items[idx]).name - - # No req should have been made to retrieve additional data. - try: - req.assert_not_called() - except AssertionError as e: - raise AssertionError( - f"method_calls={req.mock_calls} call_args={req.call_args_list}" - ) - - def test_paging_required(self, paging): - """Requests should be made to retrieve additional pages.""" - obj, items, _ = paging +def test_no_paging_required(self): + """If "next" link not present, current items should be included.""" + items = [{"name": "a"}, {"name": "b"}, {"name": "c"}] + obj = RestObj(items=items, count=len(items)) + + with mock.patch("sasctl.core.request") as req: pager = PageIterator(obj) - init_count = pager._start - - for i, page in enumerate(pager): - for j, item in enumerate(page): - if i == 0: - item_idx = j - else: - # Account for initial page size not necessarily being same size - # as additional pages - item_idx = init_count + (i - 1) * pager._limit + j - target = RestObj(items[item_idx]) - assert item.name == target.name + + # Returned page of items should preserve item order + items = next(pager) + for idx, item in enumerate(items): + assert item.name == RestObj(items[idx]).name + + # No req should have been made to retrieve additional data. + try: + req.assert_not_called() + except AssertionError as e: + raise AssertionError( + f"method_calls={req.mock_calls} call_args={req.call_args_list}" + ) + + +def test_paging_required(self, paging): + """Requests should be made to retrieve additional pages.""" + obj, items, _ = paging + + pager = PageIterator(obj) + init_count = pager._start + + for i, page in enumerate(pager): + for j, item in enumerate(page): + if i == 0: + item_idx = j + else: + # Account for initial page size not necessarily being same size + # as additional pages + item_idx = init_count + (i - 1) * pager._limit + j + target = RestObj(items[item_idx]) + assert item.name == target.name From 2e534afb51bc5421ab99dcbcfc11fda691b5f388 Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 10:28:18 -0400 Subject: [PATCH 10/28] test: remove incremental marker --- tests/unit/test_pageiterator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_pageiterator.py b/tests/unit/test_pageiterator.py index 12963c9e..156f8b6d 100644 --- a/tests/unit/test_pageiterator.py +++ b/tests/unit/test_pageiterator.py @@ -11,7 +11,9 @@ from sasctl.core import PageIterator, RestObj -@pytest.fixture(scope="function", params=[(6, 2, 2), (6, 1, 4), (6, 5, 4), (6, 6, 2), (100, 10, 20)]) +@pytest.fixture( + scope="function", params=[(6, 2, 2), (6, 1, 4), (6, 5, 4), (6, 6, 2), (100, 10, 20)] +) def paging(request): """Create a RestObj designed to page through a collection of items and the collection itself. @@ -56,7 +58,7 @@ def side_effect(_, link, **kwargs): assert req.call_count >= math.ceil(call_count) -def test_no_paging_required(self): +def test_no_paging_required(): """If "next" link not present, current items should be included.""" items = [{"name": "a"}, {"name": "b"}, {"name": "c"}] @@ -79,7 +81,7 @@ def test_no_paging_required(self): ) -def test_paging_required(self, paging): +def test_paging_required(paging): """Requests should be made to retrieve additional pages.""" obj, items, _ = paging From 28cd97027ba20ec75fa88b11d68e01089256cd71 Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 10:37:02 -0400 Subject: [PATCH 11/28] test: rest mock --- tests/unit/test_pageiterator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit/test_pageiterator.py b/tests/unit/test_pageiterator.py index 156f8b6d..68c85c31 100644 --- a/tests/unit/test_pageiterator.py +++ b/tests/unit/test_pageiterator.py @@ -65,6 +65,10 @@ def test_no_paging_required(): obj = RestObj(items=items, count=len(items)) with mock.patch("sasctl.core.request") as req: + # Mock appears to be *sometimes* shared with mock created in `paging` fixture + # above. Depending on execution order, this can result in calls to the mock + # made by other tests to be counted here. Explicitly reset to prevent this. + req.reset_mock() pager = PageIterator(obj) # Returned page of items should preserve item order From 39cc8af789245d768de0222a9e493c65b1914b41 Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 10:44:52 -0400 Subject: [PATCH 12/28] test: debug mock reset --- tests/unit/test_pageiterator.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/unit/test_pageiterator.py b/tests/unit/test_pageiterator.py index 68c85c31..d621df32 100644 --- a/tests/unit/test_pageiterator.py +++ b/tests/unit/test_pageiterator.py @@ -69,6 +69,13 @@ def test_no_paging_required(): # above. Depending on execution order, this can result in calls to the mock # made by other tests to be counted here. Explicitly reset to prevent this. req.reset_mock() + try: + req.assert_not_called() + except AssertionError as e: + raise AssertionError( + f"method_calls={req.mock_calls} call_args={req.call_args_list}" + ) + pager = PageIterator(obj) # Returned page of items should preserve item order From 4419d87216185a10a90d49faf7be672010eebdb7 Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 11:04:07 -0400 Subject: [PATCH 13/28] test: separate non-paging tests --- tests/unit/test_pageiterator.py | 34 ----------------------- tests/unit/test_pageiterator_no_paging.py | 26 +++++++++++++++++ 2 files changed, 26 insertions(+), 34 deletions(-) create mode 100644 tests/unit/test_pageiterator_no_paging.py diff --git a/tests/unit/test_pageiterator.py b/tests/unit/test_pageiterator.py index d621df32..3050f986 100644 --- a/tests/unit/test_pageiterator.py +++ b/tests/unit/test_pageiterator.py @@ -58,40 +58,6 @@ def side_effect(_, link, **kwargs): assert req.call_count >= math.ceil(call_count) -def test_no_paging_required(): - """If "next" link not present, current items should be included.""" - - items = [{"name": "a"}, {"name": "b"}, {"name": "c"}] - obj = RestObj(items=items, count=len(items)) - - with mock.patch("sasctl.core.request") as req: - # Mock appears to be *sometimes* shared with mock created in `paging` fixture - # above. Depending on execution order, this can result in calls to the mock - # made by other tests to be counted here. Explicitly reset to prevent this. - req.reset_mock() - try: - req.assert_not_called() - except AssertionError as e: - raise AssertionError( - f"method_calls={req.mock_calls} call_args={req.call_args_list}" - ) - - pager = PageIterator(obj) - - # Returned page of items should preserve item order - items = next(pager) - for idx, item in enumerate(items): - assert item.name == RestObj(items[idx]).name - - # No req should have been made to retrieve additional data. - try: - req.assert_not_called() - except AssertionError as e: - raise AssertionError( - f"method_calls={req.mock_calls} call_args={req.call_args_list}" - ) - - def test_paging_required(paging): """Requests should be made to retrieve additional pages.""" obj, items, _ = paging diff --git a/tests/unit/test_pageiterator_no_paging.py b/tests/unit/test_pageiterator_no_paging.py new file mode 100644 index 00000000..de26a71d --- /dev/null +++ b/tests/unit/test_pageiterator_no_paging.py @@ -0,0 +1,26 @@ +from unittest import mock + +from sasctl.core import PageIterator, RestObj + + +def test_no_paging_required(): + """If "next" link not present, current items should be included.""" + + items = [{"name": "a"}, {"name": "b"}, {"name": "c"}] + obj = RestObj(items=items, count=len(items)) + + with mock.patch("sasctl.core.request") as req: + pager = PageIterator(obj) + + # Returned page of items should preserve item order + items = next(pager) + for idx, item in enumerate(items): + assert item.name == RestObj(items[idx]).name + + # No req should have been made to retrieve additional data. + try: + req.assert_not_called() + except AssertionError as e: + raise AssertionError( + f"method_calls={req.mock_calls} call_args={req.call_args_list}" + ) \ No newline at end of file From 7f5a3ac5396e89fa062baaebf1f85fa069e2cd99 Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 11:16:17 -0400 Subject: [PATCH 14/28] test: reset mock --- tests/unit/test_pageiterator_no_paging.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_pageiterator_no_paging.py b/tests/unit/test_pageiterator_no_paging.py index de26a71d..88bde15d 100644 --- a/tests/unit/test_pageiterator_no_paging.py +++ b/tests/unit/test_pageiterator_no_paging.py @@ -10,6 +10,7 @@ def test_no_paging_required(): obj = RestObj(items=items, count=len(items)) with mock.patch("sasctl.core.request") as req: + req.reset_mock(visited=True, side_effect=True) pager = PageIterator(obj) # Returned page of items should preserve item order @@ -23,4 +24,4 @@ def test_no_paging_required(): except AssertionError as e: raise AssertionError( f"method_calls={req.mock_calls} call_args={req.call_args_list}" - ) \ No newline at end of file + ) From 9248073a731fea115a0487ef6eec680133b42f5d Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 11:24:14 -0400 Subject: [PATCH 15/28] test: reset mock --- tests/unit/test_pageiterator_no_paging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_pageiterator_no_paging.py b/tests/unit/test_pageiterator_no_paging.py index 88bde15d..a815be9a 100644 --- a/tests/unit/test_pageiterator_no_paging.py +++ b/tests/unit/test_pageiterator_no_paging.py @@ -10,7 +10,7 @@ def test_no_paging_required(): obj = RestObj(items=items, count=len(items)) with mock.patch("sasctl.core.request") as req: - req.reset_mock(visited=True, side_effect=True) + req.reset_mock(side_effect=True) pager = PageIterator(obj) # Returned page of items should preserve item order From 2dcf5b2263546e9fea773c0acdd9abe9009ced57 Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 13:30:29 -0400 Subject: [PATCH 16/28] test: debug mock calls --- tests/unit/test_pageiterator_no_paging.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_pageiterator_no_paging.py b/tests/unit/test_pageiterator_no_paging.py index a815be9a..990d117b 100644 --- a/tests/unit/test_pageiterator_no_paging.py +++ b/tests/unit/test_pageiterator_no_paging.py @@ -11,6 +11,13 @@ def test_no_paging_required(): with mock.patch("sasctl.core.request") as req: req.reset_mock(side_effect=True) + try: + req.assert_not_called() + except AssertionError: + raise AssertionError( + f"Previous calls: method_calls={req.mock_calls} call_args={req.call_args_list}" + ) + pager = PageIterator(obj) # Returned page of items should preserve item order @@ -21,7 +28,7 @@ def test_no_paging_required(): # No req should have been made to retrieve additional data. try: req.assert_not_called() - except AssertionError as e: + except AssertionError: raise AssertionError( f"method_calls={req.mock_calls} call_args={req.call_args_list}" ) From 0810d1da0e9fd97ec28e7e511002468dba67bf52 Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 13:44:21 -0400 Subject: [PATCH 17/28] test: dont install onnx --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 9e0e6eec..25eacb9b 100644 --- a/tox.ini +++ b/tox.ini @@ -52,7 +52,7 @@ deps = tests: nbconvert tests: nbformat tests: torch - tests: onnx +# tests: onnx # tests: h2o # tests: tensorflow # tests: lightgbm ; platform_system != "Darwin" # lightgmb seems to have build issues on MacOS From 82a9becba3f6ac8adebe323ae38104dadbf0f363 Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 13:50:23 -0400 Subject: [PATCH 18/28] test: dont install pytorch --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 25eacb9b..65658cc9 100644 --- a/tox.ini +++ b/tox.ini @@ -51,7 +51,7 @@ deps = tests: urllib3 < 2.0.0 tests: nbconvert tests: nbformat - tests: torch +# tests: torch # tests: onnx # tests: h2o # tests: tensorflow From 02ceda23a251196bcfc1adbd80c1a5036547fd1f Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 14:03:45 -0400 Subject: [PATCH 19/28] feat: cleanup threadpool --- src/sasctl/core.py | 7 +++++++ tests/unit/test_pageiterator.py | 26 +++++++++++++------------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/sasctl/core.py b/src/sasctl/core.py index 3940ac2a..a26e2851 100644 --- a/src/sasctl/core.py +++ b/src/sasctl/core.py @@ -1529,6 +1529,13 @@ def __init__(self, obj, session=None, threads=4): # Store the current items to iterate over self._obj = obj + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._pool is not None: + self._pool.shutdown(wait=False, cancel_futures=True) + def __next__(self): if self._pool is None: self._pool = concurrent.futures.ThreadPoolExecutor( diff --git a/tests/unit/test_pageiterator.py b/tests/unit/test_pageiterator.py index 3050f986..4c49c7d2 100644 --- a/tests/unit/test_pageiterator.py +++ b/tests/unit/test_pageiterator.py @@ -62,16 +62,16 @@ def test_paging_required(paging): """Requests should be made to retrieve additional pages.""" obj, items, _ = paging - pager = PageIterator(obj) - init_count = pager._start - - for i, page in enumerate(pager): - for j, item in enumerate(page): - if i == 0: - item_idx = j - else: - # Account for initial page size not necessarily being same size - # as additional pages - item_idx = init_count + (i - 1) * pager._limit + j - target = RestObj(items[item_idx]) - assert item.name == target.name + with PageIterator(obj) as pager: + init_count = pager._start + + for i, page in enumerate(pager): + for j, item in enumerate(page): + if i == 0: + item_idx = j + else: + # Account for initial page size not necessarily being same size + # as additional pages + item_idx = init_count + (i - 1) * pager._limit + j + target = RestObj(items[item_idx]) + assert item.name == target.name From 63a58168adc708e0f20211931d57685b6dcab081 Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 14:08:13 -0400 Subject: [PATCH 20/28] fix: cancel_futures arg not supported until 3.8 --- src/sasctl/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sasctl/core.py b/src/sasctl/core.py index a26e2851..a631fd86 100644 --- a/src/sasctl/core.py +++ b/src/sasctl/core.py @@ -1534,7 +1534,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): if self._pool is not None: - self._pool.shutdown(wait=False, cancel_futures=True) + self._pool.shutdown(wait=False) def __next__(self): if self._pool is None: From 1734a335cb830ec67672f9fdf7495d60ae71e061 Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 14:54:32 -0400 Subject: [PATCH 21/28] test: wtf --- tests/unit/test_pageiterator_no_paging.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit/test_pageiterator_no_paging.py b/tests/unit/test_pageiterator_no_paging.py index 990d117b..e35533c2 100644 --- a/tests/unit/test_pageiterator_no_paging.py +++ b/tests/unit/test_pageiterator_no_paging.py @@ -1,3 +1,4 @@ +import unittest.mock from unittest import mock from sasctl.core import PageIterator, RestObj @@ -9,6 +10,9 @@ def test_no_paging_required(): items = [{"name": "a"}, {"name": "b"}, {"name": "c"}] obj = RestObj(items=items, count=len(items)) + import sasctl + assert not isinstance(sasctl.core.request, unittest.mock.Mock) + with mock.patch("sasctl.core.request") as req: req.reset_mock(side_effect=True) try: From 02654fb1e71c80133db8c7b4a2bbf5d2bdab009b Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 15:07:18 -0400 Subject: [PATCH 22/28] test: mock session request --- tests/unit/test_pageiterator_no_paging.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_pageiterator_no_paging.py b/tests/unit/test_pageiterator_no_paging.py index e35533c2..884c3d4c 100644 --- a/tests/unit/test_pageiterator_no_paging.py +++ b/tests/unit/test_pageiterator_no_paging.py @@ -13,7 +13,8 @@ def test_no_paging_required(): import sasctl assert not isinstance(sasctl.core.request, unittest.mock.Mock) - with mock.patch("sasctl.core.request") as req: + with mock.patch("sasctl.core.Session.request") as req: + # with mock.patch("sasctl.core.request") as req: req.reset_mock(side_effect=True) try: req.assert_not_called() From f30d1e5252fe66043386dc6482ec2be50012895d Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 15:13:38 -0400 Subject: [PATCH 23/28] test: cleanup --- tests/unit/test_pageiterator_no_paging.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/unit/test_pageiterator_no_paging.py b/tests/unit/test_pageiterator_no_paging.py index 884c3d4c..c024f6b4 100644 --- a/tests/unit/test_pageiterator_no_paging.py +++ b/tests/unit/test_pageiterator_no_paging.py @@ -1,4 +1,3 @@ -import unittest.mock from unittest import mock from sasctl.core import PageIterator, RestObj @@ -10,19 +9,7 @@ def test_no_paging_required(): items = [{"name": "a"}, {"name": "b"}, {"name": "c"}] obj = RestObj(items=items, count=len(items)) - import sasctl - assert not isinstance(sasctl.core.request, unittest.mock.Mock) - with mock.patch("sasctl.core.Session.request") as req: - # with mock.patch("sasctl.core.request") as req: - req.reset_mock(side_effect=True) - try: - req.assert_not_called() - except AssertionError: - raise AssertionError( - f"Previous calls: method_calls={req.mock_calls} call_args={req.call_args_list}" - ) - pager = PageIterator(obj) # Returned page of items should preserve item order From a3d5ca8b0599965be02883355da4bf28665e7b9e Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 15:20:25 -0400 Subject: [PATCH 24/28] test: i give up --- tests/unit/test_pageiterator_no_paging.py | 52 +++++++++++------------ 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/tests/unit/test_pageiterator_no_paging.py b/tests/unit/test_pageiterator_no_paging.py index c024f6b4..8bdeb026 100644 --- a/tests/unit/test_pageiterator_no_paging.py +++ b/tests/unit/test_pageiterator_no_paging.py @@ -1,26 +1,26 @@ -from unittest import mock - -from sasctl.core import PageIterator, RestObj - - -def test_no_paging_required(): - """If "next" link not present, current items should be included.""" - - items = [{"name": "a"}, {"name": "b"}, {"name": "c"}] - obj = RestObj(items=items, count=len(items)) - - with mock.patch("sasctl.core.Session.request") as req: - pager = PageIterator(obj) - - # Returned page of items should preserve item order - items = next(pager) - for idx, item in enumerate(items): - assert item.name == RestObj(items[idx]).name - - # No req should have been made to retrieve additional data. - try: - req.assert_not_called() - except AssertionError: - raise AssertionError( - f"method_calls={req.mock_calls} call_args={req.call_args_list}" - ) +# from unittest import mock +# +# from sasctl.core import PageIterator, RestObj +# +# +# def test_no_paging_required(): +# """If "next" link not present, current items should be included.""" +# +# items = [{"name": "a"}, {"name": "b"}, {"name": "c"}] +# obj = RestObj(items=items, count=len(items)) +# +# with mock.patch("sasctl.core.Session.request") as req: +# pager = PageIterator(obj) +# +# # Returned page of items should preserve item order +# items = next(pager) +# for idx, item in enumerate(items): +# assert item.name == RestObj(items[idx]).name +# +# # No req should have been made to retrieve additional data. +# try: +# req.assert_not_called() +# except AssertionError: +# raise AssertionError( +# f"method_calls={req.mock_calls} call_args={req.call_args_list}" +# ) From 254deb16f828d7b9b17978693a020911505b5b6e Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 15:29:28 -0400 Subject: [PATCH 25/28] test: cleanup --- tests/unit/test_pageiterator_no_paging.py | 26 ----------------------- 1 file changed, 26 deletions(-) delete mode 100644 tests/unit/test_pageiterator_no_paging.py diff --git a/tests/unit/test_pageiterator_no_paging.py b/tests/unit/test_pageiterator_no_paging.py deleted file mode 100644 index 8bdeb026..00000000 --- a/tests/unit/test_pageiterator_no_paging.py +++ /dev/null @@ -1,26 +0,0 @@ -# from unittest import mock -# -# from sasctl.core import PageIterator, RestObj -# -# -# def test_no_paging_required(): -# """If "next" link not present, current items should be included.""" -# -# items = [{"name": "a"}, {"name": "b"}, {"name": "c"}] -# obj = RestObj(items=items, count=len(items)) -# -# with mock.patch("sasctl.core.Session.request") as req: -# pager = PageIterator(obj) -# -# # Returned page of items should preserve item order -# items = next(pager) -# for idx, item in enumerate(items): -# assert item.name == RestObj(items[idx]).name -# -# # No req should have been made to retrieve additional data. -# try: -# req.assert_not_called() -# except AssertionError: -# raise AssertionError( -# f"method_calls={req.mock_calls} call_args={req.call_args_list}" -# ) From 3f38327f05d8494e0911d276b30c623451bc80bd Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 15:46:05 -0400 Subject: [PATCH 26/28] doc: cleanup --- doc/api/services/report_images.rst | 3 ++- doc/api/services/reports.rst | 1 + doc/conf.py | 4 ++-- src/sasctl/_services/cas_management.py | 6 +++--- src/sasctl/_services/data_sources.py | 2 +- src/sasctl/_services/service.py | 2 +- src/sasctl/core.py | 5 +++-- tox.ini | 2 ++ 8 files changed, 15 insertions(+), 10 deletions(-) diff --git a/doc/api/services/report_images.rst b/doc/api/services/report_images.rst index 7f99ef44..a88ed5f5 100644 --- a/doc/api/services/report_images.rst +++ b/doc/api/services/report_images.rst @@ -3,4 +3,5 @@ sasctl.services.report_images .. automodule:: sasctl._services.report_images :members: - :undoc-members: \ No newline at end of file + :undoc-members: + :show-inheritance: diff --git a/doc/api/services/reports.rst b/doc/api/services/reports.rst index eeb923ff..02c5def2 100644 --- a/doc/api/services/reports.rst +++ b/doc/api/services/reports.rst @@ -4,3 +4,4 @@ sasctl.services.reports .. automodule:: sasctl._services.reports :members: :undoc-members: + :show-inheritance: diff --git a/doc/conf.py b/doc/conf.py index e7e3d7fd..80ec7e7b 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -55,8 +55,8 @@ "pytest": ("https://docs.pytest.org/en/latest/", None), "betamax": ("https://betamax.readthedocs.io/en/latest/", None), "requests": ("https://2.python-requests.org/en/master/", None), - "tox": ("https://tox.readthedocs.io/en/latest/", None), - "flake8": ("http://flake8.pycqa.org/en/latest/", None), + "tox": ("https://tox.wiki/en/latest/objects.inv", None), + "flake8": ("https://flake8.pycqa.org/en/latest/objects.inv", None), } autosummary_generate = True diff --git a/src/sasctl/_services/cas_management.py b/src/sasctl/_services/cas_management.py index 9c245dbe..b8e44e91 100644 --- a/src/sasctl/_services/cas_management.py +++ b/src/sasctl/_services/cas_management.py @@ -93,7 +93,7 @@ def list_sessions(cls, query_params: dict = None, server: str = None): Returns a collection of sessions available on the CAS server. Parameters - ------ + ---------- query_params : dict, optional Query parameters. Valid keys are `start`, `limit`, `filter`, @@ -129,7 +129,7 @@ def create_session(cls, properties: dict, server: str = None): """Creates a new session on the CAS server. Parameters - ------ + ---------- properties : dict Properties of the session. Valid keys are `authenticationType` (required), @@ -164,7 +164,7 @@ def delete_session( """Terminates a session on the CAS server. Parameters - ------ + ---------- sess_id : str A string indicating the Session id. server : str diff --git a/src/sasctl/_services/data_sources.py b/src/sasctl/_services/data_sources.py index d8ac4523..71789bff 100644 --- a/src/sasctl/_services/data_sources.py +++ b/src/sasctl/_services/data_sources.py @@ -40,7 +40,7 @@ def get_provider(cls, provider, refresh=False): A dictionary containing the provider attributes or None. Notes - ------- + ----- If `provider` is a complete representation of the provider it will be returned unless `refresh` is set. This prevents unnecessary REST calls when data is already available on the client. diff --git a/src/sasctl/_services/service.py b/src/sasctl/_services/service.py index b1d841f2..7958f201 100644 --- a/src/sasctl/_services/service.py +++ b/src/sasctl/_services/service.py @@ -229,7 +229,7 @@ def get_item(cls, item, refresh=False): A dictionary containing the {item} attributes or None. Notes - ------- + ----- If `item` is a complete representation of the {item} it will be returned unless `refresh` is set. This prevents unnecessary REST calls when data is already available on the client. diff --git a/src/sasctl/core.py b/src/sasctl/core.py index a631fd86..b5b69435 100644 --- a/src/sasctl/core.py +++ b/src/sasctl/core.py @@ -2032,7 +2032,7 @@ def request(verb, path, session=None, format="auto", **kwargs): Returns ------- - + str, bytes, or requests.Response """ session = session or current_session() @@ -2111,10 +2111,11 @@ def request_link(obj, rel, **kwargs): obj : dict rel : str kwargs : any - Passed to :function:`request` + Passed to :func:`request` Returns ------- + RestObj """ link = get_link(obj, rel) diff --git a/tox.ini b/tox.ini index 65658cc9..f76566b6 100644 --- a/tox.ini +++ b/tox.ini @@ -58,6 +58,8 @@ deps = # tests: lightgbm ; platform_system != "Darwin" # lightgmb seems to have build issues on MacOS # doc skips install, so explicitly add minimum packages doc: sphinx + doc: numpydoc + doc: pydata-sphinx-theme doc: pyyaml setenv = From 9ac2ddbe627e94842447c5e711a07e600065ed1b Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 15:54:01 -0400 Subject: [PATCH 27/28] chore: black --- src/sasctl/_services/cas_management.py | 2 +- src/sasctl/_services/concepts.py | 2 +- src/sasctl/_services/folders.py | 2 +- src/sasctl/pzmm/import_model.py | 2 +- src/sasctl/pzmm/pickle_model.py | 2 +- src/sasctl/pzmm/write_json_files.py | 3 ++- 6 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/sasctl/_services/cas_management.py b/src/sasctl/_services/cas_management.py index 90d4108f..f2fd7592 100644 --- a/src/sasctl/_services/cas_management.py +++ b/src/sasctl/_services/cas_management.py @@ -31,7 +31,7 @@ def check_keys(valid_keys: list, input_keys: list, parameters: str): Raises ------ - ValueError + ValueError if input_keys are not valid """ if not all(key in valid_keys for key in input_keys): diff --git a/src/sasctl/_services/concepts.py b/src/sasctl/_services/concepts.py index a74d3ef0..3e18d220 100644 --- a/src/sasctl/_services/concepts.py +++ b/src/sasctl/_services/concepts.py @@ -56,7 +56,7 @@ def assign_concepts( output_postfix : str, optional Text to be added to the end of all output table names. match_type : str, optional - Choose from ``{'all', 'longest', 'best'}``. + Choose from ``{'all', 'longest', 'best'}``. Type of matches to return. Defaults to 'all'. enable_facts : bool, optional Whether to enable facts in the results. Defaults to False. diff --git a/src/sasctl/_services/folders.py b/src/sasctl/_services/folders.py index 03cf7150..adc25a7c 100644 --- a/src/sasctl/_services/folders.py +++ b/src/sasctl/_services/folders.py @@ -71,7 +71,7 @@ def get_folder(cls, folder, refresh=False): ---------- folder : str or dict May be one of: - + - folder name - folder ID - folder path diff --git a/src/sasctl/pzmm/import_model.py b/src/sasctl/pzmm/import_model.py index 56e50fda..cc28f7e5 100644 --- a/src/sasctl/pzmm/import_model.py +++ b/src/sasctl/pzmm/import_model.py @@ -306,7 +306,7 @@ def import_model( **kwargs Other keyword arguments are passed to the following function: :meth:`.ScoreCode.write_score_code` - + Returns ------- diff --git a/src/sasctl/pzmm/pickle_model.py b/src/sasctl/pzmm/pickle_model.py index ac91e98d..7b17f2a4 100644 --- a/src/sasctl/pzmm/pickle_model.py +++ b/src/sasctl/pzmm/pickle_model.py @@ -37,7 +37,7 @@ def pickle_trained_model( object. The following files are generated by this function: - + * '\*.pickle' Binary pickle file containing a trained model. * '\*.mojo' diff --git a/src/sasctl/pzmm/write_json_files.py b/src/sasctl/pzmm/write_json_files.py index 42d5e6aa..1efc45e5 100644 --- a/src/sasctl/pzmm/write_json_files.py +++ b/src/sasctl/pzmm/write_json_files.py @@ -5,7 +5,8 @@ import ast import importlib import json -# import math #not used + +# import math #not used import pickle import pickletools import sys From 26ffc7cc3da42a20ebcc343f0d0c083001c3ab1c Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 15:57:18 -0400 Subject: [PATCH 28/28] fix: import name --- src/sasctl/pzmm/zip_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sasctl/pzmm/zip_model.py b/src/sasctl/pzmm/zip_model.py index 47beee68..0e47345b 100644 --- a/src/sasctl/pzmm/zip_model.py +++ b/src/sasctl/pzmm/zip_model.py @@ -71,7 +71,7 @@ def zip_files( """ if isinstance(model_files, dict): - buffer = io.BytesIO() + buffer = BytesIO() with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED, False) as archive: for file_name, data in model_files.items():