Skip to content

Commit 5c42cb8

Browse files
committed
Fixes unit tests
1 parent 8c7006d commit 5c42cb8

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,31 +45,29 @@
4545
from ads.model.service.oci_datascience_model import OCIDataScienceModel
4646

4747

48-
# Fixture that reloads the module before any patching is applied.
49-
@pytest.fixture(autouse=True, scope="class")
50-
def reload_model_module():
51-
reload(ads.aqua.model.model)
52-
yield
53-
54-
5548
@pytest.fixture(autouse=True, scope="class")
5649
def mock_auth():
5750
with patch("ads.common.auth.default_signer") as mock_default_signer:
5851
yield mock_default_signer
5952

6053

54+
def get_container_config():
55+
with open(
56+
os.path.join(
57+
os.path.dirname(os.path.abspath(__file__)),
58+
"test_data/ui/container_index.json",
59+
),
60+
"r",
61+
) as _file:
62+
container_index_json = json.load(_file)
63+
64+
return container_index_json
65+
66+
6167
@pytest.fixture(autouse=True, scope="class")
6268
def mock_get_container_config():
6369
with patch("ads.aqua.model.model.get_container_config") as mock_config:
64-
with open(
65-
os.path.join(
66-
os.path.dirname(os.path.abspath(__file__)),
67-
"test_data/ui/container_index.json",
68-
),
69-
"r",
70-
) as _file:
71-
container_index_json = json.load(_file)
72-
mock_config.return_value = container_index_json
70+
mock_config.return_value = get_container_config()
7371
yield mock_config
7472

7573

@@ -283,7 +281,7 @@ def setup_class(cls):
283281
os.environ["ODSC_MODEL_COMPARTMENT_OCID"] = TestDataset.SERVICE_COMPARTMENT_ID
284282
reload(ads.config)
285283
reload(ads.aqua)
286-
# reload(ads.aqua.model.model)
284+
reload(ads.aqua.model.model)
287285

288286
@classmethod
289287
def teardown_class(cls):
@@ -382,6 +380,7 @@ def test_get_foundation_models(
382380
mock_get_container_config,
383381
mock_auth,
384382
):
383+
mock_get_container_config.return_value = get_container_config()
385384
ds_model = MagicMock()
386385
ds_model.id = "test_id"
387386
ds_model.compartment_id = "test_compartment_id"
@@ -496,6 +495,7 @@ def test_get_model_fine_tuned(
496495
mock_get_container_config,
497496
mock_auth,
498497
):
498+
mock_get_container_config.return_value = get_container_config()
499499
ds_model = MagicMock()
500500
ds_model.id = "test_id"
501501
ds_model.compartment_id = "test_model_compartment_id"

0 commit comments

Comments
 (0)