1
1
from enum import Enum
2
- from typing import Set , List , Union , Iterator , Optional
2
+ from typing import Set , Iterator
3
+ from collections import defaultdict
3
4
4
5
from labelbox import Client
5
6
from labelbox .exceptions import ResourceCreationError
6
7
from labelbox .pydantic_compat import BaseModel
7
8
from labelbox .schema .user import User
8
9
from labelbox .schema .project import Project
9
10
from labelbox .exceptions import UnprocessableEntityError , InvalidQueryError
11
+ from labelbox .schema .queue_mode import QueueMode
12
+ from labelbox .schema .ontology_kind import EditorTaskType
13
+ from labelbox .schema .media_type import MediaType
10
14
11
15
12
16
class UserGroupColor (Enum ):
@@ -34,55 +38,6 @@ class UserGroupColor(Enum):
34
38
YELLOW = "E7BF00"
35
39
GRAY = "B8C4D3"
36
40
37
-
38
- class UserGroupUser (BaseModel ):
39
- """
40
- Represents a user in a group.
41
-
42
- Attributes:
43
- id (str): The ID of the user.
44
- email (str): The email of the user.
45
- """
46
- id : str
47
- email : str
48
-
49
- def __hash__ (self ):
50
- return hash ((self .id ))
51
-
52
- def __eq__ (self , other ):
53
- if not isinstance (other , UserGroupUser ):
54
- return False
55
- return self .id == other .id
56
-
57
-
58
- class UserGroupProject (BaseModel ):
59
- """
60
- Represents a project in a group.
61
-
62
- Attributes:
63
- id (str): The ID of the project.
64
- name (str): The name of the project.
65
- """
66
- id : str
67
- name : str
68
-
69
- def __hash__ (self ):
70
- return hash ((self .id ))
71
-
72
- def __eq__ (self , other ):
73
- """
74
- Check if this GroupProject object is equal to another GroupProject object.
75
-
76
- Args:
77
- other (GroupProject): The other GroupProject object to compare with.
78
-
79
- Returns:
80
- bool: True if the two GroupProject objects are equal, False otherwise.
81
- """
82
- if not isinstance (other , UserGroupProject ):
83
- return False
84
- return self .id == other .id
85
-
86
41
87
42
class UserGroup (BaseModel ):
88
43
"""
@@ -107,8 +62,8 @@ class UserGroup(BaseModel):
107
62
id : str
108
63
name : str
109
64
color : UserGroupColor
110
- users : Set [UserGroupUser ]
111
- projects : Set [UserGroupProject ]
65
+ users : Set [User ]
66
+ projects : Set [Project ]
112
67
client : Client
113
68
114
69
class Config :
@@ -182,14 +137,8 @@ def get(self) -> "UserGroup":
182
137
raise InvalidQueryError ("Failed to fetch group" )
183
138
self .name = result ["userGroup" ]["name" ]
184
139
self .color = UserGroupColor (result ["userGroup" ]["color" ])
185
- self .projects = {
186
- UserGroupProject (id = project ["id" ], name = project ["name" ])
187
- for project in result ["userGroup" ]["projects" ]["nodes" ]
188
- }
189
- self .users = {
190
- UserGroupUser (id = member ["id" ], email = member ["email" ])
191
- for member in result ["userGroup" ]["members" ]["nodes" ]
192
- }
140
+ self .projects = self ._get_projects_set (result ["userGroup" ]["projects" ]["nodes" ])
141
+ self .users = self ._get_users_set (result ["userGroup" ]["members" ]["nodes" ])
193
142
return self
194
143
195
144
def update (self ) -> "UserGroup" :
@@ -236,10 +185,10 @@ def update(self) -> "UserGroup":
236
185
"color" :
237
186
self .color .value ,
238
187
"projectIds" : [
239
- project .id for project in self .projects
188
+ project .uid for project in self .projects
240
189
],
241
190
"userIds" : [
242
- user .id for user in self .users
191
+ user .uid for user in self .users
243
192
]
244
193
}
245
194
result = self .client .execute (query , params )
@@ -298,10 +247,10 @@ def create(self) -> "UserGroup":
298
247
"color" :
299
248
self .color .value ,
300
249
"projectIds" : [
301
- project .id for project in self .projects
250
+ project .uid for project in self .projects
302
251
],
303
252
"userIds" : [
304
- user .id for user in self .users
253
+ user .uid for user in self .users
305
254
]
306
255
}
307
256
result = self .client .execute (query , params )
@@ -387,16 +336,49 @@ def get_user_groups(self) -> Iterator["UserGroup"]:
387
336
userGroup .id = group ["id" ]
388
337
userGroup .name = group ["name" ]
389
338
userGroup .color = UserGroupColor (group ["color" ])
390
- userGroup .users = {
391
- UserGroupUser (id = member ["id" ], email = member ["email" ])
392
- for member in group ["members" ]["nodes" ]
393
- }
394
- userGroup .projects = {
395
- UserGroupProject (id = project ["id" ], name = project ["name" ])
396
- for project in group ["projects" ]["nodes" ]
397
- }
339
+ userGroup .users = self ._get_users_set (group ["members" ]["nodes" ])
340
+ userGroup .projects = self ._get_projects_set (group ["projects" ]["nodes" ])
398
341
yield userGroup
399
342
nextCursor = userGroups ["nextCursor" ]
400
343
# this doesn't seem to be implemented right now to return a value other than null from the api
401
344
if not nextCursor :
402
345
break
346
+
347
+ def _get_users_set (self , user_nodes ):
348
+ """
349
+ Retrieves a set of User objects from the given user nodes.
350
+
351
+ Args:
352
+ user_nodes (list): A list of user nodes containing user information.
353
+
354
+ Returns:
355
+ set: A set of User objects.
356
+ """
357
+ users = set ()
358
+ for user in user_nodes :
359
+ user_values = defaultdict (lambda : None )
360
+ user_values ["id" ] = user ["id" ]
361
+ user_values ["email" ] = user ["email" ]
362
+ users .add (User (self .client , user_values ))
363
+ return users
364
+
365
+ def _get_projects_set (self , project_nodes ):
366
+ """
367
+ Retrieves a set of projects based on the given project nodes.
368
+
369
+ Args:
370
+ project_nodes (list): A list of project nodes.
371
+
372
+ Returns:
373
+ set: A set of Project objects.
374
+ """
375
+ projects = set ()
376
+ for project in project_nodes :
377
+ project_values = defaultdict (lambda : None )
378
+ project_values ["id" ] = project ["id" ]
379
+ project_values ["name" ] = project ["name" ]
380
+ project_values ["queueMode" ] = QueueMode .Batch .value
381
+ project_values ["editorTaskType" ] = EditorTaskType .Missing .value
382
+ project_values ["mediaType" ] = MediaType .Audio .value
383
+ projects .add (Project (self .client , project_values ))
384
+ return projects
0 commit comments