2
2
3
3
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
4
4
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
-
6
- from ads .catalog .project import ProjectCatalog , ProjectSummaryList
7
- from ads .common import auth , oci_client
8
- from ads .common .utils import random_valid_ocid
9
- from ads .config import NB_SESSION_COMPARTMENT_OCID
5
+ import os
6
+ import unittest
10
7
from collections import namedtuple
11
8
from datetime import datetime , timezone , timedelta
12
- from oci . exceptions import ServiceError
9
+ from importlib import reload
13
10
from unittest import mock
14
11
from unittest .mock import MagicMock , Mock , patch
12
+
15
13
import oci
16
- import os
17
14
import pytest
18
- import unittest
15
+ from oci .exceptions import ServiceError
16
+
17
+ import ads .config
18
+ from ads .catalog .project import ProjectCatalog , ProjectSummaryList
19
+ from ads .common import auth , oci_client
20
+ from ads .common .utils import random_valid_ocid
21
+ from ads .config import NB_SESSION_COMPARTMENT_OCID
19
22
20
23
21
24
def generate_project_list (
@@ -62,7 +65,9 @@ class ProjectCatalogTest(unittest.TestCase):
62
65
with patch .object (auth , "default_signer" ):
63
66
with patch .object (oci_client , "OCIClientFactory" ):
64
67
project_id = "ocid1.projectcatalog.oc1.iad.<unique_ocid>"
65
- comp_id = os .environ .get ("NB_SESSION_COMPARTMENT_OCID" , "ocid1.compartment.oc1.iad.<unique_ocid>" )
68
+ comp_id = os .environ .get (
69
+ "NB_SESSION_COMPARTMENT_OCID" , "ocid1.compartment.oc1.iad.<unique_ocid>"
70
+ )
66
71
date_time = datetime (2020 , 7 , 1 , 18 , 24 , 42 , 110000 , tzinfo = timezone .utc )
67
72
68
73
pc = ProjectCatalog (compartment_id = comp_id )
@@ -71,8 +76,20 @@ class ProjectCatalogTest(unittest.TestCase):
71
76
72
77
psl = ProjectSummaryList (generate_project_list ())
73
78
79
+ def setUp (self ) -> None :
80
+ os .environ [
81
+ "NB_SESSION_COMPARTMENT_OCID"
82
+ ] = "ocid1.compartment.oc1.<unique_ocid>"
83
+ reload (ads .config )
84
+ return super ().setUp ()
85
+
86
+ def tearDown (self ) -> None :
87
+ os .environ .pop ("NB_SESSION_COMPARTMENT_OCID" , None )
88
+ reload (ads .config )
89
+ return super ().tearDown ()
90
+
74
91
@staticmethod
75
- def generate_project_response_data (self , compartment_id = None , project_id = None ):
92
+ def generate_project_response_data (compartment_id = None , project_id = None ):
76
93
entity_item = {
77
94
"compartment_id" : compartment_id ,
78
95
"created_by" : "mock_user" ,
@@ -82,7 +99,7 @@ def generate_project_response_data(self, compartment_id=None, project_id=None):
82
99
"freeform_tags" : {},
83
100
"id" : project_id ,
84
101
"lifecycle_state" : "ACTIVE" ,
85
- "time_created" : self .date_time .isoformat (),
102
+ "time_created" : ProjectCatalogTest .date_time .isoformat (),
86
103
}
87
104
project_response = oci .data_science .models .Project (** entity_item )
88
105
return project_response
@@ -104,7 +121,7 @@ def test_project_init_without_compartment_id(self, mock_client, mock_signer):
104
121
def test_decorate_project_session_attributes (self ):
105
122
"""Test ProjectCatalog._decorate_project method."""
106
123
project = self .generate_project_response_data (
107
- self , compartment_id = self .comp_id , project_id = self .project_id
124
+ compartment_id = self .comp_id , project_id = self .project_id
108
125
)
109
126
110
127
def generate_get_user_data (self , compartment_id = None ):
@@ -160,7 +177,6 @@ def test_get_project_with_short_id(self):
160
177
def mock_get_notebook_session (project_id = id ):
161
178
return Mock (
162
179
data = self .generate_project_response_data (
163
- self ,
164
180
compartment_id = self .comp_id ,
165
181
project_id = short_id_index [short_id ],
166
182
)
@@ -258,7 +274,7 @@ def test_update_project_with_short_id(self):
258
274
wrapper = namedtuple ("wrapper" , ["data" ])
259
275
client_update_project_response = wrapper (
260
276
data = self .generate_project_response_data (
261
- self , compartment_id = self .comp_id , project_id = short_id_index [short_id ]
277
+ compartment_id = self .comp_id , project_id = short_id_index [short_id ]
262
278
)
263
279
)
264
280
self .pc .ds_client .update_project = MagicMock (
@@ -315,5 +331,5 @@ def test_project_summary_list_filter_invalid_param(self):
315
331
# selection is a notebook session instance
316
332
with pytest .raises (ValueError ):
317
333
self .psl .filter (
318
- selection = self .generate_project_response_data (self ), instance = None
334
+ selection = self .generate_project_response_data (), instance = None
319
335
)
0 commit comments