Skip to content

Commit 0b37dd5

Browse files
authored
Added unit tests for get and create model (#771)
2 parents e6da4bf + bef2058 commit 0b37dd5

File tree

1 file changed

+299
-0
lines changed

1 file changed

+299
-0
lines changed

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@
1010
from importlib import reload
1111
from unittest.mock import MagicMock
1212

13+
from mock import patch
1314
import oci
1415
from parameterized import parameterized
1516

1617
import ads.aqua.model
1718
import ads.config
1819
from ads.aqua.model import AquaModelApp, AquaModelSummary
20+
from ads.model.datascience_model import DataScienceModel
21+
from ads.model.model_metadata import ModelCustomMetadata, ModelProvenanceMetadata, ModelTaxonomyMetadata
1922

2023

2124
class TestDataset:
@@ -89,6 +92,302 @@ def tearDownClass(cls):
8992
reload(ads.aqua)
9093
reload(ads.aqua.model)
9194

95+
@patch.object(DataScienceModel, "create")
96+
@patch("ads.model.datascience_model.validate")
97+
@patch.object(DataScienceModel, "from_id")
98+
def test_create_model(self, mock_from_id, mock_validate, mock_create):
99+
mock_model = MagicMock()
100+
mock_model.model_file_description = {"test_key":"test_value"}
101+
mock_model.display_name = "test_display_name"
102+
mock_model.description = "test_description"
103+
mock_model.freeform_tags = {
104+
"OCI_AQUA":"ACTIVE",
105+
"license":"test_license",
106+
"organization":"test_organization",
107+
"task":"test_task",
108+
"ready_to_fine_tune":"true"
109+
}
110+
custom_metadata_list = ModelCustomMetadata()
111+
custom_metadata_list.add(
112+
key="test_metadata_item_key",
113+
value="test_metadata_item_value"
114+
)
115+
mock_model.custom_metadata_list = custom_metadata_list
116+
mock_model.provenance_metadata = ModelProvenanceMetadata(
117+
training_id="test_training_id"
118+
)
119+
mock_from_id.return_value = mock_model
120+
121+
# will not copy service model
122+
self.app.create(
123+
model_id="test_model_id",
124+
project_id="test_project_id",
125+
compartment_id="test_compartment_id",
126+
)
127+
128+
mock_from_id.assert_called_with("test_model_id")
129+
mock_validate.assert_not_called()
130+
mock_create.assert_not_called()
131+
132+
mock_model.compartment_id = TestDataset.SERVICE_COMPARTMENT_ID
133+
mock_from_id.return_value = mock_model
134+
mock_create.return_value = mock_model
135+
136+
# will copy service model
137+
model = self.app.create(
138+
model_id="test_model_id",
139+
project_id="test_project_id",
140+
compartment_id="test_compartment_id"
141+
)
142+
143+
mock_from_id.assert_called_with("test_model_id")
144+
mock_validate.assert_called()
145+
mock_create.assert_called_with(
146+
model_by_reference=True
147+
)
148+
149+
assert model.display_name == "test_display_name"
150+
assert model.description == "test_description"
151+
assert model.description == "test_description"
152+
assert model.freeform_tags == {
153+
"OCI_AQUA":"ACTIVE",
154+
"license":"test_license",
155+
"organization":"test_organization",
156+
"task":"test_task",
157+
"ready_to_fine_tune":"true"
158+
}
159+
assert model.custom_metadata_list.get(
160+
"test_metadata_item_key"
161+
).value == "test_metadata_item_value"
162+
assert model.provenance_metadata.training_id == "test_training_id"
163+
164+
@patch("ads.aqua.model.read_file")
165+
@patch.object(DataScienceModel, "from_id")
166+
def test_get_model_not_fine_tuned(self, mock_from_id, mock_read_file):
167+
ds_model = MagicMock()
168+
ds_model.id = "test_id"
169+
ds_model.compartment_id = "test_compartment_id"
170+
ds_model.project_id = "test_project_id"
171+
ds_model.display_name = "test_display_name"
172+
ds_model.description = "test_description"
173+
ds_model.freeform_tags = {
174+
"OCI_AQUA":"ACTIVE",
175+
"license":"test_license",
176+
"organization":"test_organization",
177+
"task":"test_task"
178+
}
179+
ds_model.time_created = "2024-01-19T17:57:39.158000+00:00"
180+
custom_metadata_list = ModelCustomMetadata()
181+
custom_metadata_list.add(
182+
key="artifact_location",
183+
value="oci://bucket@namespace/prefix/"
184+
)
185+
ds_model.custom_metadata_list = custom_metadata_list
186+
187+
mock_from_id.return_value = ds_model
188+
mock_read_file.return_value = "test_model_card"
189+
190+
aqua_model = self.app.get(model_id="test_model_id")
191+
192+
mock_from_id.assert_called_with("test_model_id")
193+
mock_read_file.assert_called_with(
194+
file_path="oci://bucket@namespace/prefix/README.md",
195+
auth=self.app._auth,
196+
)
197+
198+
assert asdict(aqua_model) == {
199+
'compartment_id': f'{ds_model.compartment_id}',
200+
'console_link': (
201+
f'https://cloud.oracle.com/data-science/models/{ds_model.id}?region={self.app.region}',
202+
),
203+
'icon': '',
204+
'id': f'{ds_model.id}',
205+
'is_fine_tuned_model': False,
206+
'license': f'{ds_model.freeform_tags["license"]}',
207+
'model_card': f'{mock_read_file.return_value}',
208+
'name': f'{ds_model.display_name}',
209+
'organization': f'{ds_model.freeform_tags["organization"]}',
210+
'project_id': f'{ds_model.project_id}',
211+
'ready_to_deploy': True,
212+
'ready_to_finetune': False,
213+
'search_text': 'ACTIVE,test_license,test_organization,test_task',
214+
'tags': ds_model.freeform_tags,
215+
'task': f'{ds_model.freeform_tags["task"]}',
216+
'time_created': f'{ds_model.time_created}'
217+
}
218+
219+
@patch("ads.aqua.utils.query_resource")
220+
@patch("ads.aqua.model.read_file")
221+
@patch.object(DataScienceModel, "from_id")
222+
def test_get_model_fine_tuned(self, mock_from_id, mock_read_file, mock_query_resource):
223+
ds_model = MagicMock()
224+
ds_model.id = "test_id"
225+
ds_model.compartment_id = "test_model_compartment_id"
226+
ds_model.project_id = "test_project_id"
227+
ds_model.display_name = "test_display_name"
228+
ds_model.description = "test_description"
229+
ds_model.model_version_set_id = "test_model_version_set_id"
230+
ds_model.model_version_set_name = "test_model_version_set_name"
231+
ds_model.freeform_tags = {
232+
"OCI_AQUA":"ACTIVE",
233+
"license":"test_license",
234+
"organization":"test_organization",
235+
"task":"test_task",
236+
"aqua_fine_tuned_model":"test_finetuned_model"
237+
}
238+
ds_model.time_created = "2024-01-19T17:57:39.158000+00:00"
239+
ds_model.lifecycle_state = "ACTIVE"
240+
custom_metadata_list = ModelCustomMetadata()
241+
custom_metadata_list.add(
242+
key="artifact_location",
243+
value="oci://bucket@namespace/prefix/"
244+
)
245+
custom_metadata_list.add(
246+
key="fine_tune_source",
247+
value="test_fine_tuned_source_id"
248+
)
249+
custom_metadata_list.add(
250+
key="fine_tune_source_name",
251+
value="test_fine_tuned_source_name"
252+
)
253+
ds_model.custom_metadata_list = custom_metadata_list
254+
defined_metadata_list = ModelTaxonomyMetadata()
255+
defined_metadata_list["Hyperparameters"].value = {
256+
"training_data" : "test_training_data",
257+
"val_set_size" : "test_val_set_size"
258+
}
259+
ds_model.defined_metadata_list = defined_metadata_list
260+
ds_model.provenance_metadata = ModelProvenanceMetadata(
261+
training_id="test_training_job_run_id"
262+
)
263+
264+
mock_from_id.return_value = ds_model
265+
mock_read_file.return_value = "test_model_card"
266+
267+
response = MagicMock()
268+
job_run = MagicMock()
269+
job_run.id = "test_job_run_id"
270+
job_run.lifecycle_state = "SUCCEEDED"
271+
job_run.lifecycle_details = "test lifecycle details"
272+
job_run.identifier = "test_job_id",
273+
job_run.display_name = "test_job_name"
274+
job_run.compartment_id = "test_job_run_compartment_id"
275+
job_infrastructure_configuration_details = MagicMock()
276+
job_infrastructure_configuration_details.shape_name = "test_shape_name"
277+
278+
job_configuration_override_details = MagicMock()
279+
job_configuration_override_details.environment_variables = {
280+
"NODE_COUNT" : 1
281+
}
282+
job_run.job_infrastructure_configuration_details = job_infrastructure_configuration_details
283+
job_run.job_configuration_override_details = job_configuration_override_details
284+
log_details = MagicMock()
285+
log_details.log_id = "test_log_id"
286+
log_details.log_group_id = "test_log_group_id"
287+
job_run.log_details = log_details
288+
response.data = job_run
289+
self.app.ds_client.get_job_run = MagicMock(
290+
return_value = response
291+
)
292+
293+
query_resource = MagicMock()
294+
query_resource.display_name = "test_display_name"
295+
mock_query_resource.return_value = query_resource
296+
297+
model = self.app.get(model_id="test_model_id")
298+
299+
mock_from_id.assert_called_with("test_model_id")
300+
mock_read_file.assert_called_with(
301+
file_path="oci://bucket@namespace/prefix/README.md",
302+
auth=self.app._auth,
303+
)
304+
mock_query_resource.assert_called()
305+
306+
assert asdict(model) == {
307+
'compartment_id': f'{ds_model.compartment_id}',
308+
'console_link': (
309+
f'https://cloud.oracle.com/data-science/models/{ds_model.id}?region={self.app.region}',
310+
),
311+
'dataset': 'test_training_data',
312+
'experiment': {'id': '', 'name': '', 'url': ''},
313+
'icon': '',
314+
'id': f'{ds_model.id}',
315+
'is_fine_tuned_model': True,
316+
'job': {'id': '', 'name': '', 'url': ''},
317+
'license': 'test_license',
318+
'lifecycle_details': f'{job_run.lifecycle_details}',
319+
'lifecycle_state': f'{ds_model.lifecycle_state}',
320+
'log': {
321+
'id': f'{log_details.log_id}',
322+
'name': f'{query_resource.display_name}',
323+
'url': 'https://cloud.oracle.com/logging/search?searchQuery=search '
324+
f'"{job_run.compartment_id}/{log_details.log_group_id}/{log_details.log_id}" | '
325+
f"source='{job_run.id}' | sort by datetime desc&regions={self.app.region}"
326+
},
327+
'log_group': {
328+
'id': f'{log_details.log_group_id}',
329+
'name': f'{query_resource.display_name}',
330+
'url': f'https://cloud.oracle.com/logging/log-groups/{log_details.log_group_id}?region={self.app.region}'
331+
},
332+
'metrics': [
333+
{
334+
'category': 'validation',
335+
'name': 'validation_metrics',
336+
'scores': []
337+
},
338+
{
339+
'category': 'training',
340+
'name': 'training_metrics',
341+
'scores': []
342+
},
343+
{
344+
'category': 'validation',
345+
'name': 'validation_metrics_final',
346+
'scores': []
347+
},
348+
{
349+
'category': 'training',
350+
'name': 'training_metrics_final',
351+
'scores': []
352+
}
353+
],
354+
'model_card': f'{mock_read_file.return_value}',
355+
'name': f'{ds_model.display_name}',
356+
'organization': 'test_organization',
357+
'project_id': f'{ds_model.project_id}',
358+
'ready_to_deploy': True,
359+
'ready_to_finetune': False,
360+
'search_text': 'ACTIVE,test_license,test_organization,test_task,test_finetuned_model',
361+
'shape_info': {
362+
'instance_shape': f'{job_infrastructure_configuration_details.shape_name}',
363+
'replica': 1,
364+
},
365+
'source': {'id': '', 'name': '', 'url': ''},
366+
'tags': ds_model.freeform_tags,
367+
'task': 'test_task',
368+
'time_created': f'{ds_model.time_created}',
369+
'validation': {
370+
'type': 'Automatic split',
371+
'value': 'test_val_set_size'
372+
}
373+
}
374+
375+
@patch("ads.aqua.model.read_file")
376+
@patch("ads.aqua.model.get_artifact_path")
377+
def test_load_license(self, mock_get_artifact_path, mock_read_file):
378+
self.app.ds_client.get_model = MagicMock()
379+
mock_get_artifact_path.return_value = "oci://bucket@namespace/prefix/config/LICENSE.txt"
380+
mock_read_file.return_value = "test_license"
381+
382+
license = self.app.load_license(model_id="test_model_id")
383+
384+
mock_get_artifact_path.assert_called()
385+
mock_read_file.assert_called()
386+
387+
assert asdict(license) == {
388+
'id': 'test_model_id', 'license': 'test_license'
389+
}
390+
92391
def test_list_service_models(self):
93392
"""Tests listing service models succesfully."""
94393

0 commit comments

Comments
 (0)