Skip to content

Commit 4c4cb99

Browse files
authored
[PLT-1205] Improvements for QA (#1706)
1 parent 7ddeef7 commit 4c4cb99

File tree

3 files changed

+161
-288
lines changed

3 files changed

+161
-288
lines changed

libs/labelbox/src/labelbox/schema/user_group.py

Lines changed: 80 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from enum import Enum
2-
from typing import Set, List, Union, Iterator, Optional
2+
from typing import Set, Iterator
3+
from collections import defaultdict
34

45
from labelbox import Client
56
from labelbox.exceptions import ResourceCreationError
67
from labelbox.pydantic_compat import BaseModel
78
from labelbox.schema.user import User
89
from labelbox.schema.project import Project
910
from labelbox.exceptions import UnprocessableEntityError, InvalidQueryError
11+
from labelbox.schema.queue_mode import QueueMode
12+
from labelbox.schema.ontology_kind import EditorTaskType
13+
from labelbox.schema.media_type import MediaType
1014

1115

1216
class UserGroupColor(Enum):
@@ -34,82 +38,32 @@ class UserGroupColor(Enum):
3438
YELLOW = "E7BF00"
3539
GRAY = "B8C4D3"
3640

37-
38-
class UserGroupUser(BaseModel):
39-
"""
40-
Represents a user in a group.
41-
42-
Attributes:
43-
id (str): The ID of the user.
44-
email (str): The email of the user.
45-
"""
46-
id: str
47-
email: str
48-
49-
def __hash__(self):
50-
return hash((self.id))
51-
52-
def __eq__(self, other):
53-
if not isinstance(other, UserGroupUser):
54-
return False
55-
return self.id == other.id
56-
57-
58-
class UserGroupProject(BaseModel):
59-
"""
60-
Represents a project in a group.
61-
62-
Attributes:
63-
id (str): The ID of the project.
64-
name (str): The name of the project.
65-
"""
66-
id: str
67-
name: str
68-
69-
def __hash__(self):
70-
return hash((self.id))
71-
72-
def __eq__(self, other):
73-
"""
74-
Check if this GroupProject object is equal to another GroupProject object.
75-
76-
Args:
77-
other (GroupProject): The other GroupProject object to compare with.
78-
79-
Returns:
80-
bool: True if the two GroupProject objects are equal, False otherwise.
81-
"""
82-
if not isinstance(other, UserGroupProject):
83-
return False
84-
return self.id == other.id
85-
8641

8742
class UserGroup(BaseModel):
8843
"""
8944
Represents a user group in Labelbox.
9045
9146
Attributes:
92-
id (Optional[str]): The ID of the user group.
93-
name (Optional[str]): The name of the user group.
47+
id (str): The ID of the user group.
48+
name (str): The name of the user group.
9449
color (UserGroupColor): The color of the user group.
9550
users (Set[UserGroupUser]): The set of users in the user group.
9651
projects (Set[UserGroupProject]): The set of projects associated with the user group.
9752
client (Client): The Labelbox client object.
9853
9954
Methods:
100-
__init__(self, client: Client, id: str = "", name: str = "", color: UserGroupColor = UserGroupColor.BLUE,
101-
users: Set[UserGroupUser] = set(), projects: Set[UserGroupProject] = set(), reload=True)
102-
_reload(self)
55+
__init__(self, client: Client)
56+
get(self) -> "UserGroup"
10357
update(self) -> "UserGroup"
10458
create(self) -> "UserGroup"
10559
delete(self) -> bool
10660
get_user_groups(client: Client) -> Iterator["UserGroup"]
10761
"""
108-
id: Optional[str]
109-
name: Optional[str]
62+
id: str
63+
name: str
11064
color: UserGroupColor
111-
users: Set[UserGroupUser]
112-
projects: Set[UserGroupProject]
65+
users: Set[User]
66+
projects: Set[Project]
11367
client: Client
11468

11569
class Config:
@@ -122,9 +76,8 @@ def __init__(
12276
id: str = "",
12377
name: str = "",
12478
color: UserGroupColor = UserGroupColor.BLUE,
125-
users: Set[UserGroupUser] = set(),
126-
projects: Set[UserGroupProject] = set(),
127-
reload=True,
79+
users: Set[User] = set(),
80+
projects: Set[Project] = set()
12881
):
12982
"""
13083
Initializes a UserGroup object.
@@ -134,36 +87,32 @@ def __init__(
13487
id (str, optional): The ID of the user group. Defaults to an empty string.
13588
name (str, optional): The name of the user group. Defaults to an empty string.
13689
color (UserGroupColor, optional): The color of the user group. Defaults to UserGroupColor.BLUE.
137-
users (Set[UserGroupUser], optional): The set of users in the user group. Defaults to an empty set.
138-
projects (Set[UserGroupProject], optional): The set of projects associated with the user group. Defaults to an empty set.
139-
reload (bool, optional): Whether to reload the partial representation of the group. Defaults to True.
90+
users (Set[User], optional): The set of users in the user group. Defaults to an empty set.
91+
projects (Set[Project], optional): The set of projects associated with the user group. Defaults to an empty set.
14092
14193
Raises:
14294
RuntimeError: If the experimental feature is not enabled in the client.
143-
14495
"""
14596
super().__init__(client=client, id=id, name=name, color=color, users=users, projects=projects)
14697
if not self.client.enable_experimental:
147-
raise RuntimeError(
148-
"Please enable experimental in client to use UserGroups")
98+
raise RuntimeError("Please enable experimental in client to use UserGroups")
14999

150-
# partial representation of the group, reload
151-
if self.id and reload:
152-
self._reload()
153-
154-
def _reload(self):
100+
def get(self) -> "UserGroup":
155101
"""
156102
Reloads the user group information from the server.
157103
158104
This method sends a GraphQL query to the server to fetch the latest information
159105
about the user group, including its name, color, projects, and members. The fetched
160106
information is then used to update the corresponding attributes of the `Group` object.
161107
162-
Raises:
163-
InvalidQueryError: If the query fails to fetch the group information.
108+
Args:
109+
id (str): The ID of the user group to fetch.
164110
165111
Returns:
166-
None
112+
UserGroup of passed in ID (self)
113+
114+
Raises:
115+
InvalidQueryError: If the query fails to fetch the group information.
167116
"""
168117
query = """
169118
query GetUserGroupPyApi($id: ID!) {
@@ -196,14 +145,9 @@ def _reload(self):
196145
raise InvalidQueryError("Failed to fetch group")
197146
self.name = result["userGroup"]["name"]
198147
self.color = UserGroupColor(result["userGroup"]["color"])
199-
self.projects = {
200-
UserGroupProject(id=project["id"], name=project["name"])
201-
for project in result["userGroup"]["projects"]["nodes"]
202-
}
203-
self.users = {
204-
UserGroupUser(id=member["id"], email=member["email"])
205-
for member in result["userGroup"]["members"]["nodes"]
206-
}
148+
self.projects = self._get_projects_set(result["userGroup"]["projects"]["nodes"])
149+
self.users = self._get_users_set(result["userGroup"]["members"]["nodes"])
150+
return self
207151

208152
def update(self) -> "UserGroup":
209153
"""
@@ -249,10 +193,10 @@ def update(self) -> "UserGroup":
249193
"color":
250194
self.color.value,
251195
"projectIds": [
252-
project.id for project in self.projects
196+
project.uid for project in self.projects
253197
],
254198
"userIds": [
255-
user.id for user in self.users
199+
user.uid for user in self.users
256200
]
257201
}
258202
result = self.client.execute(query, params)
@@ -311,10 +255,10 @@ def create(self) -> "UserGroup":
311255
"color":
312256
self.color.value,
313257
"projectIds": [
314-
project.id for project in self.projects
258+
project.uid for project in self.projects
315259
],
316260
"userIds": [
317-
user.id for user in self.users
261+
user.uid for user in self.users
318262
]
319263
}
320264
result = self.client.execute(query, params)
@@ -351,8 +295,7 @@ def delete(self) -> bool:
351295
raise UnprocessableEntityError("Failed to delete user group")
352296
return result["deleteUserGroup"]["success"]
353297

354-
@staticmethod
355-
def get_user_groups(client: Client) -> Iterator["UserGroup"]:
298+
def get_user_groups(self) -> Iterator["UserGroup"]:
356299
"""
357300
Gets all user groups in Labelbox.
358301
@@ -390,29 +333,60 @@ def get_user_groups(client: Client) -> Iterator["UserGroup"]:
390333
"""
391334
nextCursor = None
392335
while True:
393-
userGroups = client.execute(
336+
userGroups = self.client.execute(
394337
query, {"nextCursor": nextCursor})["userGroups"]
395338
if not userGroups:
396339
return
397340
yield
398341
groups = userGroups["nodes"]
399342
for group in groups:
400-
yield UserGroup(client,
401-
reload=False,
402-
id=group["id"],
403-
name=group["name"],
404-
color=UserGroupColor(group["color"]),
405-
users={
406-
UserGroupUser(id=member["id"],
407-
email=member["email"])
408-
for member in group["members"]["nodes"]
409-
},
410-
projects={
411-
UserGroupProject(id=project["id"],
412-
name=project["name"])
413-
for project in group["projects"]["nodes"]
414-
})
343+
userGroup = UserGroup(self.client)
344+
userGroup.id = group["id"]
345+
userGroup.name = group["name"]
346+
userGroup.color = UserGroupColor(group["color"])
347+
userGroup.users = self._get_users_set(group["members"]["nodes"])
348+
userGroup.projects = self._get_projects_set(group["projects"]["nodes"])
349+
yield userGroup
415350
nextCursor = userGroups["nextCursor"]
416351
# this doesn't seem to be implemented right now to return a value other than null from the api
417352
if not nextCursor:
418353
break
354+
355+
def _get_users_set(self, user_nodes):
356+
"""
357+
Retrieves a set of User objects from the given user nodes.
358+
359+
Args:
360+
user_nodes (list): A list of user nodes containing user information.
361+
362+
Returns:
363+
set: A set of User objects.
364+
"""
365+
users = set()
366+
for user in user_nodes:
367+
user_values = defaultdict(lambda: None)
368+
user_values["id"] = user["id"]
369+
user_values["email"] = user["email"]
370+
users.add(User(self.client, user_values))
371+
return users
372+
373+
def _get_projects_set(self, project_nodes):
374+
"""
375+
Retrieves a set of projects based on the given project nodes.
376+
377+
Args:
378+
project_nodes (list): A list of project nodes.
379+
380+
Returns:
381+
set: A set of Project objects.
382+
"""
383+
projects = set()
384+
for project in project_nodes:
385+
project_values = defaultdict(lambda: None)
386+
project_values["id"] = project["id"]
387+
project_values["name"] = project["name"]
388+
project_values["queueMode"] = QueueMode.Batch.value
389+
project_values["editorTaskType"] = EditorTaskType.Missing.value
390+
project_values["mediaType"] = MediaType.Image.value
391+
projects.add(Project(self.client, project_values))
392+
return projects

libs/labelbox/tests/integration/schema/test_user_group.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import faker
33
from labelbox import Client
4-
from labelbox.schema.user_group import UserGroup, UserGroupColor, UserGroupUser, UserGroupProject
4+
from labelbox.schema.user_group import UserGroup, UserGroupColor
55

66
data = faker.Faker()
77

@@ -21,7 +21,9 @@ def user_group(client):
2121

2222
def test_existing_user_groups(user_group, client):
2323
# Verify that the user group was created successfully
24-
user_group_equal = UserGroup(client, id=user_group.id)
24+
user_group_equal = UserGroup(client)
25+
user_group_equal.id = user_group.id
26+
user_group_equal.get()
2527
assert user_group.id == user_group_equal.id
2628
assert user_group.name == user_group_equal.name
2729
assert user_group.color == user_group_equal.color
@@ -48,15 +50,15 @@ def test_update_user_group(user_group):
4850

4951
def test_get_user_groups(user_group, client):
5052
# Get all user groups
51-
user_groups_old = list(UserGroup.get_user_groups(client))
53+
user_groups_old = list(UserGroup(client).get_user_groups())
5254

5355
# manual delete for iterators
5456
group_name = data.name()
5557
user_group = UserGroup(client)
5658
user_group.name = group_name
5759
user_group.create()
5860

59-
user_groups_new = list(UserGroup.get_user_groups(client))
61+
user_groups_new = list(UserGroup(client).get_user_groups())
6062

6163
# Verify that at least one user group is returned
6264
assert len(user_groups_new) > 0
@@ -77,15 +79,7 @@ def test_update_user_group(user_group, client, project_pack):
7779

7880
# Add the user to the group
7981
user = users[0]
80-
user = UserGroupUser(
81-
id=user.uid,
82-
email=user.email
83-
)
8482
project = projects[0]
85-
project = UserGroupProject(
86-
id=project.uid,
87-
name=project.name
88-
)
8983
user_group.users.add(user)
9084
user_group.projects.add(project)
9185
user_group.update()

0 commit comments

Comments
 (0)