10
10
from importlib import reload
11
11
from unittest .mock import MagicMock
12
12
13
+ from mock import patch
13
14
import oci
14
15
from parameterized import parameterized
15
16
16
17
import ads .aqua .model
17
18
import ads .config
18
19
from ads .aqua .model import AquaModelApp , AquaModelSummary
20
+ from ads .model .datascience_model import DataScienceModel
21
+ from ads .model .model_metadata import ModelCustomMetadata , ModelProvenanceMetadata , ModelTaxonomyMetadata
19
22
20
23
21
24
class TestDataset :
@@ -71,6 +74,8 @@ class TestAquaModel(unittest.TestCase):
71
74
"""Contains unittests for AquaModelApp."""
72
75
73
76
def setUp (self ):
77
+ import ads
78
+ ads .set_auth ("security_token" )
74
79
self .app = AquaModelApp ()
75
80
76
81
@classmethod
@@ -89,11 +94,279 @@ def tearDownClass(cls):
89
94
reload (ads .aqua )
90
95
reload (ads .aqua .model )
91
96
92
- def test_create_model (self ):
93
- pass
97
+ @patch .object (DataScienceModel , "create" )
98
+ @patch ("ads.model.datascience_model.validate" )
99
+ @patch .object (DataScienceModel , "from_id" )
100
+ def test_create_model (self , mock_from_id , mock_validate , mock_create ):
101
+ service_model = MagicMock ()
102
+ service_model .model_file_description = {"test_key" :"test_value" }
103
+ service_model .display_name = "test_display_name"
104
+ service_model .description = "test_description"
105
+ service_model .freeform_tags = {"test_key" :"test_value" }
106
+ custom_metadata_list = ModelCustomMetadata ()
107
+ custom_metadata_list .add (
108
+ key = "test_metadata_item_key" ,
109
+ value = "test_metadata_item_value"
110
+ )
111
+ service_model .custom_metadata_list = custom_metadata_list
112
+ service_model .provenance_metadata = ModelProvenanceMetadata (
113
+ training_id = "test_training_id"
114
+ )
115
+ mock_from_id .return_value = service_model
116
+
117
+ # will not copy service model
118
+ self .app .create (
119
+ model_id = "test_model_id" ,
120
+ project_id = "test_project_id" ,
121
+ compartment_id = "test_compartment_id" ,
122
+ )
123
+
124
+ mock_from_id .assert_called_with ("test_model_id" )
125
+ mock_validate .assert_not_called ()
126
+ mock_create .assert_not_called ()
127
+
128
+ service_model .compartment_id = TestDataset .SERVICE_COMPARTMENT_ID
129
+ mock_from_id .return_value = service_model
130
+
131
+ # will copy service model
132
+ self .app .create (
133
+ model_id = "test_model_id" ,
134
+ project_id = "test_project_id" ,
135
+ compartment_id = "test_compartment_id"
136
+ )
137
+
138
+ mock_from_id .assert_called_with ("test_model_id" )
139
+ mock_validate .assert_called ()
140
+ mock_create .assert_called_with (
141
+ model_by_reference = True
142
+ )
143
+
144
+ @patch ("ads.aqua.model.read_file" )
145
+ @patch .object (DataScienceModel , "from_id" )
146
+ def test_get_model_not_fine_tuned (self , mock_from_id , mock_read_file ):
147
+ ds_model = MagicMock ()
148
+ ds_model .id = "test_id"
149
+ ds_model .compartment_id = "test_compartment_id"
150
+ ds_model .project_id = "test_project_id"
151
+ ds_model .display_name = "test_display_name"
152
+ ds_model .description = "test_description"
153
+ ds_model .freeform_tags = {
154
+ "OCI_AQUA" :"ACTIVE" ,
155
+ "license" :"test_license" ,
156
+ "organization" :"test_organization" ,
157
+ "task" :"test_task"
158
+ }
159
+ ds_model .time_created = "2024-01-19T17:57:39.158000+00:00"
160
+ custom_metadata_list = ModelCustomMetadata ()
161
+ custom_metadata_list .add (
162
+ key = "artifact_location" ,
163
+ value = "oci://bucket@namespace/prefix/"
164
+ )
165
+ ds_model .custom_metadata_list = custom_metadata_list
166
+
167
+ mock_from_id .return_value = ds_model
168
+ mock_read_file .return_value = "test_model_card"
169
+
170
+ aqua_model = self .app .get (model_id = "test_model_id" )
171
+
172
+ mock_from_id .assert_called_with ("test_model_id" )
173
+ mock_read_file .assert_called_with (
174
+ file_path = "oci://bucket@namespace/prefix/README.md" ,
175
+ auth = self .app ._auth ,
176
+ )
177
+
178
+ assert asdict (aqua_model ) == {
179
+ 'compartment_id' : f'{ ds_model .compartment_id } ' ,
180
+ 'console_link' : (
181
+ f'https://cloud.oracle.com/data-science/models/{ ds_model .id } ?region={ self .app .region } ' ,
182
+ ),
183
+ 'icon' : '' ,
184
+ 'id' : f'{ ds_model .id } ' ,
185
+ 'is_fine_tuned_model' : False ,
186
+ 'license' : f'{ ds_model .freeform_tags ["license" ]} ' ,
187
+ 'model_card' : f'{ mock_read_file .return_value } ' ,
188
+ 'name' : f'{ ds_model .display_name } ' ,
189
+ 'organization' : f'{ ds_model .freeform_tags ["organization" ]} ' ,
190
+ 'project_id' : f'{ ds_model .project_id } ' ,
191
+ 'ready_to_deploy' : True ,
192
+ 'ready_to_finetune' : False ,
193
+ 'search_text' : 'ACTIVE,test_license,test_organization,test_task' ,
194
+ 'tags' : ds_model .freeform_tags ,
195
+ 'task' : f'{ ds_model .freeform_tags ["task" ]} ' ,
196
+ 'time_created' : f'{ ds_model .time_created } '
197
+ }
198
+
199
+ @patch ("ads.aqua.utils.query_resource" )
200
+ @patch ("ads.aqua.model.read_file" )
201
+ @patch .object (DataScienceModel , "from_id" )
202
+ def test_get_model_fine_tuned (self , mock_from_id , mock_read_file , mock_query_resource ):
203
+ ds_model = MagicMock ()
204
+ ds_model .id = "test_id"
205
+ ds_model .compartment_id = "test_model_compartment_id"
206
+ ds_model .project_id = "test_project_id"
207
+ ds_model .display_name = "test_display_name"
208
+ ds_model .description = "test_description"
209
+ ds_model .model_version_set_id = "test_model_version_set_id"
210
+ ds_model .model_version_set_name = "test_model_version_set_name"
211
+ ds_model .freeform_tags = {
212
+ "OCI_AQUA" :"ACTIVE" ,
213
+ "license" :"test_license" ,
214
+ "organization" :"test_organization" ,
215
+ "task" :"test_task" ,
216
+ "aqua_fine_tuned_model" :"test_finetuned_model"
217
+ }
218
+ ds_model .time_created = "2024-01-19T17:57:39.158000+00:00"
219
+ ds_model .lifecycle_state = "ACTIVE"
220
+ custom_metadata_list = ModelCustomMetadata ()
221
+ custom_metadata_list .add (
222
+ key = "artifact_location" ,
223
+ value = "oci://bucket@namespace/prefix/"
224
+ )
225
+ custom_metadata_list .add (
226
+ key = "fine_tune_source" ,
227
+ value = "test_fine_tuned_source_id"
228
+ )
229
+ custom_metadata_list .add (
230
+ key = "fine_tune_source_name" ,
231
+ value = "test_fine_tuned_source_name"
232
+ )
233
+ ds_model .custom_metadata_list = custom_metadata_list
234
+ defined_metadata_list = ModelTaxonomyMetadata ()
235
+ defined_metadata_list ["Hyperparameters" ].value = {
236
+ "training_data" : "test_training_data" ,
237
+ "val_set_size" : "test_val_set_size"
238
+ }
239
+ ds_model .defined_metadata_list = defined_metadata_list
240
+ ds_model .provenance_metadata = ModelProvenanceMetadata (
241
+ training_id = "test_training_job_run_id"
242
+ )
243
+
244
+ mock_from_id .return_value = ds_model
245
+ mock_read_file .return_value = "test_model_card"
246
+
247
+ response = MagicMock ()
248
+ job_run = MagicMock ()
249
+ job_run .id = "test_job_run_id"
250
+ job_run .lifecycle_state = "SUCCEEDED"
251
+ job_run .lifecycle_details = "test lifecycle details"
252
+ job_run .identifier = "test_job_id" ,
253
+ job_run .display_name = "test_job_name"
254
+ job_run .compartment_id = "test_job_run_compartment_id"
255
+ job_infrastructure_configuration_details = MagicMock ()
256
+ job_infrastructure_configuration_details .shape_name = "test_shape_name"
257
+
258
+ job_configuration_override_details = MagicMock ()
259
+ job_configuration_override_details .environment_variables = {
260
+ "NODE_COUNT" : 1
261
+ }
262
+ job_run .job_infrastructure_configuration_details = job_infrastructure_configuration_details
263
+ job_run .job_configuration_override_details = job_configuration_override_details
264
+ log_details = MagicMock ()
265
+ log_details .log_id = "test_log_id"
266
+ log_details .log_group_id = "test_log_group_id"
267
+ job_run .log_details = log_details
268
+ response .data = job_run
269
+ self .app .ds_client .get_job_run = MagicMock (
270
+ return_value = response
271
+ )
272
+
273
+ query_resource = MagicMock ()
274
+ query_resource .display_name = "test_display_name"
275
+ mock_query_resource .return_value = query_resource
276
+
277
+ model = self .app .get (model_id = "test_model_id" )
278
+
279
+ mock_from_id .assert_called_with ("test_model_id" )
280
+ mock_read_file .assert_called_with (
281
+ file_path = "oci://bucket@namespace/prefix/README.md" ,
282
+ auth = self .app ._auth ,
283
+ )
284
+ mock_query_resource .assert_called ()
285
+
286
+ assert asdict (model ) == {
287
+ 'compartment_id' : f'{ ds_model .compartment_id } ' ,
288
+ 'console_link' : (
289
+ f'https://cloud.oracle.com/data-science/models/{ ds_model .id } ?region={ self .app .region } ' ,
290
+ ),
291
+ 'dataset' : 'test_training_data' ,
292
+ 'experiment' : {'id' : '' , 'name' : '' , 'url' : '' },
293
+ 'icon' : '' ,
294
+ 'id' : f'{ ds_model .id } ' ,
295
+ 'is_fine_tuned_model' : True ,
296
+ 'job' : {'id' : '' , 'name' : '' , 'url' : '' },
297
+ 'license' : 'test_license' ,
298
+ 'lifecycle_details' : f'{ job_run .lifecycle_details } ' ,
299
+ 'lifecycle_state' : f'{ ds_model .lifecycle_state } ' ,
300
+ 'log' : {
301
+ 'id' : f'{ log_details .log_id } ' ,
302
+ 'name' : f'{ query_resource .display_name } ' ,
303
+ 'url' : 'https://cloud.oracle.com/logging/search?searchQuery=search '
304
+ f'"{ job_run .compartment_id } /{ log_details .log_group_id } /{ log_details .log_id } " | '
305
+ f"source='{ job_run .id } ' | sort by datetime desc®ions={ self .app .region } "
306
+ },
307
+ 'log_group' : {
308
+ 'id' : f'{ log_details .log_group_id } ' ,
309
+ 'name' : f'{ query_resource .display_name } ' ,
310
+ 'url' : f'https://cloud.oracle.com/logging/log-groups/{ log_details .log_group_id } ?region={ self .app .region } '
311
+ },
312
+ 'metrics' : [
313
+ {
314
+ 'category' : 'validation' ,
315
+ 'name' : 'validation_metrics' ,
316
+ 'scores' : []
317
+ },
318
+ {
319
+ 'category' : 'training' ,
320
+ 'name' : 'training_metrics' ,
321
+ 'scores' : []
322
+ },
323
+ {
324
+ 'category' : 'validation' ,
325
+ 'name' : 'validation_metrics_final' ,
326
+ 'scores' : []
327
+ },
328
+ {
329
+ 'category' : 'training' ,
330
+ 'name' : 'training_metrics_final' ,
331
+ 'scores' : []
332
+ }
333
+ ],
334
+ 'model_card' : f'{ mock_read_file .return_value } ' ,
335
+ 'name' : f'{ ds_model .display_name } ' ,
336
+ 'organization' : 'test_organization' ,
337
+ 'project_id' : f'{ ds_model .project_id } ' ,
338
+ 'ready_to_deploy' : True ,
339
+ 'ready_to_finetune' : False ,
340
+ 'search_text' : 'ACTIVE,test_license,test_organization,test_task,test_finetuned_model' ,
341
+ 'shape_info' : {
342
+ 'instance_shape' : f'{ job_infrastructure_configuration_details .shape_name } ' ,
343
+ 'replica' : 1 ,
344
+ },
345
+ 'source' : {'id' : '' , 'name' : '' , 'url' : '' },
346
+ 'tags' : ds_model .freeform_tags ,
347
+ 'task' : 'test_task' ,
348
+ 'time_created' : f'{ ds_model .time_created } ' ,
349
+ 'validation' : {
350
+ 'type' : 'Automatic split' ,
351
+ 'value' : 'test_val_set_size'
352
+ }
353
+ }
354
+
355
+ @patch ("ads.aqua.model.read_file" )
356
+ @patch ("ads.aqua.model.get_artifact_path" )
357
+ def test_load_license (self , mock_get_artifact_path , mock_read_file ):
358
+ self .app .ds_client .get_model = MagicMock ()
359
+ mock_get_artifact_path .return_value = "oci://bucket@namespace/prefix/config/LICENSE.txt"
360
+ mock_read_file .return_value = "test_license"
361
+
362
+ license = self .app .load_license (model_id = "test_model_id" )
363
+
364
+ mock_get_artifact_path .assert_called ()
365
+ mock_read_file .assert_called ()
94
366
95
- def test_get_model (self ):
96
- pass
367
+ assert asdict (license ) == {
368
+ 'id' : 'test_model_id' , 'license' : 'test_license'
369
+ }
97
370
98
371
def test_list_service_models (self ):
99
372
"""Tests listing service models succesfully."""
0 commit comments