Skip to content

Commit 68e2f6d

Browse files
committed
Updated pr.
1 parent 7dc3d15 commit 68e2f6d

File tree

1 file changed

+277
-4
lines changed

1 file changed

+277
-4
lines changed

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 277 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@
1010
from importlib import reload
1111
from unittest.mock import MagicMock
1212

13+
from mock import patch
1314
import oci
1415
from parameterized import parameterized
1516

1617
import ads.aqua.model
1718
import ads.config
1819
from ads.aqua.model import AquaModelApp, AquaModelSummary
20+
from ads.model.datascience_model import DataScienceModel
21+
from ads.model.model_metadata import ModelCustomMetadata, ModelProvenanceMetadata, ModelTaxonomyMetadata
1922

2023

2124
class TestDataset:
@@ -71,6 +74,8 @@ class TestAquaModel(unittest.TestCase):
7174
"""Contains unittests for AquaModelApp."""
7275

7376
def setUp(self):
77+
import ads
78+
ads.set_auth("security_token")
7479
self.app = AquaModelApp()
7580

7681
@classmethod
@@ -89,11 +94,279 @@ def tearDownClass(cls):
8994
reload(ads.aqua)
9095
reload(ads.aqua.model)
9196

92-
def test_create_model(self):
93-
pass
97+
@patch.object(DataScienceModel, "create")
98+
@patch("ads.model.datascience_model.validate")
99+
@patch.object(DataScienceModel, "from_id")
100+
def test_create_model(self, mock_from_id, mock_validate, mock_create):
101+
service_model = MagicMock()
102+
service_model.model_file_description = {"test_key":"test_value"}
103+
service_model.display_name = "test_display_name"
104+
service_model.description = "test_description"
105+
service_model.freeform_tags = {"test_key":"test_value"}
106+
custom_metadata_list = ModelCustomMetadata()
107+
custom_metadata_list.add(
108+
key="test_metadata_item_key",
109+
value="test_metadata_item_value"
110+
)
111+
service_model.custom_metadata_list = custom_metadata_list
112+
service_model.provenance_metadata = ModelProvenanceMetadata(
113+
training_id="test_training_id"
114+
)
115+
mock_from_id.return_value = service_model
116+
117+
# will not copy service model
118+
self.app.create(
119+
model_id="test_model_id",
120+
project_id="test_project_id",
121+
compartment_id="test_compartment_id",
122+
)
123+
124+
mock_from_id.assert_called_with("test_model_id")
125+
mock_validate.assert_not_called()
126+
mock_create.assert_not_called()
127+
128+
service_model.compartment_id = TestDataset.SERVICE_COMPARTMENT_ID
129+
mock_from_id.return_value = service_model
130+
131+
# will copy service model
132+
self.app.create(
133+
model_id="test_model_id",
134+
project_id="test_project_id",
135+
compartment_id="test_compartment_id"
136+
)
137+
138+
mock_from_id.assert_called_with("test_model_id")
139+
mock_validate.assert_called()
140+
mock_create.assert_called_with(
141+
model_by_reference=True
142+
)
143+
144+
@patch("ads.aqua.model.read_file")
145+
@patch.object(DataScienceModel, "from_id")
146+
def test_get_model_not_fine_tuned(self, mock_from_id, mock_read_file):
147+
ds_model = MagicMock()
148+
ds_model.id = "test_id"
149+
ds_model.compartment_id = "test_compartment_id"
150+
ds_model.project_id = "test_project_id"
151+
ds_model.display_name = "test_display_name"
152+
ds_model.description = "test_description"
153+
ds_model.freeform_tags = {
154+
"OCI_AQUA":"ACTIVE",
155+
"license":"test_license",
156+
"organization":"test_organization",
157+
"task":"test_task"
158+
}
159+
ds_model.time_created = "2024-01-19T17:57:39.158000+00:00"
160+
custom_metadata_list = ModelCustomMetadata()
161+
custom_metadata_list.add(
162+
key="artifact_location",
163+
value="oci://bucket@namespace/prefix/"
164+
)
165+
ds_model.custom_metadata_list = custom_metadata_list
166+
167+
mock_from_id.return_value = ds_model
168+
mock_read_file.return_value = "test_model_card"
169+
170+
aqua_model = self.app.get(model_id="test_model_id")
171+
172+
mock_from_id.assert_called_with("test_model_id")
173+
mock_read_file.assert_called_with(
174+
file_path="oci://bucket@namespace/prefix/README.md",
175+
auth=self.app._auth,
176+
)
177+
178+
assert asdict(aqua_model) == {
179+
'compartment_id': f'{ds_model.compartment_id}',
180+
'console_link': (
181+
f'https://cloud.oracle.com/data-science/models/{ds_model.id}?region={self.app.region}',
182+
),
183+
'icon': '',
184+
'id': f'{ds_model.id}',
185+
'is_fine_tuned_model': False,
186+
'license': f'{ds_model.freeform_tags["license"]}',
187+
'model_card': f'{mock_read_file.return_value}',
188+
'name': f'{ds_model.display_name}',
189+
'organization': f'{ds_model.freeform_tags["organization"]}',
190+
'project_id': f'{ds_model.project_id}',
191+
'ready_to_deploy': True,
192+
'ready_to_finetune': False,
193+
'search_text': 'ACTIVE,test_license,test_organization,test_task',
194+
'tags': ds_model.freeform_tags,
195+
'task': f'{ds_model.freeform_tags["task"]}',
196+
'time_created': f'{ds_model.time_created}'
197+
}
198+
199+
@patch("ads.aqua.utils.query_resource")
200+
@patch("ads.aqua.model.read_file")
201+
@patch.object(DataScienceModel, "from_id")
202+
def test_get_model_fine_tuned(self, mock_from_id, mock_read_file, mock_query_resource):
203+
ds_model = MagicMock()
204+
ds_model.id = "test_id"
205+
ds_model.compartment_id = "test_model_compartment_id"
206+
ds_model.project_id = "test_project_id"
207+
ds_model.display_name = "test_display_name"
208+
ds_model.description = "test_description"
209+
ds_model.model_version_set_id = "test_model_version_set_id"
210+
ds_model.model_version_set_name = "test_model_version_set_name"
211+
ds_model.freeform_tags = {
212+
"OCI_AQUA":"ACTIVE",
213+
"license":"test_license",
214+
"organization":"test_organization",
215+
"task":"test_task",
216+
"aqua_fine_tuned_model":"test_finetuned_model"
217+
}
218+
ds_model.time_created = "2024-01-19T17:57:39.158000+00:00"
219+
ds_model.lifecycle_state = "ACTIVE"
220+
custom_metadata_list = ModelCustomMetadata()
221+
custom_metadata_list.add(
222+
key="artifact_location",
223+
value="oci://bucket@namespace/prefix/"
224+
)
225+
custom_metadata_list.add(
226+
key="fine_tune_source",
227+
value="test_fine_tuned_source_id"
228+
)
229+
custom_metadata_list.add(
230+
key="fine_tune_source_name",
231+
value="test_fine_tuned_source_name"
232+
)
233+
ds_model.custom_metadata_list = custom_metadata_list
234+
defined_metadata_list = ModelTaxonomyMetadata()
235+
defined_metadata_list["Hyperparameters"].value = {
236+
"training_data" : "test_training_data",
237+
"val_set_size" : "test_val_set_size"
238+
}
239+
ds_model.defined_metadata_list = defined_metadata_list
240+
ds_model.provenance_metadata = ModelProvenanceMetadata(
241+
training_id="test_training_job_run_id"
242+
)
243+
244+
mock_from_id.return_value = ds_model
245+
mock_read_file.return_value = "test_model_card"
246+
247+
response = MagicMock()
248+
job_run = MagicMock()
249+
job_run.id = "test_job_run_id"
250+
job_run.lifecycle_state = "SUCCEEDED"
251+
job_run.lifecycle_details = "test lifecycle details"
252+
job_run.identifier = "test_job_id",
253+
job_run.display_name = "test_job_name"
254+
job_run.compartment_id = "test_job_run_compartment_id"
255+
job_infrastructure_configuration_details = MagicMock()
256+
job_infrastructure_configuration_details.shape_name = "test_shape_name"
257+
258+
job_configuration_override_details = MagicMock()
259+
job_configuration_override_details.environment_variables = {
260+
"NODE_COUNT" : 1
261+
}
262+
job_run.job_infrastructure_configuration_details = job_infrastructure_configuration_details
263+
job_run.job_configuration_override_details = job_configuration_override_details
264+
log_details = MagicMock()
265+
log_details.log_id = "test_log_id"
266+
log_details.log_group_id = "test_log_group_id"
267+
job_run.log_details = log_details
268+
response.data = job_run
269+
self.app.ds_client.get_job_run = MagicMock(
270+
return_value = response
271+
)
272+
273+
query_resource = MagicMock()
274+
query_resource.display_name = "test_display_name"
275+
mock_query_resource.return_value = query_resource
276+
277+
model = self.app.get(model_id="test_model_id")
278+
279+
mock_from_id.assert_called_with("test_model_id")
280+
mock_read_file.assert_called_with(
281+
file_path="oci://bucket@namespace/prefix/README.md",
282+
auth=self.app._auth,
283+
)
284+
mock_query_resource.assert_called()
285+
286+
assert asdict(model) == {
287+
'compartment_id': f'{ds_model.compartment_id}',
288+
'console_link': (
289+
f'https://cloud.oracle.com/data-science/models/{ds_model.id}?region={self.app.region}',
290+
),
291+
'dataset': 'test_training_data',
292+
'experiment': {'id': '', 'name': '', 'url': ''},
293+
'icon': '',
294+
'id': f'{ds_model.id}',
295+
'is_fine_tuned_model': True,
296+
'job': {'id': '', 'name': '', 'url': ''},
297+
'license': 'test_license',
298+
'lifecycle_details': f'{job_run.lifecycle_details}',
299+
'lifecycle_state': f'{ds_model.lifecycle_state}',
300+
'log': {
301+
'id': f'{log_details.log_id}',
302+
'name': f'{query_resource.display_name}',
303+
'url': 'https://cloud.oracle.com/logging/search?searchQuery=search '
304+
f'"{job_run.compartment_id}/{log_details.log_group_id}/{log_details.log_id}" | '
305+
f"source='{job_run.id}' | sort by datetime desc&regions={self.app.region}"
306+
},
307+
'log_group': {
308+
'id': f'{log_details.log_group_id}',
309+
'name': f'{query_resource.display_name}',
310+
'url': f'https://cloud.oracle.com/logging/log-groups/{log_details.log_group_id}?region={self.app.region}'
311+
},
312+
'metrics': [
313+
{
314+
'category': 'validation',
315+
'name': 'validation_metrics',
316+
'scores': []
317+
},
318+
{
319+
'category': 'training',
320+
'name': 'training_metrics',
321+
'scores': []
322+
},
323+
{
324+
'category': 'validation',
325+
'name': 'validation_metrics_final',
326+
'scores': []
327+
},
328+
{
329+
'category': 'training',
330+
'name': 'training_metrics_final',
331+
'scores': []
332+
}
333+
],
334+
'model_card': f'{mock_read_file.return_value}',
335+
'name': f'{ds_model.display_name}',
336+
'organization': 'test_organization',
337+
'project_id': f'{ds_model.project_id}',
338+
'ready_to_deploy': True,
339+
'ready_to_finetune': False,
340+
'search_text': 'ACTIVE,test_license,test_organization,test_task,test_finetuned_model',
341+
'shape_info': {
342+
'instance_shape': f'{job_infrastructure_configuration_details.shape_name}',
343+
'replica': 1,
344+
},
345+
'source': {'id': '', 'name': '', 'url': ''},
346+
'tags': ds_model.freeform_tags,
347+
'task': 'test_task',
348+
'time_created': f'{ds_model.time_created}',
349+
'validation': {
350+
'type': 'Automatic split',
351+
'value': 'test_val_set_size'
352+
}
353+
}
354+
355+
@patch("ads.aqua.model.read_file")
356+
@patch("ads.aqua.model.get_artifact_path")
357+
def test_load_license(self, mock_get_artifact_path, mock_read_file):
358+
self.app.ds_client.get_model = MagicMock()
359+
mock_get_artifact_path.return_value = "oci://bucket@namespace/prefix/config/LICENSE.txt"
360+
mock_read_file.return_value = "test_license"
361+
362+
license = self.app.load_license(model_id="test_model_id")
363+
364+
mock_get_artifact_path.assert_called()
365+
mock_read_file.assert_called()
94366

95-
def test_get_model(self):
96-
pass
367+
assert asdict(license) == {
368+
'id': 'test_model_id', 'license': 'test_license'
369+
}
97370

98371
def test_list_service_models(self):
99372
"""Tests listing service models succesfully."""

0 commit comments

Comments
 (0)