1
1
from enum import Enum
2
- from typing import Set , List , Union , Iterator , Optional
2
+ from typing import Set , Iterator
3
+ from collections import defaultdict
3
4
4
5
from labelbox import Client
5
6
from labelbox .exceptions import ResourceCreationError
6
7
from labelbox .pydantic_compat import BaseModel
7
8
from labelbox .schema .user import User
8
9
from labelbox .schema .project import Project
9
10
from labelbox .exceptions import UnprocessableEntityError , InvalidQueryError
11
+ from labelbox .schema .queue_mode import QueueMode
12
+ from labelbox .schema .ontology_kind import EditorTaskType
13
+ from labelbox .schema .media_type import MediaType
10
14
11
15
12
16
class UserGroupColor (Enum ):
@@ -34,82 +38,32 @@ class UserGroupColor(Enum):
34
38
YELLOW = "E7BF00"
35
39
GRAY = "B8C4D3"
36
40
37
-
38
- class UserGroupUser (BaseModel ):
39
- """
40
- Represents a user in a group.
41
-
42
- Attributes:
43
- id (str): The ID of the user.
44
- email (str): The email of the user.
45
- """
46
- id : str
47
- email : str
48
-
49
- def __hash__ (self ):
50
- return hash ((self .id ))
51
-
52
- def __eq__ (self , other ):
53
- if not isinstance (other , UserGroupUser ):
54
- return False
55
- return self .id == other .id
56
-
57
-
58
- class UserGroupProject (BaseModel ):
59
- """
60
- Represents a project in a group.
61
-
62
- Attributes:
63
- id (str): The ID of the project.
64
- name (str): The name of the project.
65
- """
66
- id : str
67
- name : str
68
-
69
- def __hash__ (self ):
70
- return hash ((self .id ))
71
-
72
- def __eq__ (self , other ):
73
- """
74
- Check if this GroupProject object is equal to another GroupProject object.
75
-
76
- Args:
77
- other (GroupProject): The other GroupProject object to compare with.
78
-
79
- Returns:
80
- bool: True if the two GroupProject objects are equal, False otherwise.
81
- """
82
- if not isinstance (other , UserGroupProject ):
83
- return False
84
- return self .id == other .id
85
-
86
41
87
42
class UserGroup (BaseModel ):
88
43
"""
89
44
Represents a user group in Labelbox.
90
45
91
46
Attributes:
92
- id (Optional[ str] ): The ID of the user group.
93
- name (Optional[ str] ): The name of the user group.
47
+ id (str): The ID of the user group.
48
+ name (str): The name of the user group.
94
49
color (UserGroupColor): The color of the user group.
95
50
users (Set[UserGroupUser]): The set of users in the user group.
96
51
projects (Set[UserGroupProject]): The set of projects associated with the user group.
97
52
client (Client): The Labelbox client object.
98
53
99
54
Methods:
100
- __init__(self, client: Client, id: str = "", name: str = "", color: UserGroupColor = UserGroupColor.BLUE,
101
- users: Set[UserGroupUser] = set(), projects: Set[UserGroupProject] = set(), reload=True)
102
- _reload(self)
55
+ __init__(self, client: Client)
56
+ get(self) -> "UserGroup"
103
57
update(self) -> "UserGroup"
104
58
create(self) -> "UserGroup"
105
59
delete(self) -> bool
106
60
get_user_groups(client: Client) -> Iterator["UserGroup"]
107
61
"""
108
- id : Optional [ str ]
109
- name : Optional [ str ]
62
+ id : str
63
+ name : str
110
64
color : UserGroupColor
111
- users : Set [UserGroupUser ]
112
- projects : Set [UserGroupProject ]
65
+ users : Set [User ]
66
+ projects : Set [Project ]
113
67
client : Client
114
68
115
69
class Config :
@@ -122,9 +76,8 @@ def __init__(
122
76
id : str = "" ,
123
77
name : str = "" ,
124
78
color : UserGroupColor = UserGroupColor .BLUE ,
125
- users : Set [UserGroupUser ] = set (),
126
- projects : Set [UserGroupProject ] = set (),
127
- reload = True ,
79
+ users : Set [User ] = set (),
80
+ projects : Set [Project ] = set ()
128
81
):
129
82
"""
130
83
Initializes a UserGroup object.
@@ -134,36 +87,32 @@ def __init__(
134
87
id (str, optional): The ID of the user group. Defaults to an empty string.
135
88
name (str, optional): The name of the user group. Defaults to an empty string.
136
89
color (UserGroupColor, optional): The color of the user group. Defaults to UserGroupColor.BLUE.
137
- users (Set[UserGroupUser], optional): The set of users in the user group. Defaults to an empty set.
138
- projects (Set[UserGroupProject], optional): The set of projects associated with the user group. Defaults to an empty set.
139
- reload (bool, optional): Whether to reload the partial representation of the group. Defaults to True.
90
+ users (Set[User], optional): The set of users in the user group. Defaults to an empty set.
91
+ projects (Set[Project], optional): The set of projects associated with the user group. Defaults to an empty set.
140
92
141
93
Raises:
142
94
RuntimeError: If the experimental feature is not enabled in the client.
143
-
144
95
"""
145
96
super ().__init__ (client = client , id = id , name = name , color = color , users = users , projects = projects )
146
97
if not self .client .enable_experimental :
147
- raise RuntimeError (
148
- "Please enable experimental in client to use UserGroups" )
98
+ raise RuntimeError ("Please enable experimental in client to use UserGroups" )
149
99
150
- # partial representation of the group, reload
151
- if self .id and reload :
152
- self ._reload ()
153
-
154
- def _reload (self ):
100
+ def get (self ) -> "UserGroup" :
155
101
"""
156
102
Reloads the user group information from the server.
157
103
158
104
This method sends a GraphQL query to the server to fetch the latest information
159
105
about the user group, including its name, color, projects, and members. The fetched
160
106
information is then used to update the corresponding attributes of the `Group` object.
161
107
162
- Raises :
163
- InvalidQueryError: If the query fails to fetch the group information .
108
+ Args :
109
+ id (str): The ID of the user group to fetch.
164
110
165
111
Returns:
166
- None
112
+ UserGroup of passed in ID (self)
113
+
114
+ Raises:
115
+ InvalidQueryError: If the query fails to fetch the group information.
167
116
"""
168
117
query = """
169
118
query GetUserGroupPyApi($id: ID!) {
@@ -196,14 +145,9 @@ def _reload(self):
196
145
raise InvalidQueryError ("Failed to fetch group" )
197
146
self .name = result ["userGroup" ]["name" ]
198
147
self .color = UserGroupColor (result ["userGroup" ]["color" ])
199
- self .projects = {
200
- UserGroupProject (id = project ["id" ], name = project ["name" ])
201
- for project in result ["userGroup" ]["projects" ]["nodes" ]
202
- }
203
- self .users = {
204
- UserGroupUser (id = member ["id" ], email = member ["email" ])
205
- for member in result ["userGroup" ]["members" ]["nodes" ]
206
- }
148
+ self .projects = self ._get_projects_set (result ["userGroup" ]["projects" ]["nodes" ])
149
+ self .users = self ._get_users_set (result ["userGroup" ]["members" ]["nodes" ])
150
+ return self
207
151
208
152
def update (self ) -> "UserGroup" :
209
153
"""
@@ -249,10 +193,10 @@ def update(self) -> "UserGroup":
249
193
"color" :
250
194
self .color .value ,
251
195
"projectIds" : [
252
- project .id for project in self .projects
196
+ project .uid for project in self .projects
253
197
],
254
198
"userIds" : [
255
- user .id for user in self .users
199
+ user .uid for user in self .users
256
200
]
257
201
}
258
202
result = self .client .execute (query , params )
@@ -311,10 +255,10 @@ def create(self) -> "UserGroup":
311
255
"color" :
312
256
self .color .value ,
313
257
"projectIds" : [
314
- project .id for project in self .projects
258
+ project .uid for project in self .projects
315
259
],
316
260
"userIds" : [
317
- user .id for user in self .users
261
+ user .uid for user in self .users
318
262
]
319
263
}
320
264
result = self .client .execute (query , params )
@@ -351,8 +295,7 @@ def delete(self) -> bool:
351
295
raise UnprocessableEntityError ("Failed to delete user group" )
352
296
return result ["deleteUserGroup" ]["success" ]
353
297
354
- @staticmethod
355
- def get_user_groups (client : Client ) -> Iterator ["UserGroup" ]:
298
+ def get_user_groups (self ) -> Iterator ["UserGroup" ]:
356
299
"""
357
300
Gets all user groups in Labelbox.
358
301
@@ -390,29 +333,60 @@ def get_user_groups(client: Client) -> Iterator["UserGroup"]:
390
333
"""
391
334
nextCursor = None
392
335
while True :
393
- userGroups = client .execute (
336
+ userGroups = self . client .execute (
394
337
query , {"nextCursor" : nextCursor })["userGroups" ]
395
338
if not userGroups :
396
339
return
397
340
yield
398
341
groups = userGroups ["nodes" ]
399
342
for group in groups :
400
- yield UserGroup (client ,
401
- reload = False ,
402
- id = group ["id" ],
403
- name = group ["name" ],
404
- color = UserGroupColor (group ["color" ]),
405
- users = {
406
- UserGroupUser (id = member ["id" ],
407
- email = member ["email" ])
408
- for member in group ["members" ]["nodes" ]
409
- },
410
- projects = {
411
- UserGroupProject (id = project ["id" ],
412
- name = project ["name" ])
413
- for project in group ["projects" ]["nodes" ]
414
- })
343
+ userGroup = UserGroup (self .client )
344
+ userGroup .id = group ["id" ]
345
+ userGroup .name = group ["name" ]
346
+ userGroup .color = UserGroupColor (group ["color" ])
347
+ userGroup .users = self ._get_users_set (group ["members" ]["nodes" ])
348
+ userGroup .projects = self ._get_projects_set (group ["projects" ]["nodes" ])
349
+ yield userGroup
415
350
nextCursor = userGroups ["nextCursor" ]
416
351
# this doesn't seem to be implemented right now to return a value other than null from the api
417
352
if not nextCursor :
418
353
break
354
+
355
+ def _get_users_set (self , user_nodes ):
356
+ """
357
+ Retrieves a set of User objects from the given user nodes.
358
+
359
+ Args:
360
+ user_nodes (list): A list of user nodes containing user information.
361
+
362
+ Returns:
363
+ set: A set of User objects.
364
+ """
365
+ users = set ()
366
+ for user in user_nodes :
367
+ user_values = defaultdict (lambda : None )
368
+ user_values ["id" ] = user ["id" ]
369
+ user_values ["email" ] = user ["email" ]
370
+ users .add (User (self .client , user_values ))
371
+ return users
372
+
373
+ def _get_projects_set (self , project_nodes ):
374
+ """
375
+ Retrieves a set of projects based on the given project nodes.
376
+
377
+ Args:
378
+ project_nodes (list): A list of project nodes.
379
+
380
+ Returns:
381
+ set: A set of Project objects.
382
+ """
383
+ projects = set ()
384
+ for project in project_nodes :
385
+ project_values = defaultdict (lambda : None )
386
+ project_values ["id" ] = project ["id" ]
387
+ project_values ["name" ] = project ["name" ]
388
+ project_values ["queueMode" ] = QueueMode .Batch .value
389
+ project_values ["editorTaskType" ] = EditorTaskType .Missing .value
390
+ project_values ["mediaType" ] = MediaType .Image .value
391
+ projects .add (Project (self .client , project_values ))
392
+ return projects
0 commit comments