Skip to content

Commit c351153

Browse files
VipulMascarenhasmayoor
authored andcommitted
Update unit tests
1 parent e8e731c commit c351153

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

tests/unitary/with_extras/aqua/test_finetuning.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
2323
from ads.aqua.finetuning.entities import AquaFineTuningParams
2424
from ads.jobs.ads_job import Job
25+
from ads.jobs.builders.infrastructure.dsc_job import DataScienceJobRun
2526
from ads.model.datascience_model import DataScienceModel
2627
from ads.model.model_metadata import ModelCustomMetadata
2728
from ads.aqua.common.errors import AquaValueError
@@ -49,6 +50,12 @@ def tearDownClass(cls):
4950
reload(ads.aqua)
5051
reload(ads.aqua.finetuning.finetuning)
5152

53+
@parameterized.expand(
54+
[
55+
("watch_logs", True),
56+
("watch_logs", False),
57+
]
58+
)
5259
@patch.object(Job, "run")
5360
@patch("ads.jobs.ads_job.Job.name", new_callable=PropertyMock)
5461
@patch("ads.jobs.ads_job.Job.id", new_callable=PropertyMock)
@@ -60,6 +67,8 @@ def tearDownClass(cls):
6067
@patch.object(AquaApp, "get_source")
6168
def test_create_fine_tuning(
6269
self,
70+
mock_watch_logs,
71+
mock_watch_logs_called,
6372
mock_get_source,
6473
mock_mvs_create,
6574
mock_ds_model_create,
@@ -117,6 +126,7 @@ def test_create_fine_tuning(
117126
ft_job_run.id = "test_ft_job_run_id"
118127
ft_job_run.lifecycle_details = "Job run artifact execution in progress."
119128
ft_job_run.lifecycle_state = "IN_PROGRESS"
129+
ft_job_run.watch = MagicMock()
120130
mock_job_run.return_value = ft_job_run
121131

122132
self.app.ds_client.update_model = MagicMock()
@@ -144,7 +154,20 @@ def test_create_fine_tuning(
144154
defined_tags=ft_model_defined_tags,
145155
)
146156

147-
aqua_ft_summary = self.app.create(**create_aqua_ft_details)
157+
inputs = {
158+
**create_aqua_ft_details,
159+
**{
160+
mock_watch_logs: mock_watch_logs_called,
161+
"log_id": "test_log_id",
162+
"log_group_id": "test_log_group_id",
163+
},
164+
}
165+
aqua_ft_summary = self.app.create(**inputs)
166+
167+
if mock_watch_logs_called:
168+
ft_job_run.watch.assert_called()
169+
else:
170+
ft_job_run.watch.assert_not_called()
148171

149172
assert aqua_ft_summary.to_dict() == {
150173
"console_url": f"https://cloud.oracle.com/data-science/models/{ft_model.id}?region={self.app.region}",

0 commit comments

Comments
 (0)