diff --git a/libs/labelbox/src/labelbox/schema/user_group.py b/libs/labelbox/src/labelbox/schema/user_group.py index c8779251b..01657f758 100644 --- a/libs/labelbox/src/labelbox/schema/user_group.py +++ b/libs/labelbox/src/labelbox/schema/user_group.py @@ -1,5 +1,6 @@ from enum import Enum -from typing import Set, List, Union, Iterator, Optional +from typing import Set, Iterator +from collections import defaultdict from labelbox import Client from labelbox.exceptions import ResourceCreationError @@ -7,6 +8,9 @@ from labelbox.schema.user import User from labelbox.schema.project import Project from labelbox.exceptions import UnprocessableEntityError, InvalidQueryError +from labelbox.schema.queue_mode import QueueMode +from labelbox.schema.ontology_kind import EditorTaskType +from labelbox.schema.media_type import MediaType class UserGroupColor(Enum): @@ -34,82 +38,32 @@ class UserGroupColor(Enum): YELLOW = "E7BF00" GRAY = "B8C4D3" - -class UserGroupUser(BaseModel): - """ - Represents a user in a group. - - Attributes: - id (str): The ID of the user. - email (str): The email of the user. - """ - id: str - email: str - - def __hash__(self): - return hash((self.id)) - - def __eq__(self, other): - if not isinstance(other, UserGroupUser): - return False - return self.id == other.id - - -class UserGroupProject(BaseModel): - """ - Represents a project in a group. - - Attributes: - id (str): The ID of the project. - name (str): The name of the project. - """ - id: str - name: str - - def __hash__(self): - return hash((self.id)) - - def __eq__(self, other): - """ - Check if this GroupProject object is equal to another GroupProject object. - - Args: - other (GroupProject): The other GroupProject object to compare with. - - Returns: - bool: True if the two GroupProject objects are equal, False otherwise. - """ - if not isinstance(other, UserGroupProject): - return False - return self.id == other.id - class UserGroup(BaseModel): """ Represents a user group in Labelbox. Attributes: - id (Optional[str]): The ID of the user group. - name (Optional[str]): The name of the user group. + id (str): The ID of the user group. + name (str): The name of the user group. color (UserGroupColor): The color of the user group. users (Set[UserGroupUser]): The set of users in the user group. projects (Set[UserGroupProject]): The set of projects associated with the user group. client (Client): The Labelbox client object. Methods: - __init__(self, client: Client, id: str = "", name: str = "", color: UserGroupColor = UserGroupColor.BLUE, - users: Set[UserGroupUser] = set(), projects: Set[UserGroupProject] = set(), reload=True) - _reload(self) + __init__(self, client: Client) + get(self) -> "UserGroup" update(self) -> "UserGroup" create(self) -> "UserGroup" delete(self) -> bool get_user_groups(client: Client) -> Iterator["UserGroup"] """ - id: Optional[str] - name: Optional[str] + id: str + name: str color: UserGroupColor - users: Set[UserGroupUser] - projects: Set[UserGroupProject] + users: Set[User] + projects: Set[Project] client: Client class Config: @@ -122,9 +76,8 @@ def __init__( id: str = "", name: str = "", color: UserGroupColor = UserGroupColor.BLUE, - users: Set[UserGroupUser] = set(), - projects: Set[UserGroupProject] = set(), - reload=True, + users: Set[User] = set(), + projects: Set[Project] = set() ): """ Initializes a UserGroup object. @@ -134,24 +87,17 @@ def __init__( id (str, optional): The ID of the user group. Defaults to an empty string. name (str, optional): The name of the user group. Defaults to an empty string. color (UserGroupColor, optional): The color of the user group. Defaults to UserGroupColor.BLUE. - users (Set[UserGroupUser], optional): The set of users in the user group. Defaults to an empty set. - projects (Set[UserGroupProject], optional): The set of projects associated with the user group. Defaults to an empty set. - reload (bool, optional): Whether to reload the partial representation of the group. Defaults to True. + users (Set[User], optional): The set of users in the user group. Defaults to an empty set. + projects (Set[Project], optional): The set of projects associated with the user group. Defaults to an empty set. Raises: RuntimeError: If the experimental feature is not enabled in the client. - """ super().__init__(client=client, id=id, name=name, color=color, users=users, projects=projects) if not self.client.enable_experimental: - raise RuntimeError( - "Please enable experimental in client to use UserGroups") + raise RuntimeError("Please enable experimental in client to use UserGroups") - # partial representation of the group, reload - if self.id and reload: - self._reload() - - def _reload(self): + def get(self) -> "UserGroup": """ Reloads the user group information from the server. @@ -159,11 +105,14 @@ def _reload(self): about the user group, including its name, color, projects, and members. The fetched information is then used to update the corresponding attributes of the `Group` object. - Raises: - InvalidQueryError: If the query fails to fetch the group information. + Args: + id (str): The ID of the user group to fetch. Returns: - None + UserGroup of passed in ID (self) + + Raises: + InvalidQueryError: If the query fails to fetch the group information. """ query = """ query GetUserGroupPyApi($id: ID!) { @@ -196,14 +145,9 @@ def _reload(self): raise InvalidQueryError("Failed to fetch group") self.name = result["userGroup"]["name"] self.color = UserGroupColor(result["userGroup"]["color"]) - self.projects = { - UserGroupProject(id=project["id"], name=project["name"]) - for project in result["userGroup"]["projects"]["nodes"] - } - self.users = { - UserGroupUser(id=member["id"], email=member["email"]) - for member in result["userGroup"]["members"]["nodes"] - } + self.projects = self._get_projects_set(result["userGroup"]["projects"]["nodes"]) + self.users = self._get_users_set(result["userGroup"]["members"]["nodes"]) + return self def update(self) -> "UserGroup": """ @@ -249,10 +193,10 @@ def update(self) -> "UserGroup": "color": self.color.value, "projectIds": [ - project.id for project in self.projects + project.uid for project in self.projects ], "userIds": [ - user.id for user in self.users + user.uid for user in self.users ] } result = self.client.execute(query, params) @@ -311,10 +255,10 @@ def create(self) -> "UserGroup": "color": self.color.value, "projectIds": [ - project.id for project in self.projects + project.uid for project in self.projects ], "userIds": [ - user.id for user in self.users + user.uid for user in self.users ] } result = self.client.execute(query, params) @@ -351,8 +295,7 @@ def delete(self) -> bool: raise UnprocessableEntityError("Failed to delete user group") return result["deleteUserGroup"]["success"] - @staticmethod - def get_user_groups(client: Client) -> Iterator["UserGroup"]: + def get_user_groups(self) -> Iterator["UserGroup"]: """ Gets all user groups in Labelbox. @@ -390,29 +333,60 @@ def get_user_groups(client: Client) -> Iterator["UserGroup"]: """ nextCursor = None while True: - userGroups = client.execute( + userGroups = self.client.execute( query, {"nextCursor": nextCursor})["userGroups"] if not userGroups: return yield groups = userGroups["nodes"] for group in groups: - yield UserGroup(client, - reload=False, - id=group["id"], - name=group["name"], - color=UserGroupColor(group["color"]), - users={ - UserGroupUser(id=member["id"], - email=member["email"]) - for member in group["members"]["nodes"] - }, - projects={ - UserGroupProject(id=project["id"], - name=project["name"]) - for project in group["projects"]["nodes"] - }) + userGroup = UserGroup(self.client) + userGroup.id = group["id"] + userGroup.name = group["name"] + userGroup.color = UserGroupColor(group["color"]) + userGroup.users = self._get_users_set(group["members"]["nodes"]) + userGroup.projects = self._get_projects_set(group["projects"]["nodes"]) + yield userGroup nextCursor = userGroups["nextCursor"] # this doesn't seem to be implemented right now to return a value other than null from the api if not nextCursor: break + + def _get_users_set(self, user_nodes): + """ + Retrieves a set of User objects from the given user nodes. + + Args: + user_nodes (list): A list of user nodes containing user information. + + Returns: + set: A set of User objects. + """ + users = set() + for user in user_nodes: + user_values = defaultdict(lambda: None) + user_values["id"] = user["id"] + user_values["email"] = user["email"] + users.add(User(self.client, user_values)) + return users + + def _get_projects_set(self, project_nodes): + """ + Retrieves a set of projects based on the given project nodes. + + Args: + project_nodes (list): A list of project nodes. + + Returns: + set: A set of Project objects. + """ + projects = set() + for project in project_nodes: + project_values = defaultdict(lambda: None) + project_values["id"] = project["id"] + project_values["name"] = project["name"] + project_values["queueMode"] = QueueMode.Batch.value + project_values["editorTaskType"] = EditorTaskType.Missing.value + project_values["mediaType"] = MediaType.Image.value + projects.add(Project(self.client, project_values)) + return projects diff --git a/libs/labelbox/tests/integration/schema/test_user_group.py b/libs/labelbox/tests/integration/schema/test_user_group.py index d426a218f..810ae5242 100644 --- a/libs/labelbox/tests/integration/schema/test_user_group.py +++ b/libs/labelbox/tests/integration/schema/test_user_group.py @@ -1,7 +1,7 @@ import pytest import faker from labelbox import Client -from labelbox.schema.user_group import UserGroup, UserGroupColor, UserGroupUser, UserGroupProject +from labelbox.schema.user_group import UserGroup, UserGroupColor data = faker.Faker() @@ -21,7 +21,9 @@ def user_group(client): def test_existing_user_groups(user_group, client): # Verify that the user group was created successfully - user_group_equal = UserGroup(client, id=user_group.id) + user_group_equal = UserGroup(client) + user_group_equal.id = user_group.id + user_group_equal.get() assert user_group.id == user_group_equal.id assert user_group.name == user_group_equal.name assert user_group.color == user_group_equal.color @@ -48,7 +50,7 @@ def test_update_user_group(user_group): def test_get_user_groups(user_group, client): # Get all user groups - user_groups_old = list(UserGroup.get_user_groups(client)) + user_groups_old = list(UserGroup(client).get_user_groups()) # manual delete for iterators group_name = data.name() @@ -56,7 +58,7 @@ def test_get_user_groups(user_group, client): user_group.name = group_name user_group.create() - user_groups_new = list(UserGroup.get_user_groups(client)) + user_groups_new = list(UserGroup(client).get_user_groups()) # Verify that at least one user group is returned assert len(user_groups_new) > 0 @@ -77,15 +79,7 @@ def test_update_user_group(user_group, client, project_pack): # Add the user to the group user = users[0] - user = UserGroupUser( - id=user.uid, - email=user.email - ) project = projects[0] - project = UserGroupProject( - id=project.uid, - name=project.name - ) user_group.users.add(user) user_group.projects.add(project) user_group.update() diff --git a/libs/labelbox/tests/unit/schema/test_user_group.py b/libs/labelbox/tests/unit/schema/test_user_group.py index 6f1400308..4217f68bf 100644 --- a/libs/labelbox/tests/unit/schema/test_user_group.py +++ b/libs/labelbox/tests/unit/schema/test_user_group.py @@ -1,9 +1,32 @@ import pytest +from collections import defaultdict from unittest.mock import MagicMock from labelbox import Client from labelbox.exceptions import ResourceCreationError +from labelbox.schema.project import Project from labelbox.schema.user import User -from labelbox.schema.user_group import UserGroup, UserGroupColor, UserGroupUser, UserGroupProject +from labelbox.schema.user_group import UserGroup, UserGroupColor +from labelbox.schema.queue_mode import QueueMode +from labelbox.schema.ontology_kind import EditorTaskType +from labelbox.schema.media_type import MediaType + +@pytest.fixture +def group_user(): + user_values = defaultdict(lambda: None) + user_values["id"] = "user_id" + user_values["email"] = "test@example.com" + return User(MagicMock(Client), user_values) + + +@pytest.fixture +def group_project(): + project_values = defaultdict(lambda: None) + project_values["id"] = "project_id" + project_values["name"] = "Test Project" + project_values["queueMode"] = QueueMode.Batch.value + project_values["editorTaskType"] = EditorTaskType.Missing.value + project_values["mediaType"] = MediaType.Image.value + return Project(MagicMock(Client), project_values) class TestUserGroupColor: @@ -20,46 +43,12 @@ def test_user_group_color_values(self): assert UserGroupColor.GRAY.value == "B8C4D3" -class TestUserGroupUser: - - def test_user_group_user_attributes(self): - user = UserGroupUser(id="user_id", email="test@example.com") - assert user.id == "user_id" - assert user.email == "test@example.com" - - def test_user_group_user_equality(self): - user1 = UserGroupUser(id="user_id", email="test@example.com") - user2 = UserGroupUser(id="user_id", email="test@example.com") - assert user1 == user2 - - def test_user_group_user_hash(self): - user = UserGroupUser(id="user_id", email="test@example.com") - assert hash(user) == hash("user_id") - - -class TestUserGroupProject: - - def test_user_group_project_attributes(self): - project = UserGroupProject(id="project_id", name="Test Project") - assert project.id == "project_id" - assert project.name == "Test Project" - - def test_user_group_project_equality(self): - project1 = UserGroupProject(id="project_id", name="Test Project") - project2 = UserGroupProject(id="project_id", name="Test Project") - assert project1 == project2 - - def test_user_group_project_hash(self): - project = UserGroupProject(id="project_id", name="Test Project") - assert hash(project) == hash("project_id") - - class TestUserGroup: def setup_method(self): self.client = MagicMock(Client) self.client.enable_experimental = True - self.group = UserGroup(client=self.client, name="Test Group") + self.group = UserGroup(client=self.client) def test_constructor_experimental_needed(self): client = MagicMock(Client) @@ -67,63 +56,36 @@ def test_constructor_experimental_needed(self): with pytest.raises(RuntimeError): group = UserGroup(client) - def test_constructor_name(self): - group = self.group - assert group.name == "Test Group" - assert group.color == UserGroupColor.BLUE - - def test_constructor_id_no_reload(self): - projects = [{ - "id": "project_id_1", - "name": "project_1" - }, { - "id": "project_id_2", - "name": "project_2" - }] - group_members = [{ - "id": "user_id_1", - "email": "email_1" - }, { - "id": "user_id_2", - "email": "email_2" - }] - self.client.execute.return_value = { - "userGroup": { - "id": "group_id", - "name": "Test Group", - "color": "4ED2F9", - "projects": { - "nodes": projects - }, - "members": { - "nodes": group_members - } - } - } - - group = UserGroup(self.client, id="group_id", reload=False) + def test_constructor(self): + group = UserGroup(self.client) - assert group.id == "group_id" + assert group.id == "" assert group.name == "" assert group.color is UserGroupColor.BLUE assert len(group.projects) == 0 assert len(group.users) == 0 - def test_constructor_id(self): - projects = [{ - "id": "project_id_1", - "name": "project_1" - }, { - "id": "project_id_2", - "name": "project_2" - }] - group_members = [{ - "id": "user_id_1", - "email": "email_1" - }, { - "id": "user_id_2", - "email": "email_2" - }] + def test_get(self): + projects = [ + { + "id": "project_id_1", + "name": "project_1" + }, + { + "id": "project_id_2", + "name": "project_2" + } + ] + group_members = [ + { + "id": "user_id_1", + "email": "email_1" + }, + { + "id": "user_id_2", + "email": "email_2" + } + ] self.client.execute.return_value = { "userGroup": { "id": "group_id", @@ -137,86 +99,29 @@ def test_constructor_id(self): } } } - group = UserGroup(self.client, id="group_id") - assert group.id == "group_id" - assert group.name == "Test Group" - assert group.color == UserGroupColor.CYAN - assert len(group.projects) == 2 - assert len(group.users) == 2 - - def test_id(self): - group = self.group + group = UserGroup(self.client) assert group.id == "" - - group.id = "1" - assert group.id == "1" - - group.id = "2" - assert group.id == "2" - - def test_name(self): - group = self.group - assert group.name == "Test Group" - - group.name = "New Group" - assert group.name == "New Group" - - group.name = "Another Group" - assert group.name == "Another Group" - - def test_color(self): - group = self.group + assert group.name == "" assert group.color is UserGroupColor.BLUE - - group.color = UserGroupColor.PINK - assert group.color == UserGroupColor.PINK - - group.color = UserGroupColor.YELLOW - assert group.color == UserGroupColor.YELLOW - - def test_users(self): - group = self.group - assert len(group.users) == 0 - - group.users = {UserGroupUser(id="user_id", email="user_id@email")} - assert len(group.users) == 1 - - group.users = { - UserGroupUser(id="user_id", email="user_id@email"), - UserGroupUser(id="user_id", email="user_id@email") - } - assert len(group.users) == 1 - - group.users = {} - assert len(group.users) == 0 - - def test_projects(self): - group = self.group assert len(group.projects) == 0 + assert len(group.users) == 0 - group.projects = { - UserGroupProject(id="project_id", name="Test Project") - } - assert len(group.projects) == 1 - - group.projects = { - UserGroupProject(id="project_id", name="Test Project"), - UserGroupProject(id="project_id", name="Test Project") - } - assert len(group.projects) == 1 + group.id = "group_id" + group.get() - group.projects = {} - assert len(group.projects) == 0 + assert group.id == "group_id" + assert group.name == "Test Group" + assert group.color is UserGroupColor.CYAN + assert len(group.projects) == 2 + assert len(group.users) == 2 - def test_update(self): + def test_update(self, group_user, group_project): group = self.group group.id = "group_id" group.name = "Test Group" group.color = UserGroupColor.BLUE - group.users = {UserGroupUser(id="user_id", email="test@example.com")} - group.projects = { - UserGroupProject(id="project_id", name="Test Project") - } + group.users = { group_user } + group.projects = { group_project } updated_group = group.update() @@ -227,17 +132,17 @@ def test_update(self): assert execute[1]["name"] == "Test Group" assert execute[1]["color"] == UserGroupColor.BLUE.value assert len(execute[1]["userIds"]) == 1 - assert list(execute[1]["userIds"])[0] == "user_id" + assert list(execute[1]["userIds"])[0] == group_user.uid assert len(execute[1]["projectIds"]) == 1 - assert list(execute[1]["projectIds"])[0] == "project_id" + assert list(execute[1]["projectIds"])[0] == group_project.uid assert updated_group.id == "group_id" assert updated_group.name == "Test Group" assert updated_group.color == UserGroupColor.BLUE assert len(updated_group.users) == 1 - assert list(updated_group.users)[0].id == "user_id" + assert list(updated_group.users)[0].uid == group_user.uid assert len(updated_group.projects) == 1 - assert list(updated_group.projects)[0].id == "project_id" + assert list(updated_group.projects)[0].uid == group_project.uid def test_create_with_exception_id(self): group = self.group @@ -253,14 +158,12 @@ def test_create_with_exception_name(self): with pytest.raises(ValueError): group.create() - def test_create(self): + def test_create(self, group_user, group_project): group = self.group group.name = "New Group" group.color = UserGroupColor.PINK - group.users = {UserGroupUser(id="user_id", email="test@example.com")} - group.projects = { - UserGroupProject(id="project_id", name="Test Project") - } + group.users = { group_user } + group.projects = { group_project } self.client.execute.return_value = { "createUserGroup": { @@ -284,9 +187,9 @@ def test_create(self): assert created_group.name == "New Group" assert created_group.color == UserGroupColor.PINK assert len(created_group.users) == 1 - assert list(created_group.users)[0].id == "user_id" + assert list(created_group.users)[0].uid == "user_id" assert len(created_group.projects) == 1 - assert list(created_group.projects)[0].id == "project_id" + assert list(created_group.projects)[0].uid == "project_id" def test_delete(self): group = self.group @@ -307,7 +210,7 @@ def test_delete(self): def test_user_groups_empty(self): self.client.execute.return_value = {"userGroups": None} - user_groups = list(UserGroup.get_user_groups(self.client)) + user_groups = list(UserGroup(self.client).get_user_groups()) assert len(user_groups) == 0 @@ -386,8 +289,10 @@ def test_user_groups(self): } } - user_groups = list(UserGroup.get_user_groups(self.client)) + user_groups = list(UserGroup(self.client).get_user_groups()) + execute = self.client.execute.call_args[0] + assert "GetUserGroupsPyApi" in execute[0] assert len(user_groups) == 3 # Check the attributes of the first user group