From 1e93d16758211b6d0a9496ba66b02956ef075d70 Mon Sep 17 00:00:00 2001 From: Adrian Chang Date: Thu, 9 May 2024 17:46:22 -0700 Subject: [PATCH 1/7] Rough group --- libs/labelbox/src/labelbox/pydantic_compat.py | 2 +- libs/labelbox/src/labelbox/schema/group.py | 459 ++++++++++++++++++ 2 files changed, 460 insertions(+), 1 deletion(-) create mode 100644 libs/labelbox/src/labelbox/schema/group.py diff --git a/libs/labelbox/src/labelbox/pydantic_compat.py b/libs/labelbox/src/labelbox/pydantic_compat.py index 51c082480..4bcece74e 100644 --- a/libs/labelbox/src/labelbox/pydantic_compat.py +++ b/libs/labelbox/src/labelbox/pydantic_compat.py @@ -31,4 +31,4 @@ def pydantic_import(class_name, sub_module_path: Optional[str] = None): conint = pydantic_import("conint") conlist = pydantic_import("conlist") constr = pydantic_import("constr") -confloat = pydantic_import("confloat") +confloat = pydantic_import("confloat") \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/schema/group.py b/libs/labelbox/src/labelbox/schema/group.py new file mode 100644 index 000000000..ea369685d --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/group.py @@ -0,0 +1,459 @@ +from enum import Enum +from typing import Set, List, TypedDict, Optional + +from labelbox import Client +from labelbox.exceptions import ResourceCreationError +from labelbox.schema.user import User +from labelbox.schema.project import Project +from labelbox.pydantic_compat import BaseModel, PrivateAttr + +class GroupColor(Enum): + """ + Enum representing the colors available for a group. + + Attributes: + BLUE (str): Hex color code for blue (#9EC5FF). + PURPLE (str): Hex color code for purple (#CEB8FF). + ORANGE (str): Hex color code for orange (#FFB35F). + CYAN (str): Hex color code for cyan (#4ED2F9). + PINK (str): Hex color code for pink (#FFAEA9). + LIGHT_PINK (str): Hex color code for light pink (#FFA9D5). + GREEN (str): Hex color code for green (#3FDC9A). + YELLOW (str): Hex color code for yellow (#E7BF00). + GRAY (str): Hex color code for gray (#B8C4D3). + """ + BLUE = "9EC5FF" + PURPLE = "CEB8FF" + ORANGE = "FFB35F" + CYAN = "4ED2F9" + PINK = "FFAEA9" + LIGHT_PINK = "FFA9D5" + GREEN = "3FDC9A" + YELLOW = "E7BF00" + GRAY = "B8C4D3" + + +class GroupUser(BaseModel): + """ + Represents a user in a group. + + Attributes: + id (str): The ID of the user. + email (str): The email of the user. + """ + id: str + email: str + + + def __hash__(self): + return hash((self.id)) + + def __eq__(self, other): + if not isinstance(other, GroupUser): + return False + return self.id == other.id + + +class GroupProject(BaseModel): + """ + Represents a project in a group. + + Attributes: + id (str): The ID of the project. + name (str): The name of the project. + """ + id: str + name: str + + def __hash__(self): + return hash((self.id)) + + def __eq__(self, other): + """ + Check if this GroupProject object is equal to another GroupProject object. + + Args: + other (GroupProject): The other GroupProject object to compare with. + + Returns: + bool: True if the two GroupProject objects are equal, False otherwise. + """ + if not isinstance(other, GroupProject): + return False + return self.id == other.id + + +class GroupParmeters(TypedDict): + """ + Represents the parameters for a group. + + Attributes: + id (Optional[str]): The ID of the group. + name (Optional[str]): The name of the group. + color (Optional[GroupColor]): The color of the group. + users (Optional[Set[GroupUser]]): The users in the group. + projects (Optional[Set[GroupProject]]): The projects associated with the group. + """ + id: Optional[str] + name: Optional[str] + color: Optional[GroupColor] + users: Optional[Set[GroupUser]] + projects: Optional[Set[GroupProject]] + + +class Group(): + """ + Represents a group of users in Labelbox. + + Args: + client (Client): The Labelbox client. + data (dict): The data dictionary containing group information. + + Attributes: + id (str): The ID of the group. + name (str): The name of the group. + color (GroupColor): The color of the group. + users (List[str]): The list of user IDs in the group. + projects (List[str]): The list of project IDs in the group. + client (Client): The Labelbox client. + """ + _id: str + _name: str = None + _color: GroupColor = None + _users: Set[GroupUser] = None + _projects: Set[GroupProject] = None + _client: Client + + def __init__(self, client: Client, **kwargs: GroupParmeters): + """ + Initializes a Group object. + + Args: + client (Client): The Labelbox client. + **kwargs: Additional keyword arguments for initializing the Group object. + """ + self.id = kwargs['id'] + self.color = kwargs.get('color', GroupColor.BLUE) + self.users = kwargs.get('users', {}) + self.projects = kwargs.get('projects', {}) + self.client = client + + # runs against _gql + if client.enable_experimental is False: + raise RuntimeError("Experimental features are not enabled. Please enable them in the client to use this feature.") + + # partial respentation of the group, reload + if self.id is not None: + self._reload() + else: + self.name = kwargs['name'] + super().__init__() + + def _reload(self): + """ + Reloads the user group information from the server. + + This method sends a GraphQL query to the server to fetch the latest information + about the user group, including its name, color, projects, and members. The fetched + information is then used to update the corresponding attributes of the `Group` object. + + Returns: + None + """ + query = """ + query GetUserGroupPyApi($id: ID!) { + userGroup(where: {id: $id}) { + id + name + color + projects { + nodes { + id + name + } + totalCount + } + members { + nodes { + id + email + } + totalCount + } + } + } + """ + params = { + "id": self.id, + } + result = self.client.execute(query, params) + self.name = result["userGroup"]["name"] + self.color = GroupColor(result["userGroup"]["color"]) + self.projects = {GroupProject(id=project["id"], name=project["name"]) for project in result["userGroup"]["projects"]["nodes"]} + self.users = {GroupUser(id=member["id"], email=member["email"]) for member in result["userGroup"]["members"]["nodes"]} + + @property + def id(self) -> str: + """ + Gets the ID of the group. + + Returns: + str: The ID of the group. + """ + return self._id + + @id.setter + def id(self, value: str) -> None: + """ + Sets the ID of the group. + + Args: + value (str): The ID to set. + """ + self._id = value + + @property + def name(self) -> str: + """ + Gets the name of the group. + + Returns: + str: The name of the group. + """ + return self._name + + @name.setter + def name(self, value: str) -> None: + """ + Sets the name of the group. + + Args: + value (str): The name to set. + """ + self._name = value + + @property + def color(self) -> GroupColor: + """ + Gets the color of the group. + + Returns: + GroupColor: The color of the group. + """ + return self._color + + @color.setter + def color(self, value: GroupColor) -> None: + """ + Sets the color of the group. + + Args: + value (GroupColor): The color to set. + """ + self._color = value + self._update() + + @property + def users(self) -> Set[GroupUser]: + """ + Gets the list of user IDs in the group. + + Returns: + Set[GroupUser]: The list of user IDs in the group. + """ + return self._users + + @users.setter + def users(self, value: Set[GroupUser]) -> None: + """ + Sets the list of user IDs in the group. + + Args: + value (Set[GroupUser]): The list of user IDs to set. + """ + self._users = value + + @property + def projects(self) -> Set[GroupProject]: + """ + Gets the list of project IDs in the group. + + Returns: + Set[GroupProject]: The list of project IDs in the group. + """ + return self._projects + + @projects.setter + def projects(self, value: Set[GroupProject]) -> None: + """ + Sets the list of project IDs in the group. + + Args: + value (Set[GroupProject]): The list of project IDs to set. + """ + self._projects = value + + def update(self): + """ + Updates the group in Labelbox. + """ + query = """ + mutation UpdateUserGroupPyApi($id: ID!, $name: String!, $color: String!, $projectIds: [String!]!, $userIds: [String!]!) { + updateUserGroup( + where: {id: $id} + data: {name: $name, color: $color, projectIds: $projectIds, userIds: $userIds} + ) { + group { + id + name + color + projects { + nodes { + id + name + } + } + members { + nodes { + id + email + } + } + } + } + } + """ + params = { + "id": self.id, + "name": self.name, + "color": self.color.value, + "projectIds": [project.id for project in self.projects], + "userIds": [user.id for user in self.users] + } + self.client.execute(query, params) + + def create(self): + """ + Creates a new group in Labelbox. + + Args: + client (Client): The Labelbox client. + name (str): The name of the group. + color (GroupColor, optional): The color of the group. Defaults to GroupColor.BLUE. + users (List[User], optional): The users to add to the group. Defaults to []. + projects (List[Project], optional): The projects to add to the group. Defaults to []. + + Returns: + Group: The newly created group. + """ + if self.id is not None: + raise ResourceCreationError("Group already exists") + query = """ + mutation CreateUserGroupPyApi($name: String!, $color: String!, $projectIds: [String!]!, $userIds: [String!]!) { + createUserGroup( + data: { + name: $name, + color: $color, + projectIds: $projectIds, + userIds: $userIds + } + ) { + group { + id + name + color + projects { + nodes { + id + name + } + } + members { + nodes { + id + email + } + } + } + } + } + """ + params = { + "name": self.name, + "color": self.color.value, + "projectIds": [project.id for project in self.projects], + "userIds": [user.id for user in self.users] + } + result = self.client.execute(query, params)["createUserGroup"]["group"] + self.id = result["id"] + + def delete(self) -> bool: + """ + Deletes the group from Labelbox. + + Returns: + bool: True if the group was successfully deleted, False otherwise. + """ + query = """ + mutation DeleteUserGroupPyApi($id: ID!) { + deleteUserGroup(where: {id: $id}) { + success + } + } + """ + params = { + "id": self.id + } + result = self.client.execute(query, params) + return result["deleteUserGroup"]["success"] + + @staticmethod + def groups(client: Client) -> List["Group"]: + """ + Gets all groups in Labelbox. + + Args: + client (Client): The Labelbox client. + + Returns: + List[Group]: The list of groups. + """ + query = """ + query GetUserGroups { + userGroups { + nodes { + id + name + color + projects { + nodes { + id + name + } + totalCount + } + members { + nodes { + id + email + } + totalCount + } + } + nextCursor + } + } + """ + userGroups = client.execute(query)["userGroups"] + groups = userGroups["nodes"] + return [ + Group( + client, + group["id"], + group["name"], + GroupColor(group["color"]), + {GroupUser(id=member["id"], email=member["email"]) for member in group["members"]["nodes"]}, + {GroupProject(id=project["id"], name=project["name"]) for project in group["projects"]["nodes"]} + ) + for group in groups + ] \ No newline at end of file From 4ed0f4612358fe6c2890858c2cb65b142c5ca120 Mon Sep 17 00:00:00 2001 From: Adrian Chang Date: Mon, 27 May 2024 23:17:49 -0700 Subject: [PATCH 2/7] add users --- libs/labelbox/src/labelbox/__init__.py | 1 + libs/labelbox/src/labelbox/client.py | 13 + libs/labelbox/src/labelbox/orm/model.py | 1 + .../schema/{group.py => user_group.py} | 147 ++++--- .../tests/unit/schema/test_user_group.py | 399 ++++++++++++++++++ 5 files changed, 492 insertions(+), 69 deletions(-) rename libs/labelbox/src/labelbox/schema/{group.py => user_group.py} (72%) create mode 100644 libs/labelbox/tests/unit/schema/test_user_group.py diff --git a/libs/labelbox/src/labelbox/__init__.py b/libs/labelbox/src/labelbox/__init__.py index 4cd3b4390..2157c0e62 100644 --- a/libs/labelbox/src/labelbox/__init__.py +++ b/libs/labelbox/src/labelbox/__init__.py @@ -41,3 +41,4 @@ from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.ontology_kind import OntologyKind from labelbox.schema.project_overview import ProjectOverview, ProjectOverviewDetailed +from labelbox.schema.user_group import UserGroup, UserGroupColor, UserGroupUser, UserGroupProject \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 019285350..bb2aaf1cd 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -543,6 +543,19 @@ def get_projects(self, where=None) -> PaginatedCollection: """ return self._get_all(Entity.Project, where) + def get_users(self, where=None) -> PaginatedCollection: + """ Fetches all the users. + + >>> users = client.get_users(where=User.email == "") + + Args: + where (Comparison, LogicalOperation or None): The `where` clause + for filtering. + Returns: + An iterable of Users (typically a PaginatedCollection). + """ + return self._get_all(Entity.User, where, filter_deleted=False) + def get_datasets(self, where=None) -> PaginatedCollection: """ Fetches one or more datasets. diff --git a/libs/labelbox/src/labelbox/orm/model.py b/libs/labelbox/src/labelbox/orm/model.py index f3afa174e..2e578fae2 100644 --- a/libs/labelbox/src/labelbox/orm/model.py +++ b/libs/labelbox/src/labelbox/orm/model.py @@ -382,6 +382,7 @@ class Entity(metaclass=EntityMeta): CatalogSlice: Type[labelbox.CatalogSlice] ModelSlice: Type[labelbox.ModelSlice] TaskQueue: Type[labelbox.TaskQueue] + UserGroup: Type[labelbox.UserGroup] @classmethod def _attributes_of_type(cls, attr_type): diff --git a/libs/labelbox/src/labelbox/schema/group.py b/libs/labelbox/src/labelbox/schema/user_group.py similarity index 72% rename from libs/labelbox/src/labelbox/schema/group.py rename to libs/labelbox/src/labelbox/schema/user_group.py index ea369685d..464adf330 100644 --- a/libs/labelbox/src/labelbox/schema/group.py +++ b/libs/labelbox/src/labelbox/schema/user_group.py @@ -1,13 +1,15 @@ from enum import Enum -from typing import Set, List, TypedDict, Optional +from typing import Set, List, Optional, Union, TypedDict, Iterator + from labelbox import Client from labelbox.exceptions import ResourceCreationError +from labelbox.pydantic_compat import BaseModel from labelbox.schema.user import User from labelbox.schema.project import Project -from labelbox.pydantic_compat import BaseModel, PrivateAttr -class GroupColor(Enum): + +class UserGroupColor(Enum): """ Enum representing the colors available for a group. @@ -33,7 +35,7 @@ class GroupColor(Enum): GRAY = "B8C4D3" -class GroupUser(BaseModel): +class UserGroupUser(BaseModel): """ Represents a user in a group. @@ -44,17 +46,16 @@ class GroupUser(BaseModel): id: str email: str - def __hash__(self): return hash((self.id)) def __eq__(self, other): - if not isinstance(other, GroupUser): + if not isinstance(other, UserGroupUser): return False return self.id == other.id -class GroupProject(BaseModel): +class UserGroupProject(BaseModel): """ Represents a project in a group. @@ -78,53 +79,53 @@ def __eq__(self, other): Returns: bool: True if the two GroupProject objects are equal, False otherwise. """ - if not isinstance(other, GroupProject): + if not isinstance(other, UserGroupProject): return False return self.id == other.id -class GroupParmeters(TypedDict): +class UserGroupParameters(TypedDict): """ - Represents the parameters for a group. + Represents the parameters for a user group. Attributes: - id (Optional[str]): The ID of the group. - name (Optional[str]): The name of the group. - color (Optional[GroupColor]): The color of the group. - users (Optional[Set[GroupUser]]): The users in the group. - projects (Optional[Set[GroupProject]]): The projects associated with the group. + id (Optional[str]): The ID of the user group. + name (Optional[str]): The name of the user group. + color (Optional[UserGroupColor]): The color of the user group. + users (Optional[Set[Union[UserGroupUser, User]]]): The users in the user group. + projects (Optional[Set[Union[UserGroupProject, Project]]]): The projects associated with the user group. """ id: Optional[str] name: Optional[str] - color: Optional[GroupColor] - users: Optional[Set[GroupUser]] - projects: Optional[Set[GroupProject]] + color: Optional[UserGroupColor] + users: Optional[Set[Union[UserGroupUser, User]]] + projects: Optional[Set[Union[UserGroupProject, Project]]] -class Group(): +class UserGroup: """ - Represents a group of users in Labelbox. + Represents a user group in Labelbox. Args: client (Client): The Labelbox client. - data (dict): The data dictionary containing group information. + **kwargs: Additional keyword arguments for initializing the UserGroup object. Attributes: - id (str): The ID of the group. - name (str): The name of the group. - color (GroupColor): The color of the group. - users (List[str]): The list of user IDs in the group. - projects (List[str]): The list of project IDs in the group. - client (Client): The Labelbox client. + _id (str): The ID of the user group. + _name (str): The name of the user group. + _color (UserGroupColor): The color of the user group. + _users (Set[Union[UserGroupUser, User]]): The set of user IDs in the user group. + _projects (Set[Union[UserGroupProject, Project]]): The set of project IDs in the user group. + _client (Client): The Labelbox client. """ - _id: str + _id: str = None _name: str = None - _color: GroupColor = None - _users: Set[GroupUser] = None - _projects: Set[GroupProject] = None - _client: Client + _color: UserGroupColor = None + _users: Set[Union[UserGroupUser, User]] = None + _projects: Set[Union[UserGroupProject, Project]] = None + _client: Client - def __init__(self, client: Client, **kwargs: GroupParmeters): + def __init__(self, client: Client, reload=True, **kwargs: UserGroupParameters): """ Initializes a Group object. @@ -132,22 +133,24 @@ def __init__(self, client: Client, **kwargs: GroupParmeters): client (Client): The Labelbox client. **kwargs: Additional keyword arguments for initializing the Group object. """ - self.id = kwargs['id'] - self.color = kwargs.get('color', GroupColor.BLUE) - self.users = kwargs.get('users', {}) - self.projects = kwargs.get('projects', {}) + super().__init__() + self.color = kwargs.get('color', UserGroupColor.BLUE) + self.users = kwargs.get('users', set()) + self.projects = kwargs.get('projects', set()) self.client = client # runs against _gql - if client.enable_experimental is False: + if not client.enable_experimental: raise RuntimeError("Experimental features are not enabled. Please enable them in the client to use this feature.") + if 'id' not in kwargs and 'name' not in kwargs: + raise ValueError("Either 'id' or 'name' must be provided") + + self.name = kwargs.get('name', None) + self.id = kwargs.get('id', None) # partial respentation of the group, reload - if self.id is not None: + if self.id is not None and reload: self._reload() - else: - self.name = kwargs['name'] - super().__init__() def _reload(self): """ @@ -188,9 +191,9 @@ def _reload(self): } result = self.client.execute(query, params) self.name = result["userGroup"]["name"] - self.color = GroupColor(result["userGroup"]["color"]) - self.projects = {GroupProject(id=project["id"], name=project["name"]) for project in result["userGroup"]["projects"]["nodes"]} - self.users = {GroupUser(id=member["id"], email=member["email"]) for member in result["userGroup"]["members"]["nodes"]} + self.color = UserGroupColor(result["userGroup"]["color"]) + self.projects = {UserGroupProject(id=project["id"], name=project["name"]) for project in result["userGroup"]["projects"]["nodes"]} + self.users = {UserGroupUser(id=member["id"], email=member["email"]) for member in result["userGroup"]["members"]["nodes"]} @property def id(self) -> str: @@ -233,7 +236,7 @@ def name(self, value: str) -> None: self._name = value @property - def color(self) -> GroupColor: + def color(self) -> UserGroupColor: """ Gets the color of the group. @@ -243,7 +246,7 @@ def color(self) -> GroupColor: return self._color @color.setter - def color(self, value: GroupColor) -> None: + def color(self, value: UserGroupColor) -> None: """ Sets the color of the group. @@ -251,10 +254,9 @@ def color(self, value: GroupColor) -> None: value (GroupColor): The color to set. """ self._color = value - self._update() @property - def users(self) -> Set[GroupUser]: + def users(self) -> Set[Union[UserGroupUser, User]]: """ Gets the list of user IDs in the group. @@ -264,7 +266,7 @@ def users(self) -> Set[GroupUser]: return self._users @users.setter - def users(self, value: Set[GroupUser]) -> None: + def users(self, value: Set[Union[UserGroupUser, User]]) -> None: """ Sets the list of user IDs in the group. @@ -274,7 +276,7 @@ def users(self, value: Set[GroupUser]) -> None: self._users = value @property - def projects(self) -> Set[GroupProject]: + def projects(self) -> Set[UserGroupProject]: """ Gets the list of project IDs in the group. @@ -284,7 +286,7 @@ def projects(self) -> Set[GroupProject]: return self._projects @projects.setter - def projects(self, value: Set[GroupProject]) -> None: + def projects(self, value: Set[UserGroupProject]) -> None: """ Sets the list of project IDs in the group. @@ -293,7 +295,7 @@ def projects(self, value: Set[GroupProject]) -> None: """ self._projects = value - def update(self): + def update(self) -> "UserGroup": """ Updates the group in Labelbox. """ @@ -331,8 +333,9 @@ def update(self): "userIds": [user.id for user in self.users] } self.client.execute(query, params) + return self - def create(self): + def create(self) -> "UserGroup": """ Creates a new group in Labelbox. @@ -386,6 +389,7 @@ def create(self): } result = self.client.execute(query, params)["createUserGroup"]["group"] self.id = result["id"] + return self def delete(self) -> bool: """ @@ -408,7 +412,7 @@ def delete(self) -> bool: return result["deleteUserGroup"]["success"] @staticmethod - def groups(client: Client) -> List["Group"]: + def user_groups(client: Client) -> Iterator["UserGroup"]: """ Gets all groups in Labelbox. @@ -419,7 +423,7 @@ def groups(client: Client) -> List["Group"]: List[Group]: The list of groups. """ query = """ - query GetUserGroups { + query GetUserGroupsPyApi { userGroups { nodes { id @@ -444,16 +448,21 @@ def groups(client: Client) -> List["Group"]: } } """ - userGroups = client.execute(query)["userGroups"] - groups = userGroups["nodes"] - return [ - Group( - client, - group["id"], - group["name"], - GroupColor(group["color"]), - {GroupUser(id=member["id"], email=member["email"]) for member in group["members"]["nodes"]}, - {GroupProject(id=project["id"], name=project["name"]) for project in group["projects"]["nodes"]} - ) - for group in groups - ] \ No newline at end of file + nextCursor = None + while True: + userGroups = client.execute(query, { "nextCursor": nextCursor })["userGroups"] + groups = userGroups["nodes"] + for group in groups: + yield UserGroup( + client, + reload=False, + id=group["id"], + name=group["name"], + color=UserGroupColor(group["color"]), + users={UserGroupUser(id=member["id"], email=member["email"]) for member in group["members"]["nodes"]}, + projects={UserGroupProject(id=project["id"], name=project["name"]) for project in group["projects"]["nodes"]} + ) + nextCursor = userGroups.get("nextCursor", None) + # this doesn't seem to be used right now + if nextCursor is None: + break \ No newline at end of file diff --git a/libs/labelbox/tests/unit/schema/test_user_group.py b/libs/labelbox/tests/unit/schema/test_user_group.py new file mode 100644 index 000000000..4ee62f485 --- /dev/null +++ b/libs/labelbox/tests/unit/schema/test_user_group.py @@ -0,0 +1,399 @@ +import pytest +from unittest.mock import MagicMock +from labelbox import Client +from labelbox.schema.user import User +from labelbox.schema.user_group import UserGroup, UserGroupColor, UserGroupUser, UserGroupProject, UserGroupParameters + + +class TestUserGroupColor: + + def test_user_group_color_values(self): + assert UserGroupColor.BLUE.value == "9EC5FF" + assert UserGroupColor.PURPLE.value == "CEB8FF" + assert UserGroupColor.ORANGE.value == "FFB35F" + assert UserGroupColor.CYAN.value == "4ED2F9" + assert UserGroupColor.PINK.value == "FFAEA9" + assert UserGroupColor.LIGHT_PINK.value == "FFA9D5" + assert UserGroupColor.GREEN.value == "3FDC9A" + assert UserGroupColor.YELLOW.value == "E7BF00" + assert UserGroupColor.GRAY.value == "B8C4D3" + + +class TestUserGroupUser: + + def test_user_group_user_attributes(self): + user = UserGroupUser(id="user_id", email="test@example.com") + assert user.id == "user_id" + assert user.email == "test@example.com" + + def test_user_group_user_equality(self): + user1 = UserGroupUser(id="user_id", email="test@example.com") + user2 = UserGroupUser(id="user_id", email="test@example.com") + assert user1 == user2 + + def test_user_group_user_hash(self): + user = UserGroupUser(id="user_id", email="test@example.com") + assert hash(user) == hash("user_id") + + +class TestUserGroupProject: + + def test_user_group_project_attributes(self): + project = UserGroupProject(id="project_id", name="Test Project") + assert project.id == "project_id" + assert project.name == "Test Project" + + def test_user_group_project_equality(self): + project1 = UserGroupProject(id="project_id", name="Test Project") + project2 = UserGroupProject(id="project_id", name="Test Project") + assert project1 == project2 + + def test_user_group_project_hash(self): + project = UserGroupProject(id="project_id", name="Test Project") + assert hash(project) == hash("project_id") + + +class TestUserGroupParameters: + + def test_user_group_parameters_attributes(self): + params = UserGroupParameters( + id="group_id", + name="Test Group", + color=UserGroupColor.BLUE, + users={UserGroupUser(id="user_id", email="test@example.com")}, + projects={UserGroupProject(id="project_id", name="Test Project")} + ) + + assert params["id"] == "group_id" + assert params["name"] == "Test Group" + assert params["color"] == UserGroupColor.BLUE + assert len(params["users"]) == 1 + assert list(params["users"])[0].id == "user_id" + assert len(params["projects"]) == 1 + assert list(params["projects"])[0].id == "project_id" + + +class TestUserGroup: + + def setup_method(self): + self.client = MagicMock() + self.group = UserGroup(self.client, name="Test Group") + self.client.enable_experimental = True + + def test_constructor_experimental_needed(self): + client = MagicMock(Client) + client.enable_experimental = False + with pytest.raises(RuntimeError): + group = UserGroup(client) + + def test_constructor_id_or_name_needed(self): + client = MagicMock(Client) + with pytest.raises(ValueError): + group = UserGroup(self.client) + + def test_constructor_name(self): + group = self.group + assert group.name == "Test Group" + assert group.color == UserGroupColor.BLUE + + def test_constructor_id_no_reload(self): + projects = [ + { + "id": "project_id_1", + "name": "project_1" + }, + { + "id": "project_id_2", + "name": "project_2" + } + ] + group_members = [ + { + "id": "user_id_1", + "email": "email_1" + }, + { + "id": "user_id_2", + "email": "email_2" + } + ] + self.client.execute.return_value = { + "userGroup": { + "id": "group_id", + "name": "Test Group", + "color": "4ED2F9", + "projects": { + "nodes": projects + }, + "members": { + "nodes": group_members + } + } + } + + group = UserGroup(self.client, id="group_id", reload=False) + + assert group.id == "group_id" + assert group.name is None + assert group.color is UserGroupColor.BLUE + assert len(group.projects) == 0 + assert len(group.users) == 0 + + def test_constructor_id(self): + projects = [ + { + "id": "project_id_1", + "name": "project_1" + }, + { + "id": "project_id_2", + "name": "project_2" + } + ] + group_members = [ + { + "id": "user_id_1", + "email": "email_1" + }, + { + "id": "user_id_2", + "email": "email_2" + } + ] + self.client.execute.return_value = { + "userGroup": { + "id": "group_id", + "name": "Test Group", + "color": "4ED2F9", + "projects": { + "nodes": projects + }, + "members": { + "nodes": group_members + } + } + } + group = UserGroup(self.client, id="group_id") + assert group.id == "group_id" + assert group.name == "Test Group" + assert group.color == UserGroupColor.CYAN + assert len(group.projects) == 2 + assert len(group.users) == 2 + + def test_id(self): + group = self.group + assert group.id is None + + group.id = "1" + assert group.id == "1" + + group.id = "2" + assert group.id == "2" + + def test_name(self): + group = self.group + assert group.name == "Test Group" + + group.name = "New Group" + assert group.name == "New Group" + + group.name = "Another Group" + assert group.name == "Another Group" + + def test_color(self): + group = self.group + assert group.color is UserGroupColor.BLUE + + group.color = UserGroupColor.PINK + assert group.color == UserGroupColor.PINK + + group.color = UserGroupColor.YELLOW + assert group.color == UserGroupColor.YELLOW + + def test_users(self): + group = self.group + assert len(group.users) == 0 + + group.users = {UserGroupUser(id="user_id", email="user_id@email")} + assert len(group.users) == 1 + + group.users = {UserGroupUser(id="user_id", email="user_id@email"), UserGroupUser(id="user_id", email="user_id@email")} + assert len(group.users) == 1 + + group.users = {} + assert len(group.users) == 0 + + def test_projects(self): + group = self.group + assert len(group.projects) == 0 + + group.projects = {UserGroupProject(id="project_id", name="Test Project")} + assert len(group.projects) == 1 + + group.projects = {UserGroupProject(id="project_id", name="Test Project"), UserGroupProject(id="project_id", name="Test Project")} + assert len(group.projects) == 1 + + group.projects = {} + assert len(group.projects) == 0 + + def test_update(self): + group = self.group + group.id = "group_id" + group.name = "Test Group" + group.color = UserGroupColor.BLUE + group.users = {UserGroupUser(id="user_id", email="test@example.com")} + group.projects = {UserGroupProject(id="project_id", name="Test Project")} + + updated_group = group.update() + + execute = self.client.execute.call_args[0] + + assert "UpdateUserGroupPyApi" in execute[0] + assert execute[1]["id"] == "group_id" + assert execute[1]["name"] == "Test Group" + assert execute[1]["color"] == UserGroupColor.BLUE.value + assert len(execute[1]["userIds"]) == 1 + assert list(execute[1]["userIds"])[0] == "user_id" + assert len(execute[1]["projectIds"]) == 1 + assert list(execute[1]["projectIds"])[0] == "project_id" + + assert updated_group.id == "group_id" + assert updated_group.name == "Test Group" + assert updated_group.color == UserGroupColor.BLUE + assert len(updated_group.users) == 1 + assert list(updated_group.users)[0].id == "user_id" + assert len(updated_group.projects) == 1 + assert list(updated_group.projects)[0].id == "project_id" + + def test_create_with_exception(self): + group = self.group + group.id = "group_id" + + with pytest.raises(Exception): + group.create() + + def test_create(self): + group = self.group + group.name = "New Group" + group.color = UserGroupColor.PINK + group.users = {UserGroupUser(id="user_id", email="test@example.com")} + group.projects = {UserGroupProject(id="project_id", name="Test Project")} + + self.client.execute.return_value = { "createUserGroup": { "group": { "id": "group_id" } } } + created_group = group.create() + execute = self.client.execute.call_args[0] + + assert "CreateUserGroupPyApi" in execute[0] + assert execute[1]["name"] == "New Group" + assert execute[1]["color"] == UserGroupColor.PINK.value + assert len(execute[1]["userIds"]) == 1 + assert list(execute[1]["userIds"])[0] == "user_id" + assert len(execute[1]["projectIds"]) == 1 + assert list(execute[1]["projectIds"])[0] == "project_id" + assert created_group.id is not None + assert created_group.id == "group_id" + assert created_group.name == "New Group" + assert created_group.color == UserGroupColor.PINK + assert len(created_group.users) == 1 + assert list(created_group.users)[0].id == "user_id" + assert len(created_group.projects) == 1 + assert list(created_group.projects)[0].id == "project_id" + + def test_delete(self): + group = self.group + group.id = "group_id" + + self.client.execute.return_value = { "deleteUserGroup": { "success": True } } + deleted = group.delete() + execute = self.client.execute.call_args[0] + + assert "DeleteUserGroupPyApi" in execute[0] + assert execute[1]["id"] == "group_id" + assert deleted is True + + def test_user_groups(self): + self.client.execute.return_value = { + "userGroups": { + "nodes": [ + { + "id": "group_id_1", + "name": "Group 1", + "color": "9EC5FF", + "projects": { + "nodes": [ + {"id": "project_id_1", "name": "Project 1"}, + {"id": "project_id_2", "name": "Project 2"} + ] + }, + "members": { + "nodes": [ + {"id": "user_id_1", "email": "user1@example.com"}, + {"id": "user_id_2", "email": "user2@example.com"} + ] + } + }, + { + "id": "group_id_2", + "name": "Group 2", + "color": "9EC5FF", + "projects": { + "nodes": [ + {"id": "project_id_3", "name": "Project 3"}, + {"id": "project_id_4", "name": "Project 4"} + ] + }, + "members": { + "nodes": [ + {"id": "user_id_3", "email": "user3@example.com"}, + {"id": "user_id_4", "email": "user4@example.com"} + ] + } + }, + { + "id": "group_id_3", + "name": "Group 3", + "color": "9EC5FF", + "projects": { + "nodes": [ + {"id": "project_id_5", "name": "Project 5"}, + {"id": "project_id_6", "name": "Project 6"} + ] + }, + "members": { + "nodes": [ + {"id": "user_id_5", "email": "user5@example.com"}, + {"id": "user_id_6", "email": "user6@example.com"} + ] + } + } + ] + } + } + + user_groups = list(UserGroup.user_groups(self.client)) + + assert len(user_groups) == 3 + + # Check the attributes of the first user group + assert user_groups[0].id == "group_id_1" + assert user_groups[0].name == "Group 1" + assert user_groups[0].color == UserGroupColor.BLUE + assert len(user_groups[0].projects) == 2 + assert len(user_groups[0].users) == 2 + + # Check the attributes of the second user group + assert user_groups[1].id == "group_id_2" + assert user_groups[1].name == "Group 2" + assert user_groups[1].color == UserGroupColor.BLUE + assert len(user_groups[1].projects) == 2 + assert len(user_groups[1].users) == 2 + + # Check the attributes of the third user group + assert user_groups[2].id == "group_id_3" + assert user_groups[2].name == "Group 3" + assert user_groups[2].color == UserGroupColor.BLUE + assert len(user_groups[2].projects) == 2 + assert len(user_groups[2].users) == 2 + +if __name__ == "__main__": + pytest.main(["-v", __file__]) \ No newline at end of file From b05e2a3aea69e783ba2a577a567322a4116230a6 Mon Sep 17 00:00:00 2001 From: Adrian Chang Date: Tue, 4 Jun 2024 17:50:48 -0700 Subject: [PATCH 3/7] user groups done --- libs/labelbox/src/labelbox/orm/model.py | 1 - .../src/labelbox/schema/user_group.py | 109 +++++---- .../integration/schema/test_user_group.py | 114 ++++++++++ .../tests/unit/schema/test_user_group.py | 210 ++++++++++-------- pyproject.toml | 1 + requirements-dev.lock | 2 + 6 files changed, 301 insertions(+), 136 deletions(-) create mode 100644 libs/labelbox/tests/integration/schema/test_user_group.py diff --git a/libs/labelbox/src/labelbox/orm/model.py b/libs/labelbox/src/labelbox/orm/model.py index 2e578fae2..f3afa174e 100644 --- a/libs/labelbox/src/labelbox/orm/model.py +++ b/libs/labelbox/src/labelbox/orm/model.py @@ -382,7 +382,6 @@ class Entity(metaclass=EntityMeta): CatalogSlice: Type[labelbox.CatalogSlice] ModelSlice: Type[labelbox.ModelSlice] TaskQueue: Type[labelbox.TaskQueue] - UserGroup: Type[labelbox.UserGroup] @classmethod def _attributes_of_type(cls, attr_type): diff --git a/libs/labelbox/src/labelbox/schema/user_group.py b/libs/labelbox/src/labelbox/schema/user_group.py index 464adf330..4d324cb34 100644 --- a/libs/labelbox/src/labelbox/schema/user_group.py +++ b/libs/labelbox/src/labelbox/schema/user_group.py @@ -1,7 +1,6 @@ from enum import Enum from typing import Set, List, Optional, Union, TypedDict, Iterator - from labelbox import Client from labelbox.exceptions import ResourceCreationError from labelbox.pydantic_compat import BaseModel @@ -82,7 +81,7 @@ def __eq__(self, other): if not isinstance(other, UserGroupProject): return False return self.id == other.id - + class UserGroupParameters(TypedDict): """ @@ -97,7 +96,7 @@ class UserGroupParameters(TypedDict): """ id: Optional[str] name: Optional[str] - color: Optional[UserGroupColor] + color: Optional[UserGroupColor] users: Optional[Set[Union[UserGroupUser, User]]] projects: Optional[Set[Union[UserGroupProject, Project]]] @@ -123,9 +122,12 @@ class UserGroup: _color: UserGroupColor = None _users: Set[Union[UserGroupUser, User]] = None _projects: Set[Union[UserGroupProject, Project]] = None - _client: Client + _client: Client - def __init__(self, client: Client, reload=True, **kwargs: UserGroupParameters): + def __init__(self, + client: Client, + reload=True, + **kwargs: UserGroupParameters): """ Initializes a Group object. @@ -141,10 +143,9 @@ def __init__(self, client: Client, reload=True, **kwargs: UserGroupParameters): # runs against _gql if not client.enable_experimental: - raise RuntimeError("Experimental features are not enabled. Please enable them in the client to use this feature.") - - if 'id' not in kwargs and 'name' not in kwargs: - raise ValueError("Either 'id' or 'name' must be provided") + raise RuntimeError( + "Experimental features are not enabled. Please enable them in the client to use this feature." + ) self.name = kwargs.get('name', None) self.id = kwargs.get('id', None) @@ -192,8 +193,14 @@ def _reload(self): result = self.client.execute(query, params) self.name = result["userGroup"]["name"] self.color = UserGroupColor(result["userGroup"]["color"]) - self.projects = {UserGroupProject(id=project["id"], name=project["name"]) for project in result["userGroup"]["projects"]["nodes"]} - self.users = {UserGroupUser(id=member["id"], email=member["email"]) for member in result["userGroup"]["members"]["nodes"]} + self.projects = { + UserGroupProject(id=project["id"], name=project["name"]) + for project in result["userGroup"]["projects"]["nodes"] + } + self.users = { + UserGroupUser(id=member["id"], email=member["email"]) + for member in result["userGroup"]["members"]["nodes"] + } @property def id(self) -> str: @@ -326,11 +333,20 @@ def update(self) -> "UserGroup": } """ params = { - "id": self.id, - "name": self.name, - "color": self.color.value, - "projectIds": [project.id for project in self.projects], - "userIds": [user.id for user in self.users] + "id": + self.id, + "name": + self.name, + "color": + self.color.value, + "projectIds": [ + project.id if hasattr(project, 'id') else project.uid + for project in self.projects + ], + "userIds": [ + user.id if hasattr(user, 'id') else user.uid + for user in self.users + ] } self.client.execute(query, params) return self @@ -382,10 +398,18 @@ def create(self) -> "UserGroup": } """ params = { - "name": self.name, - "color": self.color.value, - "projectIds": [project.id for project in self.projects], - "userIds": [user.id for user in self.users] + "name": + self.name, + "color": + self.color.value, + "projectIds": [ + project.id if hasattr(project, 'id') else project.uid + for project in self.projects + ], + "userIds": [ + user.id if hasattr(user, 'id') else user.uid + for user in self.users + ] } result = self.client.execute(query, params)["createUserGroup"]["group"] self.id = result["id"] @@ -405,14 +429,12 @@ def delete(self) -> bool: } } """ - params = { - "id": self.id - } + params = {"id": self.id} result = self.client.execute(query, params) return result["deleteUserGroup"]["success"] - - @staticmethod - def user_groups(client: Client) -> Iterator["UserGroup"]: + + @staticmethod + def get_user_groups(client: Client) -> Iterator["UserGroup"]: """ Gets all groups in Labelbox. @@ -450,19 +472,26 @@ def user_groups(client: Client) -> Iterator["UserGroup"]: """ nextCursor = None while True: - userGroups = client.execute(query, { "nextCursor": nextCursor })["userGroups"] + userGroups = client.execute( + query, {"nextCursor": nextCursor})["userGroups"] groups = userGroups["nodes"] for group in groups: - yield UserGroup( - client, - reload=False, - id=group["id"], - name=group["name"], - color=UserGroupColor(group["color"]), - users={UserGroupUser(id=member["id"], email=member["email"]) for member in group["members"]["nodes"]}, - projects={UserGroupProject(id=project["id"], name=project["name"]) for project in group["projects"]["nodes"]} - ) - nextCursor = userGroups.get("nextCursor", None) - # this doesn't seem to be used right now - if nextCursor is None: - break \ No newline at end of file + yield UserGroup(client, + reload=False, + id=group["id"], + name=group["name"], + color=UserGroupColor(group["color"]), + users={ + UserGroupUser(id=member["id"], + email=member["email"]) + for member in group["members"]["nodes"] + }, + projects={ + UserGroupProject(id=project["id"], + name=project["name"]) + for project in group["projects"]["nodes"] + }) + nextCursor = userGroups["nextCursor"] + # this doesn't seem to be implemented right now to return a value other than null from the api + if nextCursor: + break diff --git a/libs/labelbox/tests/integration/schema/test_user_group.py b/libs/labelbox/tests/integration/schema/test_user_group.py new file mode 100644 index 000000000..175acd403 --- /dev/null +++ b/libs/labelbox/tests/integration/schema/test_user_group.py @@ -0,0 +1,114 @@ +import pytest +import faker +from labelbox import Client +from labelbox.schema.user_group import UserGroup, UserGroupColor, UserGroupUser, UserGroupProject + +data = faker.Faker() + + +@pytest.fixture +def client(client): + client.enable_experimental = True + client.endpoint = "https://app.lb-stage.xyz/api/_gql/" + return client + + +class TestUserGroup: + + def test_existing_user_groups(self, client): + group_name = data.name() + # Create a new user group + user_group = UserGroup(client) + user_group.name = group_name + user_group.color = UserGroupColor.BLUE + user_group.create() + + # Verify that the user group was created successfully + user_group_equal = UserGroup(client, id=user_group.id) + assert user_group.id == user_group_equal.id + assert user_group.name == user_group_equal.name + assert user_group.color == user_group_equal.color + + user_group.delete() + + def test_create_user_group(self, client): + group_name = data.name() + # Create a new user group + user_group = UserGroup(client) + user_group.name = group_name + user_group.color = UserGroupColor.BLUE + user_group.create() + + # Verify that the user group was created successfully + assert user_group.id is not None + assert user_group.name == group_name + assert user_group.color == UserGroupColor.BLUE + + user_group.delete() + + def test_update_user_group(self, client): + # Create a new user group + group_name = data.name() + user_group = UserGroup(client) + user_group.name = group_name + user_group.create() + + # Update the user group + group_name = data.name() + user_group.name = group_name + user_group.color = UserGroupColor.PURPLE + user_group.update() + + # Verify that the user group was updated successfully + assert user_group.name == group_name + assert user_group.color == UserGroupColor.PURPLE + + user_group.delete() + + def test_get_user_groups(self, client): + # Get all user groups + user_groups_old = list(UserGroup.get_user_groups(client)) + + group_name = data.name() + user_group = UserGroup(client) + user_group.name = group_name + user_group.create() + + user_groups_new = list(UserGroup.get_user_groups(client)) + + # Verify that at least one user group is returned + assert len(user_groups_new) > 0 + assert len(user_groups_new) == len(user_groups_old) + 1 + + # Verify that each user group has a valid ID and name + for user_group in user_groups_new: + assert user_group.id is not None + assert user_group.name is not None + + user_group.delete() + + # project_pack creates two projects + def test_update_user_groups(self, client, project_pack): + # Create a new user group + user_group = UserGroup(client) + user_group.name = data.name() + user_group.create() + + users = list(client.get_users()) + projects = project_pack + + # Add the user to the group + user_group.users.add(users[0]) + user_group.projects.add(projects[0]) + user_group.update() + + # Verify that the user is added to the group + assert users[0] in user_group.users + assert projects[0] in user_group.projects + + user_group.delete() + + +if __name__ == "__main__": + import subprocess + subprocess.call(["pytest", "-v", __file__]) diff --git a/libs/labelbox/tests/unit/schema/test_user_group.py b/libs/labelbox/tests/unit/schema/test_user_group.py index 4ee62f485..0e5194ba5 100644 --- a/libs/labelbox/tests/unit/schema/test_user_group.py +++ b/libs/labelbox/tests/unit/schema/test_user_group.py @@ -61,8 +61,7 @@ def test_user_group_parameters_attributes(self): name="Test Group", color=UserGroupColor.BLUE, users={UserGroupUser(id="user_id", email="test@example.com")}, - projects={UserGroupProject(id="project_id", name="Test Project")} - ) + projects={UserGroupProject(id="project_id", name="Test Project")}) assert params["id"] == "group_id" assert params["name"] == "Test Group" @@ -86,37 +85,26 @@ def test_constructor_experimental_needed(self): with pytest.raises(RuntimeError): group = UserGroup(client) - def test_constructor_id_or_name_needed(self): - client = MagicMock(Client) - with pytest.raises(ValueError): - group = UserGroup(self.client) - def test_constructor_name(self): group = self.group assert group.name == "Test Group" assert group.color == UserGroupColor.BLUE def test_constructor_id_no_reload(self): - projects = [ - { - "id": "project_id_1", - "name": "project_1" - }, - { - "id": "project_id_2", - "name": "project_2" - } - ] - group_members = [ - { - "id": "user_id_1", - "email": "email_1" - }, - { - "id": "user_id_2", - "email": "email_2" - } - ] + projects = [{ + "id": "project_id_1", + "name": "project_1" + }, { + "id": "project_id_2", + "name": "project_2" + }] + group_members = [{ + "id": "user_id_1", + "email": "email_1" + }, { + "id": "user_id_2", + "email": "email_2" + }] self.client.execute.return_value = { "userGroup": { "id": "group_id", @@ -140,26 +128,20 @@ def test_constructor_id_no_reload(self): assert len(group.users) == 0 def test_constructor_id(self): - projects = [ - { - "id": "project_id_1", - "name": "project_1" - }, - { - "id": "project_id_2", - "name": "project_2" - } - ] - group_members = [ - { - "id": "user_id_1", - "email": "email_1" - }, - { - "id": "user_id_2", - "email": "email_2" - } - ] + projects = [{ + "id": "project_id_1", + "name": "project_1" + }, { + "id": "project_id_2", + "name": "project_2" + }] + group_members = [{ + "id": "user_id_1", + "email": "email_1" + }, { + "id": "user_id_2", + "email": "email_2" + }] self.client.execute.return_value = { "userGroup": { "id": "group_id", @@ -217,7 +199,10 @@ def test_users(self): group.users = {UserGroupUser(id="user_id", email="user_id@email")} assert len(group.users) == 1 - group.users = {UserGroupUser(id="user_id", email="user_id@email"), UserGroupUser(id="user_id", email="user_id@email")} + group.users = { + UserGroupUser(id="user_id", email="user_id@email"), + UserGroupUser(id="user_id", email="user_id@email") + } assert len(group.users) == 1 group.users = {} @@ -227,10 +212,15 @@ def test_projects(self): group = self.group assert len(group.projects) == 0 - group.projects = {UserGroupProject(id="project_id", name="Test Project")} + group.projects = { + UserGroupProject(id="project_id", name="Test Project") + } assert len(group.projects) == 1 - group.projects = {UserGroupProject(id="project_id", name="Test Project"), UserGroupProject(id="project_id", name="Test Project")} + group.projects = { + UserGroupProject(id="project_id", name="Test Project"), + UserGroupProject(id="project_id", name="Test Project") + } assert len(group.projects) == 1 group.projects = {} @@ -242,12 +232,14 @@ def test_update(self): group.name = "Test Group" group.color = UserGroupColor.BLUE group.users = {UserGroupUser(id="user_id", email="test@example.com")} - group.projects = {UserGroupProject(id="project_id", name="Test Project")} - + group.projects = { + UserGroupProject(id="project_id", name="Test Project") + } + updated_group = group.update() execute = self.client.execute.call_args[0] - + assert "UpdateUserGroupPyApi" in execute[0] assert execute[1]["id"] == "group_id" assert execute[1]["name"] == "Test Group" @@ -268,18 +260,26 @@ def test_update(self): def test_create_with_exception(self): group = self.group group.id = "group_id" - + with pytest.raises(Exception): group.create() - + def test_create(self): group = self.group group.name = "New Group" group.color = UserGroupColor.PINK group.users = {UserGroupUser(id="user_id", email="test@example.com")} - group.projects = {UserGroupProject(id="project_id", name="Test Project")} + group.projects = { + UserGroupProject(id="project_id", name="Test Project") + } - self.client.execute.return_value = { "createUserGroup": { "group": { "id": "group_id" } } } + self.client.execute.return_value = { + "createUserGroup": { + "group": { + "id": "group_id" + } + } + } created_group = group.create() execute = self.client.execute.call_args[0] @@ -302,11 +302,15 @@ def test_create(self): def test_delete(self): group = self.group group.id = "group_id" - - self.client.execute.return_value = { "deleteUserGroup": { "success": True } } + + self.client.execute.return_value = { + "deleteUserGroup": { + "success": True + } + } deleted = group.delete() execute = self.client.execute.call_args[0] - + assert "DeleteUserGroupPyApi" in execute[0] assert execute[1]["id"] == "group_id" assert deleted is True @@ -314,63 +318,77 @@ def test_delete(self): def test_user_groups(self): self.client.execute.return_value = { "userGroups": { - "nodes": [ - { + "nodes": [{ "id": "group_id_1", "name": "Group 1", "color": "9EC5FF", "projects": { - "nodes": [ - {"id": "project_id_1", "name": "Project 1"}, - {"id": "project_id_2", "name": "Project 2"} - ] + "nodes": [{ + "id": "project_id_1", + "name": "Project 1" + }, { + "id": "project_id_2", + "name": "Project 2" + }] }, "members": { - "nodes": [ - {"id": "user_id_1", "email": "user1@example.com"}, - {"id": "user_id_2", "email": "user2@example.com"} - ] + "nodes": [{ + "id": "user_id_1", + "email": "user1@example.com" + }, { + "id": "user_id_2", + "email": "user2@example.com" + }] } - }, - { + }, { "id": "group_id_2", "name": "Group 2", "color": "9EC5FF", "projects": { - "nodes": [ - {"id": "project_id_3", "name": "Project 3"}, - {"id": "project_id_4", "name": "Project 4"} - ] + "nodes": [{ + "id": "project_id_3", + "name": "Project 3" + }, { + "id": "project_id_4", + "name": "Project 4" + }] }, "members": { - "nodes": [ - {"id": "user_id_3", "email": "user3@example.com"}, - {"id": "user_id_4", "email": "user4@example.com"} - ] + "nodes": [{ + "id": "user_id_3", + "email": "user3@example.com" + }, { + "id": "user_id_4", + "email": "user4@example.com" + }] } - }, - { + }, { "id": "group_id_3", "name": "Group 3", "color": "9EC5FF", "projects": { - "nodes": [ - {"id": "project_id_5", "name": "Project 5"}, - {"id": "project_id_6", "name": "Project 6"} - ] + "nodes": [{ + "id": "project_id_5", + "name": "Project 5" + }, { + "id": "project_id_6", + "name": "Project 6" + }] }, "members": { - "nodes": [ - {"id": "user_id_5", "email": "user5@example.com"}, - {"id": "user_id_6", "email": "user6@example.com"} - ] + "nodes": [{ + "id": "user_id_5", + "email": "user5@example.com" + }, { + "id": "user_id_6", + "email": "user6@example.com" + }] } - } - ] + }] } } - - user_groups = list(UserGroup.user_groups(self.client)) + + user_groups = list(UserGroup.get_user_groups(self.client)) assert len(user_groups) == 3 @@ -395,5 +413,7 @@ def test_user_groups(self): assert len(user_groups[2].projects) == 2 assert len(user_groups[2].users) == 2 + if __name__ == "__main__": - pytest.main(["-v", __file__]) \ No newline at end of file + import subprocess + subprocess.call(["pytest", "-v", __file__]) diff --git a/pyproject.toml b/pyproject.toml index 238a6e247..ca5e7d0d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dev-dependencies = [ "pytest-cov>=4.1.0", "pytest-xdist>=3.5.0", "toml-cli>=0.6.0", + "faker>=25.5.0", ] [tool.rye.workspace] diff --git a/requirements-dev.lock b/requirements-dev.lock index 288f35a94..1fc9bc5a6 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -59,6 +59,7 @@ execnet==2.1.1 # via pytest-xdist executing==2.0.1 # via stack-data +faker==25.5.0 fastjsonschema==2.19.1 # via nbformat geojson==3.1.0 @@ -206,6 +207,7 @@ pytest-rerunfailures==14.0 pytest-snapshot==0.9.0 pytest-xdist==3.6.1 python-dateutil==2.8.2 + # via faker # via jupyter-client # via labelbox # via pandas From d6312d6d329cabc60eb434be662783c73b26a5d2 Mon Sep 17 00:00:00 2001 From: Adrian Chang Date: Fri, 7 Jun 2024 01:18:18 -0700 Subject: [PATCH 4/7] Deal with none executes --- docs/labelbox/user-group.rst | 6 + libs/labelbox/src/labelbox/__init__.py | 3 +- libs/labelbox/src/labelbox/pydantic_compat.py | 2 +- .../src/labelbox/schema/user_group.py | 124 +++++++++++------- .../integration/schema/test_user_group.py | 9 +- .../tests/unit/schema/test_user_group.py | 46 ++++--- requirements-dev.lock | 42 +++--- requirements.lock | 28 ++-- 8 files changed, 141 insertions(+), 119 deletions(-) create mode 100644 docs/labelbox/user-group.rst diff --git a/docs/labelbox/user-group.rst b/docs/labelbox/user-group.rst new file mode 100644 index 000000000..66d56891f --- /dev/null +++ b/docs/labelbox/user-group.rst @@ -0,0 +1,6 @@ +User Group +=============================================================================================== + +.. automodule:: labelbox.schema.user_group + :members: + :show-inheritance: \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/__init__.py b/libs/labelbox/src/labelbox/__init__.py index 2157c0e62..3bc3f4f06 100644 --- a/libs/labelbox/src/labelbox/__init__.py +++ b/libs/labelbox/src/labelbox/__init__.py @@ -40,5 +40,4 @@ from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.ontology_kind import OntologyKind -from labelbox.schema.project_overview import ProjectOverview, ProjectOverviewDetailed -from labelbox.schema.user_group import UserGroup, UserGroupColor, UserGroupUser, UserGroupProject \ No newline at end of file +from labelbox.schema.project_overview import ProjectOverview, ProjectOverviewDetailed \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/pydantic_compat.py b/libs/labelbox/src/labelbox/pydantic_compat.py index 4bcece74e..51c082480 100644 --- a/libs/labelbox/src/labelbox/pydantic_compat.py +++ b/libs/labelbox/src/labelbox/pydantic_compat.py @@ -31,4 +31,4 @@ def pydantic_import(class_name, sub_module_path: Optional[str] = None): conint = pydantic_import("conint") conlist = pydantic_import("conlist") constr = pydantic_import("constr") -confloat = pydantic_import("confloat") \ No newline at end of file +confloat = pydantic_import("confloat") diff --git a/libs/labelbox/src/labelbox/schema/user_group.py b/libs/labelbox/src/labelbox/schema/user_group.py index 4d324cb34..27bdacf27 100644 --- a/libs/labelbox/src/labelbox/schema/user_group.py +++ b/libs/labelbox/src/labelbox/schema/user_group.py @@ -6,6 +6,7 @@ from labelbox.pydantic_compat import BaseModel from labelbox.schema.user import User from labelbox.schema.project import Project +from labelbox.exceptions import UnprocessableEntityError, InvalidQueryError class UserGroupColor(Enum): @@ -83,24 +84,6 @@ def __eq__(self, other): return self.id == other.id -class UserGroupParameters(TypedDict): - """ - Represents the parameters for a user group. - - Attributes: - id (Optional[str]): The ID of the user group. - name (Optional[str]): The name of the user group. - color (Optional[UserGroupColor]): The color of the user group. - users (Optional[Set[Union[UserGroupUser, User]]]): The users in the user group. - projects (Optional[Set[Union[UserGroupProject, Project]]]): The projects associated with the user group. - """ - id: Optional[str] - name: Optional[str] - color: Optional[UserGroupColor] - users: Optional[Set[Union[UserGroupUser, User]]] - projects: Optional[Set[Union[UserGroupProject, Project]]] - - class UserGroup: """ Represents a user group in Labelbox. @@ -117,40 +100,49 @@ class UserGroup: _projects (Set[Union[UserGroupProject, Project]]): The set of project IDs in the user group. _client (Client): The Labelbox client. """ - _id: str = None - _name: str = None - _color: UserGroupColor = None - _users: Set[Union[UserGroupUser, User]] = None - _projects: Set[Union[UserGroupProject, Project]] = None + _id: str + _name: str + _color: UserGroupColor + _users: Set[Union[UserGroupUser, User]] + _projects: Set[Union[UserGroupProject, Project]] _client: Client - def __init__(self, - client: Client, - reload=True, - **kwargs: UserGroupParameters): + def __init__( + self, + client: Client, + id: str = "", + name: str = "", + color: UserGroupColor = UserGroupColor.BLUE, + users: Set[Union[UserGroupUser, User]] = set(), + projects: Set[Union[UserGroupProject, Project]] = set(), + reload=True, + ): """ Initializes a Group object. Args: client (Client): The Labelbox client. - **kwargs: Additional keyword arguments for initializing the Group object. + reload (bool): Whether to reload the group information from the server. Defaults to True. + color (UserGroupColor): The color of the user group. Defaults to UserGroupColor.BLUE. + users (Set[Union[UserGroupUser, User]]): The set of user IDs in the user group. Defaults to an empty set. + projects (Set[Union[UserGroupProject, Project]]): The set of project IDs in the user group. Defaults to an empty set. + name (str): The name of the user group. Defaults to None. + id (str): The ID of the user group. Defaults to None. """ super().__init__() - self.color = kwargs.get('color', UserGroupColor.BLUE) - self.users = kwargs.get('users', set()) - self.projects = kwargs.get('projects', set()) + self.color = color + self.users = users + self.projects = projects self.client = client - # runs against _gql if not client.enable_experimental: raise RuntimeError( - "Experimental features are not enabled. Please enable them in the client to use this feature." - ) + "Please enable experimental in client to use UserGroups") - self.name = kwargs.get('name', None) - self.id = kwargs.get('id', None) - # partial respentation of the group, reload - if self.id is not None and reload: + self.name = name + self.id = id + # partial representation of the group, reload + if self.id and reload: self._reload() def _reload(self): @@ -161,6 +153,9 @@ def _reload(self): about the user group, including its name, color, projects, and members. The fetched information is then used to update the corresponding attributes of the `Group` object. + Raises: + InvalidQueryError: If the query fails to fetch the group information. + Returns: None """ @@ -191,6 +186,8 @@ def _reload(self): "id": self.id, } result = self.client.execute(query, params) + if not result: + raise InvalidQueryError("Failed to fetch group") self.name = result["userGroup"]["name"] self.color = UserGroupColor(result["userGroup"]["color"]) self.projects = { @@ -283,7 +280,7 @@ def users(self, value: Set[Union[UserGroupUser, User]]) -> None: self._users = value @property - def projects(self) -> Set[UserGroupProject]: + def projects(self) -> Set[Union[UserGroupProject, Project]]: """ Gets the list of project IDs in the group. @@ -293,7 +290,7 @@ def projects(self) -> Set[UserGroupProject]: return self._projects @projects.setter - def projects(self, value: Set[UserGroupProject]) -> None: + def projects(self, value: Set[Union[UserGroupProject, Project]]) -> None: """ Sets the list of project IDs in the group. @@ -305,6 +302,12 @@ def projects(self, value: Set[UserGroupProject]) -> None: def update(self) -> "UserGroup": """ Updates the group in Labelbox. + + Returns: + UserGroup: The updated UserGroup object. (self) + + Raises: + UnprocessableEntityError: If the update fails. """ query = """ mutation UpdateUserGroupPyApi($id: ID!, $name: String!, $color: String!, $projectIds: [String!]!, $userIds: [String!]!) { @@ -348,7 +351,9 @@ def update(self) -> "UserGroup": for user in self.users ] } - self.client.execute(query, params) + result = self.client.execute(query, params) + if not result: + raise UnprocessableEntityError("Failed to update group") return self def create(self) -> "UserGroup": @@ -363,10 +368,16 @@ def create(self) -> "UserGroup": projects (List[Project], optional): The projects to add to the group. Defaults to []. Returns: - Group: The newly created group. + Group: The newly created group. (self) + + Raises: + ResourceCreationError: If the group already exists or if the creation fails. + ValueError: If the group name is not provided. """ - if self.id is not None: + if self.id: raise ResourceCreationError("Group already exists") + if not self.name: + raise ValueError("Group name is required") query = """ mutation CreateUserGroupPyApi($name: String!, $color: String!, $projectIds: [String!]!, $userIds: [String!]!) { createUserGroup( @@ -411,16 +422,26 @@ def create(self) -> "UserGroup": for user in self.users ] } - result = self.client.execute(query, params)["createUserGroup"]["group"] + result = self.client.execute(query, params) + if not result: + raise ResourceCreationError("Failed to create group") + result = result["createUserGroup"]["group"] self.id = result["id"] return self def delete(self) -> bool: """ - Deletes the group from Labelbox. + Deletes the user group from Labelbox. + + This method sends a mutation request to the Labelbox API to delete the user group + with the specified ID. If the deletion is successful, it returns True. Otherwise, + it raises an UnprocessableEntityError and returns False. Returns: - bool: True if the group was successfully deleted, False otherwise. + bool: True if the user group was successfully deleted, False otherwise. + + Raises: + UnprocessableEntityError: If the deletion of the user group fails. """ query = """ mutation DeleteUserGroupPyApi($id: ID!) { @@ -431,18 +452,20 @@ def delete(self) -> bool: """ params = {"id": self.id} result = self.client.execute(query, params) + if not result: + raise UnprocessableEntityError("Failed to delete user group") return result["deleteUserGroup"]["success"] @staticmethod def get_user_groups(client: Client) -> Iterator["UserGroup"]: """ - Gets all groups in Labelbox. + Gets all user groups in Labelbox. Args: client (Client): The Labelbox client. Returns: - List[Group]: The list of groups. + Iterator[UserGroup]: An iterator over the user groups. """ query = """ query GetUserGroupsPyApi { @@ -474,6 +497,9 @@ def get_user_groups(client: Client) -> Iterator["UserGroup"]: while True: userGroups = client.execute( query, {"nextCursor": nextCursor})["userGroups"] + if not userGroups: + return + yield groups = userGroups["nodes"] for group in groups: yield UserGroup(client, @@ -493,5 +519,5 @@ def get_user_groups(client: Client) -> Iterator["UserGroup"]: }) nextCursor = userGroups["nextCursor"] # this doesn't seem to be implemented right now to return a value other than null from the api - if nextCursor: + if not nextCursor: break diff --git a/libs/labelbox/tests/integration/schema/test_user_group.py b/libs/labelbox/tests/integration/schema/test_user_group.py index 175acd403..15f4f6738 100644 --- a/libs/labelbox/tests/integration/schema/test_user_group.py +++ b/libs/labelbox/tests/integration/schema/test_user_group.py @@ -6,13 +6,6 @@ data = faker.Faker() -@pytest.fixture -def client(client): - client.enable_experimental = True - client.endpoint = "https://app.lb-stage.xyz/api/_gql/" - return client - - class TestUserGroup: def test_existing_user_groups(self, client): @@ -88,7 +81,7 @@ def test_get_user_groups(self, client): user_group.delete() # project_pack creates two projects - def test_update_user_groups(self, client, project_pack): + def test_update_user_group(self, client, project_pack): # Create a new user group user_group = UserGroup(client) user_group.name = data.name() diff --git a/libs/labelbox/tests/unit/schema/test_user_group.py b/libs/labelbox/tests/unit/schema/test_user_group.py index 0e5194ba5..03514fa4a 100644 --- a/libs/labelbox/tests/unit/schema/test_user_group.py +++ b/libs/labelbox/tests/unit/schema/test_user_group.py @@ -1,8 +1,9 @@ import pytest from unittest.mock import MagicMock from labelbox import Client +from labelbox.exceptions import ResourceCreationError from labelbox.schema.user import User -from labelbox.schema.user_group import UserGroup, UserGroupColor, UserGroupUser, UserGroupProject, UserGroupParameters +from labelbox.schema.user_group import UserGroup, UserGroupColor, UserGroupUser, UserGroupProject class TestUserGroupColor: @@ -53,25 +54,6 @@ def test_user_group_project_hash(self): assert hash(project) == hash("project_id") -class TestUserGroupParameters: - - def test_user_group_parameters_attributes(self): - params = UserGroupParameters( - id="group_id", - name="Test Group", - color=UserGroupColor.BLUE, - users={UserGroupUser(id="user_id", email="test@example.com")}, - projects={UserGroupProject(id="project_id", name="Test Project")}) - - assert params["id"] == "group_id" - assert params["name"] == "Test Group" - assert params["color"] == UserGroupColor.BLUE - assert len(params["users"]) == 1 - assert list(params["users"])[0].id == "user_id" - assert len(params["projects"]) == 1 - assert list(params["projects"])[0].id == "project_id" - - class TestUserGroup: def setup_method(self): @@ -122,7 +104,7 @@ def test_constructor_id_no_reload(self): group = UserGroup(self.client, id="group_id", reload=False) assert group.id == "group_id" - assert group.name is None + assert group.name == "" assert group.color is UserGroupColor.BLUE assert len(group.projects) == 0 assert len(group.users) == 0 @@ -164,7 +146,7 @@ def test_constructor_id(self): def test_id(self): group = self.group - assert group.id is None + assert group.id == "" group.id = "1" assert group.id == "1" @@ -257,11 +239,18 @@ def test_update(self): assert len(updated_group.projects) == 1 assert list(updated_group.projects)[0].id == "project_id" - def test_create_with_exception(self): + def test_create_with_exception_id(self): group = self.group group.id = "group_id" - with pytest.raises(Exception): + with pytest.raises(ResourceCreationError): + group.create() + + def test_create_with_exception_name(self): + group = self.group + group.name = "" + + with pytest.raises(ValueError): group.create() def test_create(self): @@ -315,9 +304,18 @@ def test_delete(self): assert execute[1]["id"] == "group_id" assert deleted is True + def test_user_groups_empty(self): + self.client.execute.return_value = {"userGroups": None} + + user_groups = list(UserGroup.get_user_groups(self.client)) + + assert len(user_groups) == 0 + def test_user_groups(self): self.client.execute.return_value = { "userGroups": { + "nextCursor": + None, "nodes": [{ "id": "group_id_1", "name": "Group 1", diff --git a/requirements-dev.lock b/requirements-dev.lock index 1fc9bc5a6..60a73324c 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -31,7 +31,7 @@ bleach==6.1.0 # via nbconvert cachetools==5.3.3 # via google-auth -certifi==2024.2.2 +certifi==2024.6.2 # via pyproj # via requests charset-normalizer==3.3.2 @@ -41,7 +41,7 @@ click==8.1.7 # via typer commonmark==0.9.1 # via rich -coverage==7.5.3 +coverage==7.5.4 # via pytest-cov databooks==1.3.10 decopatch==1.4.10 @@ -59,8 +59,8 @@ execnet==2.1.1 # via pytest-xdist executing==2.0.1 # via stack-data -faker==25.5.0 -fastjsonschema==2.19.1 +faker==25.9.1 +fastjsonschema==2.20.0 # via nbformat geojson==3.1.0 # via labelbox @@ -68,18 +68,18 @@ gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via databooks -google-api-core==2.19.0 +google-api-core==2.19.1 # via labelbox -google-auth==2.29.0 +google-auth==2.30.0 # via google-api-core -googleapis-common-protos==1.63.0 +googleapis-common-protos==1.63.2 # via google-api-core idna==3.7 # via requests imagesize==1.4.1 # via labelbox # via sphinx -importlib-metadata==7.1.0 +importlib-metadata==7.2.1 # via jupyter-client # via nbconvert # via sphinx @@ -137,7 +137,7 @@ numpy==1.24.4 # via shapely opencv-python-headless==4.10.0.84 # via labelbox -packaging==24.0 +packaging==24.1 # via black # via nbconvert # via pytest @@ -165,11 +165,11 @@ platformdirs==4.2.2 # via yapf pluggy==1.5.0 # via pytest -prompt-toolkit==3.0.45 +prompt-toolkit==3.0.47 # via ipython -proto-plus==1.23.0 +proto-plus==1.24.0 # via google-api-core -protobuf==4.25.3 +protobuf==5.27.1 # via google-api-core # via googleapis-common-protos # via proto-plus @@ -182,10 +182,10 @@ pyasn1==0.6.0 # via rsa pyasn1-modules==0.4.0 # via google-auth -pydantic==2.7.2 +pydantic==2.7.4 # via databooks # via labelbox -pydantic-core==2.18.3 +pydantic-core==2.18.4 # via pydantic pygeotile==1.0.6 # via labelbox @@ -196,7 +196,7 @@ pygments==2.18.0 # via sphinx pyproj==3.5.0 # via labelbox -pytest==8.2.1 +pytest==8.2.2 # via pytest-cov # via pytest-rerunfailures # via pytest-snapshot @@ -221,7 +221,7 @@ referencing==0.35.1 # via jsonschema-specifications regex==2024.5.15 # via toml-cli -requests==2.32.2 +requests==2.32.3 # via google-api-core # via labelbox # via sphinx @@ -284,7 +284,7 @@ tomli==2.0.1 # via yapf tomlkit==0.12.5 # via toml-cli -tornado==6.4 +tornado==6.4.1 # via jupyter-client tqdm==4.66.4 # via labelbox @@ -303,9 +303,9 @@ typer==0.12.3 # via toml-cli types-pillow==10.2.0.20240520 types-python-dateutil==2.9.0.20240316 -types-requests==2.32.0.20240523 +types-requests==2.32.0.20240622 types-tqdm==4.66.0.20240417 -typing-extensions==4.12.0 +typing-extensions==4.12.2 # via annotated-types # via black # via databooks @@ -319,7 +319,7 @@ typing-extensions==4.12.0 # via typer tzdata==2024.1 # via pandas -urllib3==2.2.1 +urllib3==2.2.2 # via requests # via types-requests wcwidth==0.2.13 @@ -328,6 +328,6 @@ webencodings==0.5.1 # via bleach # via tinycss2 yapf==0.40.2 -zipp==3.19.0 +zipp==3.19.2 # via importlib-metadata # via importlib-resources diff --git a/requirements.lock b/requirements.lock index a1b09c201..b052771f6 100644 --- a/requirements.lock +++ b/requirements.lock @@ -17,7 +17,7 @@ babel==2.15.0 # via sphinx cachetools==5.3.3 # via google-auth -certifi==2024.2.2 +certifi==2024.6.2 # via pyproj # via requests charset-normalizer==3.3.2 @@ -27,18 +27,18 @@ docutils==0.20.1 # via sphinx-rtd-theme geojson==3.1.0 # via labelbox -google-api-core==2.19.0 +google-api-core==2.19.1 # via labelbox -google-auth==2.29.0 +google-auth==2.30.0 # via google-api-core -googleapis-common-protos==1.63.0 +googleapis-common-protos==1.63.2 # via google-api-core idna==3.7 # via requests imagesize==1.4.1 # via labelbox # via sphinx -importlib-metadata==7.1.0 +importlib-metadata==7.2.1 # via sphinx # via typeguard jinja2==3.1.4 @@ -51,13 +51,13 @@ numpy==1.24.4 # via shapely opencv-python-headless==4.10.0.84 # via labelbox -packaging==24.0 +packaging==24.1 # via sphinx pillow==10.3.0 # via labelbox -proto-plus==1.23.0 +proto-plus==1.24.0 # via google-api-core -protobuf==4.25.3 +protobuf==5.27.1 # via google-api-core # via googleapis-common-protos # via proto-plus @@ -66,9 +66,9 @@ pyasn1==0.6.0 # via rsa pyasn1-modules==0.4.0 # via google-auth -pydantic==2.7.2 +pydantic==2.7.4 # via labelbox -pydantic-core==2.18.3 +pydantic-core==2.18.4 # via pydantic pygeotile==1.0.6 # via labelbox @@ -80,7 +80,7 @@ python-dateutil==2.8.2 # via labelbox pytz==2024.1 # via babel -requests==2.32.2 +requests==2.32.3 # via google-api-core # via labelbox # via sphinx @@ -117,13 +117,13 @@ tqdm==4.66.4 # via labelbox typeguard==4.3.0 # via labelbox -typing-extensions==4.12.0 +typing-extensions==4.12.2 # via annotated-types # via labelbox # via pydantic # via pydantic-core # via typeguard -urllib3==2.2.1 +urllib3==2.2.2 # via requests -zipp==3.19.0 +zipp==3.19.2 # via importlib-metadata From 6ccad6072777876c48a823480a4b730a17accd3c Mon Sep 17 00:00:00 2001 From: Adrian Chang Date: Mon, 24 Jun 2024 22:09:35 -0700 Subject: [PATCH 5/7] update offline --- libs/labelbox/src/labelbox/__init__.py | 2 +- .../integration/schema/test_user_group.py | 141 ++++++++---------- .../test_offline_chat_evaluation_project.py | 2 +- 3 files changed, 64 insertions(+), 81 deletions(-) diff --git a/libs/labelbox/src/labelbox/__init__.py b/libs/labelbox/src/labelbox/__init__.py index 3bc3f4f06..4cd3b4390 100644 --- a/libs/labelbox/src/labelbox/__init__.py +++ b/libs/labelbox/src/labelbox/__init__.py @@ -40,4 +40,4 @@ from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.ontology_kind import OntologyKind -from labelbox.schema.project_overview import ProjectOverview, ProjectOverviewDetailed \ No newline at end of file +from labelbox.schema.project_overview import ProjectOverview, ProjectOverviewDetailed diff --git a/libs/labelbox/tests/integration/schema/test_user_group.py b/libs/labelbox/tests/integration/schema/test_user_group.py index 15f4f6738..678fdc146 100644 --- a/libs/labelbox/tests/integration/schema/test_user_group.py +++ b/libs/labelbox/tests/integration/schema/test_user_group.py @@ -5,103 +5,86 @@ data = faker.Faker() +@pytest.fixture +def user_group(client): + group_name = data.name() + # Create a new user group + user_group = UserGroup(client) + user_group.name = group_name + user_group.color = UserGroupColor.BLUE + user_group.create() + + yield user_group + + user_group.delete() -class TestUserGroup: - - def test_existing_user_groups(self, client): - group_name = data.name() - # Create a new user group - user_group = UserGroup(client) - user_group.name = group_name - user_group.color = UserGroupColor.BLUE - user_group.create() - - # Verify that the user group was created successfully - user_group_equal = UserGroup(client, id=user_group.id) - assert user_group.id == user_group_equal.id - assert user_group.name == user_group_equal.name - assert user_group.color == user_group_equal.color - - user_group.delete() - - def test_create_user_group(self, client): - group_name = data.name() - # Create a new user group - user_group = UserGroup(client) - user_group.name = group_name - user_group.color = UserGroupColor.BLUE - user_group.create() - - # Verify that the user group was created successfully - assert user_group.id is not None - assert user_group.name == group_name - assert user_group.color == UserGroupColor.BLUE - user_group.delete() +def test_existing_user_groups(user_group, client): + # Verify that the user group was created successfully + user_group_equal = UserGroup(client, id=user_group.id) + assert user_group.id == user_group_equal.id + assert user_group.name == user_group_equal.name + assert user_group.color == user_group_equal.color - def test_update_user_group(self, client): - # Create a new user group - group_name = data.name() - user_group = UserGroup(client) - user_group.name = group_name - user_group.create() - # Update the user group - group_name = data.name() - user_group.name = group_name - user_group.color = UserGroupColor.PURPLE - user_group.update() +def test_create_user_group(user_group): + # Verify that the user group was created successfully + assert user_group.id is not None + assert user_group.name is not None + assert user_group.color == UserGroupColor.BLUE - # Verify that the user group was updated successfully - assert user_group.name == group_name - assert user_group.color == UserGroupColor.PURPLE - user_group.delete() +def test_update_user_group(user_group): + # Update the user group + group_name = data.name() + user_group.name = group_name + user_group.color = UserGroupColor.PURPLE + user_group.update() - def test_get_user_groups(self, client): - # Get all user groups - user_groups_old = list(UserGroup.get_user_groups(client)) + # Verify that the user group was updated successfully + assert user_group.name == group_name + assert user_group.color == UserGroupColor.PURPLE - group_name = data.name() - user_group = UserGroup(client) - user_group.name = group_name - user_group.create() - user_groups_new = list(UserGroup.get_user_groups(client)) +def test_get_user_groups(user_group, client): + # Get all user groups + user_groups_old = UserGroup.get_user_groups(client) - # Verify that at least one user group is returned - assert len(user_groups_new) > 0 - assert len(user_groups_new) == len(user_groups_old) + 1 + # manual delete for iterators + group_name = data.name() + user_group = UserGroup(client) + user_group.name = group_name + user_group.create() - # Verify that each user group has a valid ID and name - for user_group in user_groups_new: - assert user_group.id is not None - assert user_group.name is not None + user_groups_new = UserGroup.get_user_groups(client) - user_group.delete() + # Verify that at least one user group is returned + assert len(user_groups_new) > 0 + assert len(user_groups_new) == len(user_groups_old) + 1 + + # Verify that each user group has a valid ID and name + for user_group in user_groups_new: + assert user_group.id is not None + assert user_group.name is not None - # project_pack creates two projects - def test_update_user_group(self, client, project_pack): - # Create a new user group - user_group = UserGroup(client) - user_group.name = data.name() - user_group.create() + user_group.delete() - users = list(client.get_users()) - projects = project_pack - # Add the user to the group - user_group.users.add(users[0]) - user_group.projects.add(projects[0]) - user_group.update() +# project_pack creates two projects +def test_update_user_group(user_group, client, project_pack): + users = list(client.get_users()) + projects = project_pack - # Verify that the user is added to the group - assert users[0] in user_group.users - assert projects[0] in user_group.projects + # Add the user to the group + user_group.users.add(users[0]) + user_group.projects.add(projects[0]) + user_group.update() - user_group.delete() + # Verify that the user is added to the group + assert users[0] in user_group.users + assert projects[0] in user_group.projects if __name__ == "__main__": import subprocess - subprocess.call(["pytest", "-v", __file__]) + subprocess.call(["pytest", "-v", __file__]) \ No newline at end of file diff --git a/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py b/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py index d27f4e95e..07d9d743f 100644 --- a/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py +++ b/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py @@ -1,6 +1,6 @@ import pytest - +@pytest.mark.skip(reason="Behavior changed, need to update test") def test_create_offline_chat_evaluation_project(client, rand_gen, offline_chat_evaluation_project, chat_evaluation_ontology, From 87a28810790b389218c6a31c387e93b9c80b0e45 Mon Sep 17 00:00:00 2001 From: Adrian Chang Date: Wed, 26 Jun 2024 00:51:22 -0700 Subject: [PATCH 6/7] use pydnatic --- .../src/labelbox/schema/user_group.py | 167 ++++-------------- .../tests/unit/schema/test_user_group.py | 6 +- 2 files changed, 35 insertions(+), 138 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/user_group.py b/libs/labelbox/src/labelbox/schema/user_group.py index 27bdacf27..2cf723726 100644 --- a/libs/labelbox/src/labelbox/schema/user_group.py +++ b/libs/labelbox/src/labelbox/schema/user_group.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Set, List, Optional, Union, TypedDict, Iterator +from typing import Set, List, Union, Iterator, Optional from labelbox import Client from labelbox.exceptions import ResourceCreationError @@ -84,7 +84,7 @@ def __eq__(self, other): return self.id == other.id -class UserGroup: +class UserGroup(BaseModel): """ Represents a user group in Labelbox. @@ -93,54 +93,51 @@ class UserGroup: **kwargs: Additional keyword arguments for initializing the UserGroup object. Attributes: - _id (str): The ID of the user group. - _name (str): The name of the user group. - _color (UserGroupColor): The color of the user group. - _users (Set[Union[UserGroupUser, User]]): The set of user IDs in the user group. - _projects (Set[Union[UserGroupProject, Project]]): The set of project IDs in the user group. - _client (Client): The Labelbox client. + id (str): The ID of the user group. + name (str): The name of the user group. + color (UserGroupColor): The color of the user group. + users (Set[Union[UserGroupUser, User]]): The set of user IDs in the user group. + projects (Set[Union[UserGroupProject, Project]]): The set of project IDs in the user group. + client (Client): The Labelbox client. """ - _id: str - _name: str - _color: UserGroupColor - _users: Set[Union[UserGroupUser, User]] - _projects: Set[Union[UserGroupProject, Project]] - _client: Client + id: Optional[str] + name: Optional[str] + color: UserGroupColor + users: Set[Union[UserGroupUser, User]] + projects: Set[Union[UserGroupProject, Project]] + client: Client + + class Config: + # fix for pydnatic 2 + arbitrary_types_allowed = True def __init__( - self, - client: Client, - id: str = "", - name: str = "", - color: UserGroupColor = UserGroupColor.BLUE, - users: Set[Union[UserGroupUser, User]] = set(), - projects: Set[Union[UserGroupProject, Project]] = set(), - reload=True, + self, + client: Client, + id: str = "", + name: str = "", + color: UserGroupColor = UserGroupColor.BLUE, + users: Set[Union[UserGroupUser, User]] = set(), + projects: Set[Union[UserGroupProject, Project]] = set(), + reload=True, ): """ - Initializes a Group object. + Initializes a UserGroup object. Args: client (Client): The Labelbox client. - reload (bool): Whether to reload the group information from the server. Defaults to True. + id (str): The ID of the user group. Defaults to an empty string. + name (str): The name of the user group. Defaults to an empty string. color (UserGroupColor): The color of the user group. Defaults to UserGroupColor.BLUE. users (Set[Union[UserGroupUser, User]]): The set of user IDs in the user group. Defaults to an empty set. projects (Set[Union[UserGroupProject, Project]]): The set of project IDs in the user group. Defaults to an empty set. - name (str): The name of the user group. Defaults to None. - id (str): The ID of the user group. Defaults to None. + reload (bool): Whether to reload the group information from the server. Defaults to True. """ - super().__init__() - self.color = color - self.users = users - self.projects = projects - self.client = client - - if not client.enable_experimental: + super().__init__(client=client, id=id, name=name, color=color, users=users, projects=projects) + if not self.client.enable_experimental: raise RuntimeError( "Please enable experimental in client to use UserGroups") - self.name = name - self.id = id # partial representation of the group, reload if self.id and reload: self._reload() @@ -199,106 +196,6 @@ def _reload(self): for member in result["userGroup"]["members"]["nodes"] } - @property - def id(self) -> str: - """ - Gets the ID of the group. - - Returns: - str: The ID of the group. - """ - return self._id - - @id.setter - def id(self, value: str) -> None: - """ - Sets the ID of the group. - - Args: - value (str): The ID to set. - """ - self._id = value - - @property - def name(self) -> str: - """ - Gets the name of the group. - - Returns: - str: The name of the group. - """ - return self._name - - @name.setter - def name(self, value: str) -> None: - """ - Sets the name of the group. - - Args: - value (str): The name to set. - """ - self._name = value - - @property - def color(self) -> UserGroupColor: - """ - Gets the color of the group. - - Returns: - GroupColor: The color of the group. - """ - return self._color - - @color.setter - def color(self, value: UserGroupColor) -> None: - """ - Sets the color of the group. - - Args: - value (GroupColor): The color to set. - """ - self._color = value - - @property - def users(self) -> Set[Union[UserGroupUser, User]]: - """ - Gets the list of user IDs in the group. - - Returns: - Set[GroupUser]: The list of user IDs in the group. - """ - return self._users - - @users.setter - def users(self, value: Set[Union[UserGroupUser, User]]) -> None: - """ - Sets the list of user IDs in the group. - - Args: - value (Set[GroupUser]): The list of user IDs to set. - """ - self._users = value - - @property - def projects(self) -> Set[Union[UserGroupProject, Project]]: - """ - Gets the list of project IDs in the group. - - Returns: - Set[GroupProject]: The list of project IDs in the group. - """ - return self._projects - - @projects.setter - def projects(self, value: Set[Union[UserGroupProject, Project]]) -> None: - """ - Sets the list of project IDs in the group. - - Args: - value (Set[GroupProject]): The list of project IDs to set. - """ - self._projects = value - def update(self) -> "UserGroup": """ Updates the group in Labelbox. diff --git a/libs/labelbox/tests/unit/schema/test_user_group.py b/libs/labelbox/tests/unit/schema/test_user_group.py index 03514fa4a..6f1400308 100644 --- a/libs/labelbox/tests/unit/schema/test_user_group.py +++ b/libs/labelbox/tests/unit/schema/test_user_group.py @@ -57,10 +57,10 @@ def test_user_group_project_hash(self): class TestUserGroup: def setup_method(self): - self.client = MagicMock() - self.group = UserGroup(self.client, name="Test Group") + self.client = MagicMock(Client) self.client.enable_experimental = True - + self.group = UserGroup(client=self.client, name="Test Group") + def test_constructor_experimental_needed(self): client = MagicMock(Client) client.enable_experimental = False From b5957f378269b7c123acf047b77d5c44b9d93b1e Mon Sep 17 00:00:00 2001 From: Adrian Chang Date: Wed, 26 Jun 2024 01:02:04 -0700 Subject: [PATCH 7/7] Remove user / projects --- .../src/labelbox/schema/user_group.py | 80 +++++++++---------- libs/labelbox/tests/conftest.py | 2 + .../test_offline_chat_evaluation_project.py | 1 - 3 files changed, 41 insertions(+), 42 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/user_group.py b/libs/labelbox/src/labelbox/schema/user_group.py index 2cf723726..c8779251b 100644 --- a/libs/labelbox/src/labelbox/schema/user_group.py +++ b/libs/labelbox/src/labelbox/schema/user_group.py @@ -83,28 +83,33 @@ def __eq__(self, other): return False return self.id == other.id - + class UserGroup(BaseModel): """ Represents a user group in Labelbox. - Args: - client (Client): The Labelbox client. - **kwargs: Additional keyword arguments for initializing the UserGroup object. - Attributes: - id (str): The ID of the user group. - name (str): The name of the user group. + id (Optional[str]): The ID of the user group. + name (Optional[str]): The name of the user group. color (UserGroupColor): The color of the user group. - users (Set[Union[UserGroupUser, User]]): The set of user IDs in the user group. - projects (Set[Union[UserGroupProject, Project]]): The set of project IDs in the user group. - client (Client): The Labelbox client. + users (Set[UserGroupUser]): The set of users in the user group. + projects (Set[UserGroupProject]): The set of projects associated with the user group. + client (Client): The Labelbox client object. + + Methods: + __init__(self, client: Client, id: str = "", name: str = "", color: UserGroupColor = UserGroupColor.BLUE, + users: Set[UserGroupUser] = set(), projects: Set[UserGroupProject] = set(), reload=True) + _reload(self) + update(self) -> "UserGroup" + create(self) -> "UserGroup" + delete(self) -> bool + get_user_groups(client: Client) -> Iterator["UserGroup"] """ id: Optional[str] name: Optional[str] color: UserGroupColor - users: Set[Union[UserGroupUser, User]] - projects: Set[Union[UserGroupProject, Project]] + users: Set[UserGroupUser] + projects: Set[UserGroupProject] client: Client class Config: @@ -117,21 +122,25 @@ def __init__( id: str = "", name: str = "", color: UserGroupColor = UserGroupColor.BLUE, - users: Set[Union[UserGroupUser, User]] = set(), - projects: Set[Union[UserGroupProject, Project]] = set(), + users: Set[UserGroupUser] = set(), + projects: Set[UserGroupProject] = set(), reload=True, ): """ Initializes a UserGroup object. Args: - client (Client): The Labelbox client. - id (str): The ID of the user group. Defaults to an empty string. - name (str): The name of the user group. Defaults to an empty string. - color (UserGroupColor): The color of the user group. Defaults to UserGroupColor.BLUE. - users (Set[Union[UserGroupUser, User]]): The set of user IDs in the user group. Defaults to an empty set. - projects (Set[Union[UserGroupProject, Project]]): The set of project IDs in the user group. Defaults to an empty set. - reload (bool): Whether to reload the group information from the server. Defaults to True. + client (Client): The Labelbox client object. + id (str, optional): The ID of the user group. Defaults to an empty string. + name (str, optional): The name of the user group. Defaults to an empty string. + color (UserGroupColor, optional): The color of the user group. Defaults to UserGroupColor.BLUE. + users (Set[UserGroupUser], optional): The set of users in the user group. Defaults to an empty set. + projects (Set[UserGroupProject], optional): The set of projects associated with the user group. Defaults to an empty set. + reload (bool, optional): Whether to reload the partial representation of the group. Defaults to True. + + Raises: + RuntimeError: If the experimental feature is not enabled in the client. + """ super().__init__(client=client, id=id, name=name, color=color, users=users, projects=projects) if not self.client.enable_experimental: @@ -240,12 +249,10 @@ def update(self) -> "UserGroup": "color": self.color.value, "projectIds": [ - project.id if hasattr(project, 'id') else project.uid - for project in self.projects + project.id for project in self.projects ], "userIds": [ - user.id if hasattr(user, 'id') else user.uid - for user in self.users + user.id for user in self.users ] } result = self.client.execute(query, params) @@ -255,21 +262,14 @@ def update(self) -> "UserGroup": def create(self) -> "UserGroup": """ - Creates a new group in Labelbox. - - Args: - client (Client): The Labelbox client. - name (str): The name of the group. - color (GroupColor, optional): The color of the group. Defaults to GroupColor.BLUE. - users (List[User], optional): The users to add to the group. Defaults to []. - projects (List[Project], optional): The projects to add to the group. Defaults to []. - - Returns: - Group: The newly created group. (self) + Creates a new user group. Raises: - ResourceCreationError: If the group already exists or if the creation fails. + ResourceCreationError: If the group already exists. ValueError: If the group name is not provided. + + Returns: + UserGroup: The created user group. """ if self.id: raise ResourceCreationError("Group already exists") @@ -311,12 +311,10 @@ def create(self) -> "UserGroup": "color": self.color.value, "projectIds": [ - project.id if hasattr(project, 'id') else project.uid - for project in self.projects + project.id for project in self.projects ], "userIds": [ - user.id if hasattr(user, 'id') else user.uid - for user in self.users + user.id for user in self.users ] } result = self.client.execute(query, params) diff --git a/libs/labelbox/tests/conftest.py b/libs/labelbox/tests/conftest.py index 80229e319..420bc5a83 100644 --- a/libs/labelbox/tests/conftest.py +++ b/libs/labelbox/tests/conftest.py @@ -158,6 +158,8 @@ def execute(self, query=None, params=None, check_naming=True, **kwargs): assert re.match(r"\s*(?:query|mutation) \w+PyApi", query) is not None self.queries.append((query, params)) + if not kwargs.get('timeout'): + kwargs['timeout'] = 30.0 return super().execute(query, params, **kwargs) diff --git a/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py b/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py index 07d9d743f..f1e3877ff 100644 --- a/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py +++ b/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py @@ -1,6 +1,5 @@ import pytest -@pytest.mark.skip(reason="Behavior changed, need to update test") def test_create_offline_chat_evaluation_project(client, rand_gen, offline_chat_evaluation_project, chat_evaluation_ontology,