Skip to content

Commit bedaadb

Browse files
committed
fix: update for tree-based models
1 parent 793a1fe commit bedaadb

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

tests/integration/test_astore_models.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,10 @@ def test_forest_regression(cas_session, boston_dataset):
208208
props = _get_model_properties(desc)
209209

210210
for k, v in target.items():
211-
assert props[k] == v
211+
if k == "algorithm":
212+
assert props[k] in ("Random forest", "Tree-based model")
213+
else:
214+
assert props[k] == v
212215

213216
files = create_files_from_astore(cas_session.CASTable("astore"))
214217
check_input_variables(files, BOSTON_INPUT_VARS)
@@ -239,7 +242,10 @@ def test_forest_regression_with_nominals(cas_session, boston_dataset):
239242
props = _get_model_properties(desc)
240243

241244
for k, v in target.items():
242-
assert props[k] == v
245+
if k == "algorithm":
246+
assert props[k] in ("Random forest", "Tree-based model")
247+
else:
248+
assert props[k] == v
243249

244250
files = create_files_from_astore(cas_session.CASTable("astore"))
245251
check_input_variables(files, BOSTON_INPUT_VARS)
@@ -268,7 +274,10 @@ def test_gradboost_binary_classification(cas_session, cancer_dataset):
268274
props = _get_model_properties(desc)
269275

270276
for k, v in target.items():
271-
assert props[k] == v
277+
if k == "algorithm":
278+
assert props[k] in ("Gradient boosting", "Tree-based model")
279+
else:
280+
assert props[k] == v
272281

273282
files = create_files_from_astore(cas_session.CASTable("astore"))
274283
check_input_variables(files, CANCER_INPUT_VARS)
@@ -296,7 +305,10 @@ def test_gradboost_classification(cas_session, iris_dataset):
296305
props = _get_model_properties(desc)
297306

298307
for k, v in target.items():
299-
assert props[k] == v
308+
if k == "algorithm":
309+
assert props[k] in ("Gradient boosting", "Tree-based model")
310+
else:
311+
assert props[k] == v
300312

301313
files = create_files_from_astore(cas_session.CASTable("astore"))
302314
check_input_variables(files, IRIS_INPUT_VARS)
@@ -324,7 +336,10 @@ def test_gradboost_regression(cas_session, boston_dataset):
324336
props = _get_model_properties(desc)
325337

326338
for k, v in target.items():
327-
assert props[k] == v
339+
if k == "algorithm":
340+
assert props[k] in ("Gradient boosting", "Tree-based model")
341+
else:
342+
assert props[k] == v
328343

329344
files = create_files_from_astore(cas_session.CASTable("astore"))
330345
check_input_variables(files, BOSTON_INPUT_VARS)
@@ -355,7 +370,10 @@ def test_gradboost_regression_with_nominals(cas_session, boston_dataset):
355370
props = _get_model_properties(desc)
356371

357372
for k, v in target.items():
358-
assert props[k] == v
373+
if k == "algorithm":
374+
assert props[k] in ("Gradient boosting", "Tree-based model")
375+
else:
376+
assert props[k] == v
359377

360378
files = create_files_from_astore(cas_session.CASTable("astore"))
361379
check_input_variables(files, BOSTON_INPUT_VARS)

0 commit comments

Comments
 (0)