Skip to content

Commit e880965

Browse files
author
Adrian Chang
committed
save
1 parent 4ff5057 commit e880965

File tree

2 files changed

+63
-109
lines changed

2 files changed

+63
-109
lines changed

libs/labelbox/src/labelbox/schema/user_group.py

Lines changed: 54 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from enum import Enum
2-
from typing import Set, List, Union, Iterator, Optional
2+
from typing import Set, Iterator
3+
from collections import defaultdict
34

45
from labelbox import Client
56
from labelbox.exceptions import ResourceCreationError
67
from labelbox.pydantic_compat import BaseModel
78
from labelbox.schema.user import User
89
from labelbox.schema.project import Project
910
from labelbox.exceptions import UnprocessableEntityError, InvalidQueryError
11+
from labelbox.schema.queue_mode import QueueMode
12+
from labelbox.schema.ontology_kind import EditorTaskType
13+
from labelbox.schema.media_type import MediaType
1014

1115

1216
class UserGroupColor(Enum):
@@ -34,55 +38,6 @@ class UserGroupColor(Enum):
3438
YELLOW = "E7BF00"
3539
GRAY = "B8C4D3"
3640

37-
38-
class UserGroupUser(BaseModel):
39-
"""
40-
Represents a user in a group.
41-
42-
Attributes:
43-
id (str): The ID of the user.
44-
email (str): The email of the user.
45-
"""
46-
id: str
47-
email: str
48-
49-
def __hash__(self):
50-
return hash((self.id))
51-
52-
def __eq__(self, other):
53-
if not isinstance(other, UserGroupUser):
54-
return False
55-
return self.id == other.id
56-
57-
58-
class UserGroupProject(BaseModel):
59-
"""
60-
Represents a project in a group.
61-
62-
Attributes:
63-
id (str): The ID of the project.
64-
name (str): The name of the project.
65-
"""
66-
id: str
67-
name: str
68-
69-
def __hash__(self):
70-
return hash((self.id))
71-
72-
def __eq__(self, other):
73-
"""
74-
Check if this GroupProject object is equal to another GroupProject object.
75-
76-
Args:
77-
other (GroupProject): The other GroupProject object to compare with.
78-
79-
Returns:
80-
bool: True if the two GroupProject objects are equal, False otherwise.
81-
"""
82-
if not isinstance(other, UserGroupProject):
83-
return False
84-
return self.id == other.id
85-
8641

8742
class UserGroup(BaseModel):
8843
"""
@@ -107,8 +62,8 @@ class UserGroup(BaseModel):
10762
id: str
10863
name: str
10964
color: UserGroupColor
110-
users: Set[UserGroupUser]
111-
projects: Set[UserGroupProject]
65+
users: Set[User]
66+
projects: Set[Project]
11267
client: Client
11368

11469
class Config:
@@ -182,14 +137,8 @@ def get(self) -> "UserGroup":
182137
raise InvalidQueryError("Failed to fetch group")
183138
self.name = result["userGroup"]["name"]
184139
self.color = UserGroupColor(result["userGroup"]["color"])
185-
self.projects = {
186-
UserGroupProject(id=project["id"], name=project["name"])
187-
for project in result["userGroup"]["projects"]["nodes"]
188-
}
189-
self.users = {
190-
UserGroupUser(id=member["id"], email=member["email"])
191-
for member in result["userGroup"]["members"]["nodes"]
192-
}
140+
self.projects = self._get_projects_set(result["userGroup"]["projects"]["nodes"])
141+
self.users = self._get_users_set(result["userGroup"]["members"]["nodes"])
193142
return self
194143

195144
def update(self) -> "UserGroup":
@@ -236,10 +185,10 @@ def update(self) -> "UserGroup":
236185
"color":
237186
self.color.value,
238187
"projectIds": [
239-
project.id for project in self.projects
188+
project.uid for project in self.projects
240189
],
241190
"userIds": [
242-
user.id for user in self.users
191+
user.uid for user in self.users
243192
]
244193
}
245194
result = self.client.execute(query, params)
@@ -298,10 +247,10 @@ def create(self) -> "UserGroup":
298247
"color":
299248
self.color.value,
300249
"projectIds": [
301-
project.id for project in self.projects
250+
project.uid for project in self.projects
302251
],
303252
"userIds": [
304-
user.id for user in self.users
253+
user.uid for user in self.users
305254
]
306255
}
307256
result = self.client.execute(query, params)
@@ -387,16 +336,49 @@ def get_user_groups(self) -> Iterator["UserGroup"]:
387336
userGroup.id = group["id"]
388337
userGroup.name = group["name"]
389338
userGroup.color = UserGroupColor(group["color"])
390-
userGroup.users = {
391-
UserGroupUser(id=member["id"], email=member["email"])
392-
for member in group["members"]["nodes"]
393-
}
394-
userGroup.projects = {
395-
UserGroupProject(id=project["id"], name=project["name"])
396-
for project in group["projects"]["nodes"]
397-
}
339+
userGroup.users = self._get_users_set(group["members"]["nodes"])
340+
userGroup.projects = self._get_projects_set(group["projects"]["nodes"])
398341
yield userGroup
399342
nextCursor = userGroups["nextCursor"]
400343
# this doesn't seem to be implemented right now to return a value other than null from the api
401344
if not nextCursor:
402345
break
346+
347+
def _get_users_set(self, user_nodes):
348+
"""
349+
Retrieves a set of User objects from the given user nodes.
350+
351+
Args:
352+
user_nodes (list): A list of user nodes containing user information.
353+
354+
Returns:
355+
set: A set of User objects.
356+
"""
357+
users = set()
358+
for user in user_nodes:
359+
user_values = defaultdict(lambda: None)
360+
user_values["id"] = user["id"]
361+
user_values["email"] = user["email"]
362+
users.add(User(self.client, user_values))
363+
return users
364+
365+
def _get_projects_set(self, project_nodes):
366+
"""
367+
Retrieves a set of projects based on the given project nodes.
368+
369+
Args:
370+
project_nodes (list): A list of project nodes.
371+
372+
Returns:
373+
set: A set of Project objects.
374+
"""
375+
projects = set()
376+
for project in project_nodes:
377+
project_values = defaultdict(lambda: None)
378+
project_values["id"] = project["id"]
379+
project_values["name"] = project["name"]
380+
project_values["queueMode"] = QueueMode.Batch.value
381+
project_values["editorTaskType"] = EditorTaskType.Missing.value
382+
project_values["mediaType"] = MediaType.Audio.value
383+
projects.add(Project(self.client, project_values))
384+
return projects

libs/labelbox/tests/unit/schema/test_user_group.py

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import pytest
2+
from collections import defaultdict
23
from unittest.mock import MagicMock
34
from labelbox import Client
45
from labelbox.exceptions import ResourceCreationError
6+
from labelbox.schema.project import Project
57
from labelbox.schema.user import User
6-
from labelbox.schema.user_group import UserGroup, UserGroupColor, UserGroupUser, UserGroupProject
8+
from labelbox.schema.user_group import UserGroup, UserGroupColor
79

810

911
class TestUserGroupColor:
@@ -20,40 +22,6 @@ def test_user_group_color_values(self):
2022
assert UserGroupColor.GRAY.value == "B8C4D3"
2123

2224

23-
class TestUserGroupUser:
24-
25-
def test_user_group_user_attributes(self):
26-
user = UserGroupUser(id="user_id", email="test@example.com")
27-
assert user.id == "user_id"
28-
assert user.email == "test@example.com"
29-
30-
def test_user_group_user_equality(self):
31-
user1 = UserGroupUser(id="user_id", email="test@example.com")
32-
user2 = UserGroupUser(id="user_id", email="test@example.com")
33-
assert user1 == user2
34-
35-
def test_user_group_user_hash(self):
36-
user = UserGroupUser(id="user_id", email="test@example.com")
37-
assert hash(user) == hash("user_id")
38-
39-
40-
class TestUserGroupProject:
41-
42-
def test_user_group_project_attributes(self):
43-
project = UserGroupProject(id="project_id", name="Test Project")
44-
assert project.id == "project_id"
45-
assert project.name == "Test Project"
46-
47-
def test_user_group_project_equality(self):
48-
project1 = UserGroupProject(id="project_id", name="Test Project")
49-
project2 = UserGroupProject(id="project_id", name="Test Project")
50-
assert project1 == project2
51-
52-
def test_user_group_project_hash(self):
53-
project = UserGroupProject(id="project_id", name="Test Project")
54-
assert hash(project) == hash("project_id")
55-
56-
5725
class TestUserGroup:
5826

5927
def setup_method(self):
@@ -233,9 +201,13 @@ def test_create(self):
233201
group = self.group
234202
group.name = "New Group"
235203
group.color = UserGroupColor.PINK
236-
group.users = {UserGroupUser(id="user_id", email="test@example.com")}
204+
user_values = defaultdict(lambda: None)
205+
user_values["id"] = "user_id"
206+
user_values["email"] = "test@example.com"
207+
group.users = {User(self.client, user_values)}
208+
project_values = defaultdict(lambda: None)
237209
group.projects = {
238-
UserGroupProject(id="project_id", name="Test Project")
210+
Project(self.client, {id="project_id", name="Test Project", qu})
239211
}
240212

241213
self.client.execute.return_value = {

0 commit comments

Comments
 (0)