Skip to content

[PLT-1205] Improvements for QA #1706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 80 additions & 106 deletions libs/labelbox/src/labelbox/schema/user_group.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from enum import Enum
from typing import Set, List, Union, Iterator, Optional
from typing import Set, Iterator
from collections import defaultdict

from labelbox import Client
from labelbox.exceptions import ResourceCreationError
from labelbox.pydantic_compat import BaseModel
from labelbox.schema.user import User
from labelbox.schema.project import Project
from labelbox.exceptions import UnprocessableEntityError, InvalidQueryError
from labelbox.schema.queue_mode import QueueMode
from labelbox.schema.ontology_kind import EditorTaskType
from labelbox.schema.media_type import MediaType


class UserGroupColor(Enum):
Expand Down Expand Up @@ -34,82 +38,32 @@ class UserGroupColor(Enum):
YELLOW = "E7BF00"
GRAY = "B8C4D3"


class UserGroupUser(BaseModel):
"""
Represents a user in a group.

Attributes:
id (str): The ID of the user.
email (str): The email of the user.
"""
id: str
email: str

def __hash__(self):
return hash((self.id))

def __eq__(self, other):
if not isinstance(other, UserGroupUser):
return False
return self.id == other.id


class UserGroupProject(BaseModel):
"""
Represents a project in a group.

Attributes:
id (str): The ID of the project.
name (str): The name of the project.
"""
id: str
name: str

def __hash__(self):
return hash((self.id))

def __eq__(self, other):
"""
Check if this GroupProject object is equal to another GroupProject object.

Args:
other (GroupProject): The other GroupProject object to compare with.

Returns:
bool: True if the two GroupProject objects are equal, False otherwise.
"""
if not isinstance(other, UserGroupProject):
return False
return self.id == other.id


class UserGroup(BaseModel):
"""
Represents a user group in Labelbox.

Attributes:
id (Optional[str]): The ID of the user group.
name (Optional[str]): The name of the user group.
id (str): The ID of the user group.
name (str): The name of the user group.
color (UserGroupColor): The color of the user group.
users (Set[UserGroupUser]): The set of users in the user group.
projects (Set[UserGroupProject]): The set of projects associated with the user group.
client (Client): The Labelbox client object.

Methods:
__init__(self, client: Client, id: str = "", name: str = "", color: UserGroupColor = UserGroupColor.BLUE,
users: Set[UserGroupUser] = set(), projects: Set[UserGroupProject] = set(), reload=True)
_reload(self)
__init__(self, client: Client)
get(self) -> "UserGroup"
update(self) -> "UserGroup"
create(self) -> "UserGroup"
delete(self) -> bool
get_user_groups(client: Client) -> Iterator["UserGroup"]
"""
id: Optional[str]
name: Optional[str]
id: str
name: str
color: UserGroupColor
users: Set[UserGroupUser]
projects: Set[UserGroupProject]
users: Set[User]
projects: Set[Project]
client: Client

class Config:
Expand All @@ -122,9 +76,8 @@ def __init__(
id: str = "",
name: str = "",
color: UserGroupColor = UserGroupColor.BLUE,
users: Set[UserGroupUser] = set(),
projects: Set[UserGroupProject] = set(),
reload=True,
users: Set[User] = set(),
projects: Set[Project] = set()
):
"""
Initializes a UserGroup object.
Expand All @@ -134,36 +87,32 @@ def __init__(
id (str, optional): The ID of the user group. Defaults to an empty string.
name (str, optional): The name of the user group. Defaults to an empty string.
color (UserGroupColor, optional): The color of the user group. Defaults to UserGroupColor.BLUE.
users (Set[UserGroupUser], optional): The set of users in the user group. Defaults to an empty set.
projects (Set[UserGroupProject], optional): The set of projects associated with the user group. Defaults to an empty set.
reload (bool, optional): Whether to reload the partial representation of the group. Defaults to True.
users (Set[User], optional): The set of users in the user group. Defaults to an empty set.
projects (Set[Project], optional): The set of projects associated with the user group. Defaults to an empty set.

Raises:
RuntimeError: If the experimental feature is not enabled in the client.

"""
super().__init__(client=client, id=id, name=name, color=color, users=users, projects=projects)
if not self.client.enable_experimental:
raise RuntimeError(
"Please enable experimental in client to use UserGroups")
raise RuntimeError("Please enable experimental in client to use UserGroups")

# partial representation of the group, reload
if self.id and reload:
self._reload()

def _reload(self):
def get(self) -> "UserGroup":
"""
Reloads the user group information from the server.

This method sends a GraphQL query to the server to fetch the latest information
about the user group, including its name, color, projects, and members. The fetched
information is then used to update the corresponding attributes of the `Group` object.

Raises:
InvalidQueryError: If the query fails to fetch the group information.
Args:
id (str): The ID of the user group to fetch.

Returns:
None
UserGroup of passed in ID (self)

Raises:
InvalidQueryError: If the query fails to fetch the group information.
"""
query = """
query GetUserGroupPyApi($id: ID!) {
Expand Down Expand Up @@ -196,14 +145,9 @@ def _reload(self):
raise InvalidQueryError("Failed to fetch group")
self.name = result["userGroup"]["name"]
self.color = UserGroupColor(result["userGroup"]["color"])
self.projects = {
UserGroupProject(id=project["id"], name=project["name"])
for project in result["userGroup"]["projects"]["nodes"]
}
self.users = {
UserGroupUser(id=member["id"], email=member["email"])
for member in result["userGroup"]["members"]["nodes"]
}
self.projects = self._get_projects_set(result["userGroup"]["projects"]["nodes"])
self.users = self._get_users_set(result["userGroup"]["members"]["nodes"])
return self

def update(self) -> "UserGroup":
"""
Expand Down Expand Up @@ -249,10 +193,10 @@ def update(self) -> "UserGroup":
"color":
self.color.value,
"projectIds": [
project.id for project in self.projects
project.uid for project in self.projects
],
"userIds": [
user.id for user in self.users
user.uid for user in self.users
]
}
result = self.client.execute(query, params)
Expand Down Expand Up @@ -311,10 +255,10 @@ def create(self) -> "UserGroup":
"color":
self.color.value,
"projectIds": [
project.id for project in self.projects
project.uid for project in self.projects
],
"userIds": [
user.id for user in self.users
user.uid for user in self.users
]
}
result = self.client.execute(query, params)
Expand Down Expand Up @@ -351,8 +295,7 @@ def delete(self) -> bool:
raise UnprocessableEntityError("Failed to delete user group")
return result["deleteUserGroup"]["success"]

@staticmethod
def get_user_groups(client: Client) -> Iterator["UserGroup"]:
def get_user_groups(self) -> Iterator["UserGroup"]:
"""
Gets all user groups in Labelbox.

Expand Down Expand Up @@ -390,29 +333,60 @@ def get_user_groups(client: Client) -> Iterator["UserGroup"]:
"""
nextCursor = None
while True:
userGroups = client.execute(
userGroups = self.client.execute(
query, {"nextCursor": nextCursor})["userGroups"]
if not userGroups:
return
yield
groups = userGroups["nodes"]
for group in groups:
yield UserGroup(client,
reload=False,
id=group["id"],
name=group["name"],
color=UserGroupColor(group["color"]),
users={
UserGroupUser(id=member["id"],
email=member["email"])
for member in group["members"]["nodes"]
},
projects={
UserGroupProject(id=project["id"],
name=project["name"])
for project in group["projects"]["nodes"]
})
userGroup = UserGroup(self.client)
userGroup.id = group["id"]
userGroup.name = group["name"]
userGroup.color = UserGroupColor(group["color"])
userGroup.users = self._get_users_set(group["members"]["nodes"])
userGroup.projects = self._get_projects_set(group["projects"]["nodes"])
yield userGroup
nextCursor = userGroups["nextCursor"]
# this doesn't seem to be implemented right now to return a value other than null from the api
if not nextCursor:
break

def _get_users_set(self, user_nodes):
"""
Retrieves a set of User objects from the given user nodes.

Args:
user_nodes (list): A list of user nodes containing user information.

Returns:
set: A set of User objects.
"""
users = set()
for user in user_nodes:
user_values = defaultdict(lambda: None)
user_values["id"] = user["id"]
user_values["email"] = user["email"]
users.add(User(self.client, user_values))
return users

def _get_projects_set(self, project_nodes):
"""
Retrieves a set of projects based on the given project nodes.

Args:
project_nodes (list): A list of project nodes.

Returns:
set: A set of Project objects.
"""
projects = set()
for project in project_nodes:
project_values = defaultdict(lambda: None)
project_values["id"] = project["id"]
project_values["name"] = project["name"]
project_values["queueMode"] = QueueMode.Batch.value
project_values["editorTaskType"] = EditorTaskType.Missing.value
project_values["mediaType"] = MediaType.Image.value
projects.add(Project(self.client, project_values))
return projects
18 changes: 6 additions & 12 deletions libs/labelbox/tests/integration/schema/test_user_group.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import faker
from labelbox import Client
from labelbox.schema.user_group import UserGroup, UserGroupColor, UserGroupUser, UserGroupProject
from labelbox.schema.user_group import UserGroup, UserGroupColor

data = faker.Faker()

Expand All @@ -21,7 +21,9 @@ def user_group(client):

def test_existing_user_groups(user_group, client):
# Verify that the user group was created successfully
user_group_equal = UserGroup(client, id=user_group.id)
user_group_equal = UserGroup(client)
user_group_equal.id = user_group.id
user_group_equal.get()
assert user_group.id == user_group_equal.id
assert user_group.name == user_group_equal.name
assert user_group.color == user_group_equal.color
Expand All @@ -48,15 +50,15 @@ def test_update_user_group(user_group):

def test_get_user_groups(user_group, client):
# Get all user groups
user_groups_old = list(UserGroup.get_user_groups(client))
user_groups_old = list(UserGroup(client).get_user_groups())

# manual delete for iterators
group_name = data.name()
user_group = UserGroup(client)
user_group.name = group_name
user_group.create()

user_groups_new = list(UserGroup.get_user_groups(client))
user_groups_new = list(UserGroup(client).get_user_groups())

# Verify that at least one user group is returned
assert len(user_groups_new) > 0
Expand All @@ -77,15 +79,7 @@ def test_update_user_group(user_group, client, project_pack):

# Add the user to the group
user = users[0]
user = UserGroupUser(
id=user.uid,
email=user.email
)
project = projects[0]
project = UserGroupProject(
id=project.uid,
name=project.name
)
user_group.users.add(user)
user_group.projects.add(project)
user_group.update()
Expand Down
Loading