Skip to content

Commit 6f3c523

Browse files
committed
changed imports in test_model_parameters.py to allow tests to work in GitHub
1 parent 14ddd33 commit 6f3c523

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tests/unit/test_model_parameters.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@
1616
from sasctl.pzmm import ModelParameters as mp
1717
import unittest
1818
import uuid
19-
import xgboost
20-
import h2o
21-
import tensorflow as tf
22-
import statsmodels.formula.api as smf
2319
import numpy as np
2420

2521

@@ -297,6 +293,7 @@ def test_overwrite(self):
297293

298294
class TestGenerateHyperparameters(unittest.TestCase):
299295
def test_xgboost(self):
296+
xgboost = pytest.importorskip("xgboost")
300297
model = unittest.mock.Mock()
301298
model.__class__ = xgboost.Booster
302299
attrs = {"save_config.return_value": json.dumps({"test": "passed"})}
@@ -306,6 +303,7 @@ def test_xgboost(self):
306303
assert Path(Path(tmp_dir.name) / f"./prefixHyperparameters.json").exists()
307304

308305
def test_xgboost_sklearn(self):
306+
xgboost = pytest.importorskip("xgboost")
309307
model = unittest.mock.Mock()
310308
model.__class__ = xgboost.XGBModel
311309
attrs = {"get_params.return_value": json.dumps({"test": "passed"})}
@@ -315,6 +313,7 @@ def test_xgboost_sklearn(self):
315313
assert Path(Path(tmp_dir.name) / f"./prefixHyperparameters.json").exists()
316314

317315
def test_h2o(self):
316+
h2o = pytest.importorskip("h2o")
318317
model = unittest.mock.Mock()
319318
model.__class__ = h2o.H2OFrame
320319
attrs = {"get_params.return_value": json.dumps({"test": "passed"})}
@@ -324,6 +323,7 @@ def test_h2o(self):
324323
assert Path(Path(tmp_dir.name) / f"./prefixHyperparameters.json").exists()
325324

326325
def test_tensorflow(self):
326+
tf = pytest.importorskip("tensorflow")
327327
model = unittest.mock.Mock()
328328
model.__class__ = tf.keras.Sequential
329329
attrs = {"get_config.return_value": json.dumps({"test": "passed"})}
@@ -333,6 +333,7 @@ def test_tensorflow(self):
333333
assert Path(Path(tmp_dir.name) / f"./prefixHyperparameters.json").exists()
334334

335335
def test_statsmodels(self):
336+
smf = pytest.importorskip("statsmodels.formula.api")
336337
model = unittest.mock.Mock(
337338
exog_names=["test", "exog"], weights=np.array([0, 1])
338339
)

0 commit comments

Comments
 (0)