Skip to content

Commit 7dc5167

Browse files
committed
Make numpy and sklearn optional requirements; include errors for functions requiring these packages #140
1 parent 16b5d5a commit 7dc5167

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

setup.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,13 @@ def get_file(filename):
4141
include_package_data=True,
4242
packages=find_packages(where="src"),
4343
package_dir={"": "src"},
44-
python_requires=">=3.5",
45-
install_requires=["pandas", "scikit-learn", "requests", "pyyaml", "packaging"],
44+
python_requires=">=3.6",
45+
install_requires=["pandas", "requests", "pyyaml", "packaging"],
4646
extras_require={
4747
"swat": ["swat"],
4848
"GitPython": ["GitPython"],
49+
"numpy": ["numpy"],
50+
"scikit-learn": ["scikit-learn"],
4951
"kerberos": [
5052
'kerberos ; platform_system != "Windows"',
5153
'winkerberos ; platform_system == "Windows"',
@@ -66,7 +68,6 @@ def get_file(filename):
6668
"Intended Audience :: Developers",
6769
"Programming Language :: Python",
6870
"Programming Language :: Python :: 3",
69-
"Programming Language :: Python :: 3.5",
7071
"Programming Language :: Python :: 3.6",
7172
"Programming Language :: Python :: 3.7",
7273
"Programming Language :: Python :: 3.8",

src/sasctl/pzmm/writeScoreCode.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from pathlib import Path
5-
import numpy as np
65
import re
76
from ..core import platform_version
87
from .._services.model_repository import ModelRepository as modelRepo
@@ -66,8 +65,8 @@ def writeScoreCode(
6665
columns. The writeScoreCode function currently supports int(64), float(64),
6766
and string data types for scoring. Providing a list of dict objects signals
6867
that the model files are being created from an MLFlow model.
69-
targetDF : DataFrame
70-
The `DataFrame` object contains the training data for the target variable. Note that
68+
targetDF : pandas Series
69+
The `DataFrame Series` object contains the training data for the target variable. Note that
7170
for MLFlow models, this can be set as None.
7271
modelPrefix : string
7372
The variable for the model name that is used when naming model files.
@@ -492,7 +491,7 @@ def score{modelPrefix}({inputVarList}):
492491
)
493492
)
494493
if threshPrediction is None:
495-
threshPrediction = np.mean(targetDF)
494+
threshPrediction = targetDF.mean()
496495
cls.pyFile.write(
497496
"""\n
498497
if ({metric0} >= {threshold}):

src/sasctl/pzmm/write_json_files.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
# %%
55
from pathlib import Path
66
import sys
7-
87
import getpass
98
import json
109
import pandas as pd
11-
from sklearn import metrics
1210
import math
1311

1412

@@ -476,11 +474,17 @@ def calculateFitStat(
476474
"""
477475
# If numpy inputs are supplied, then it is assumed that numpy is installed in the environment
478476
try:
479-
# noinspection PyPackageRequirements
480477
import numpy as np
481478
except ImportError:
482479
np = None
483480

481+
try:
482+
from sklearn import metrics
483+
except ImportError:
484+
raise RuntimeError(
485+
"The 'scikit-learn' package is required to use the calculateFitStat function."
486+
)
487+
484488
nullJSONPath = Path(__file__).resolve().parent / "null_dmcas_fitstat.json"
485489
nullJSONDict = self.readJSONFile(nullJSONPath)
486490

0 commit comments

Comments
 (0)