From d9464b93433121f2e7e6813623e87727feb1a299 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Thu, 5 Dec 2024 14:13:59 -0800 Subject: [PATCH 1/7] Get role id by name --- .../src/labelbox/schema/user_group.py | 5 +- .../src/labelbox/schema/user_group_upload.py | 46 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 libs/labelbox/src/labelbox/schema/user_group_upload.py diff --git a/libs/labelbox/src/labelbox/schema/user_group.py b/libs/labelbox/src/labelbox/schema/user_group.py index 2e93b4376..35e3c5d7d 100644 --- a/libs/labelbox/src/labelbox/schema/user_group.py +++ b/libs/labelbox/src/labelbox/schema/user_group.py @@ -1,6 +1,6 @@ from collections import defaultdict from enum import Enum -from typing import Iterator, Set +from typing import Iterator, List, Set from lbox.exceptions import ( MalformedQueryException, @@ -317,6 +317,9 @@ def delete(self) -> bool: ) return result["deleteUserGroup"]["success"] + def import_members(self, role: str, emails: List[str]): + pass + def get_user_groups(self) -> Iterator["UserGroup"]: """ Gets all user groups in Labelbox. diff --git a/libs/labelbox/src/labelbox/schema/user_group_upload.py b/libs/labelbox/src/labelbox/schema/user_group_upload.py new file mode 100644 index 000000000..65abc47a1 --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/user_group_upload.py @@ -0,0 +1,46 @@ +from typing import List, Optional + +from lbox.exceptions import ResourceNotFoundError + +from labelbox.client import Client +from labelbox.pagination import PaginatedCollection + + +class UserGroupUpload: + def __init__(self, client: Client): + self.client = client + + def upload_members(self, group_id: str, role: str, emails: List[str]): + role_id = self._get_role_id(role) + if role_id is None: + raise ResourceNotFoundError(message="The role does not exist.") + + def _get_role_id(self, role_name: str) -> Optional[str]: + role_id = None + query = """query GetAvailableUserRolesPyPi { + roles(skip: %d, first: %d) { + id + organizationId + name + description + } + } + """ + + result = PaginatedCollection( + client=self.client, + query=query, + params={}, + dereferencing=["roles"], + obj_class=lambda _, data: data, # type: ignore + ) + if result is None: + raise ResourceNotFoundError( + message="Could not find any valid roles." + ) + for role in result: + if role["name"].strip() == role_name.strip(): + role_id = role["id"] + break + + return role_id From 2c8e358b8080b07bd0a4e625345b7fc4643a8e30 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Thu, 5 Dec 2024 15:49:11 -0800 Subject: [PATCH 2/7] Add support for member upload --- .../src/labelbox/schema/user_group_upload.py | 112 +++++++++++++++++- 1 file changed, 111 insertions(+), 1 deletion(-) diff --git a/libs/labelbox/src/labelbox/schema/user_group_upload.py b/libs/labelbox/src/labelbox/schema/user_group_upload.py index 65abc47a1..19599dd61 100644 --- a/libs/labelbox/src/labelbox/schema/user_group_upload.py +++ b/libs/labelbox/src/labelbox/schema/user_group_upload.py @@ -1,6 +1,13 @@ +import json +from io import BytesIO from typing import List, Optional -from lbox.exceptions import ResourceNotFoundError +import requests +from lbox.exceptions import ( + InternalServerError, + LabelboxError, + ResourceNotFoundError, +) from labelbox.client import Client from labelbox.pagination import PaginatedCollection @@ -11,10 +18,104 @@ def __init__(self, client: Client): self.client = client def upload_members(self, group_id: str, role: str, emails: List[str]): + if len(emails) == 0: + print("No emails to upload.") + return + role_id = self._get_role_id(role) if role_id is None: raise ResourceNotFoundError(message="The role does not exist.") + buffer = BytesIO() + buffer.write(b"email\n") # Header row + for email in emails: + buffer.write(f"{email}\n".encode("utf-8")) + # Reset pointer to start of stream + buffer.seek(0) + + multipart_file_field = "1" + gql_file_field = "file" + files = { + multipart_file_field: ( + f"{multipart_file_field}.csv", + buffer, + "text/csv", + ) + } + query = """mutation ImportMembersToGroup( + $roleId: ID! + $file: Upload! + $where: WhereUniqueIdInput! + ) { + importUsersAsCsvToGroup(roleId: $roleId, file: $file, where: $where) { + csvReport + addedCount + count + } + } + """ + params = { + "roleId": role_id, + gql_file_field: None, + "where": {"id": group_id}, + } + + request_data = { + "operations": json.dumps( + { + "variables": params, + "query": query, + } + ), + "map": ( + None, + json.dumps( + {multipart_file_field: [f"variables.{gql_file_field}"]} + ), + ), + } + + client = self.client + headers = dict(client.connection.headers) + headers.pop("Content-Type", None) + request = requests.Request( + "POST", + client.endpoint, + headers=headers, + data=request_data, + files=files, + ) + + prepped: requests.PreparedRequest = request.prepare() + + response = client.connection.send(prepped) + + if response.status_code == 502: + error_502 = "502 Bad Gateway" + raise InternalServerError(error_502) + elif response.status_code == 503: + raise InternalServerError(response.text) + elif response.status_code == 520: + raise InternalServerError(response.text) + + try: + file_data = response.json().get("data", None) + except ValueError as e: # response is not valid JSON + raise LabelboxError("Failed to upload, unknown cause", e) + + if not file_data or not file_data.get("importUsersAsCsvToGroup", None): + try: + errors = response.json().get("errors", []) + error_msg = next(iter(errors), {}).get( + "message", "Unknown error" + ) + except Exception: + error_msg = "Unknown error" + raise LabelboxError("Failed to upload, message: %s" % error_msg) + + csv_report = file_data["importUsersAsCsvToGroup"]["csvReport"] + return self._parse_csv_report(csv_report) + def _get_role_id(self, role_name: str) -> Optional[str]: role_id = None query = """query GetAvailableUserRolesPyPi { @@ -44,3 +145,12 @@ def _get_role_id(self, role_name: str) -> Optional[str]: break return role_id + + def _parse_csv_report(self, csv_report: str) -> List[dict]: + lines = csv_report.strip().split("\n") + headers = lines[0].split(",") + report_list = [] + for line in lines[1:]: + values = line.split(",") + report_list.append(dict(zip(headers, values))) + return report_list From b0ed92d8b801f53878ebc932362a5e0768c41598 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Thu, 5 Dec 2024 15:58:11 -0800 Subject: [PATCH 3/7] Add UploadReport --- .../src/labelbox/schema/user_group_upload.py | 40 +++++++++++++++---- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/user_group_upload.py b/libs/labelbox/src/labelbox/schema/user_group_upload.py index 19599dd61..9cf4d334e 100644 --- a/libs/labelbox/src/labelbox/schema/user_group_upload.py +++ b/libs/labelbox/src/labelbox/schema/user_group_upload.py @@ -1,4 +1,5 @@ import json +from dataclasses import dataclass from io import BytesIO from typing import List, Optional @@ -13,18 +14,34 @@ from labelbox.pagination import PaginatedCollection +@dataclass +class UploadReportLine: + email: str + result: str + error: Optional[str] = None + + +@dataclass +class UploadReport: + lines: List[UploadReportLine] + + class UserGroupUpload: def __init__(self, client: Client): self.client = client - def upload_members(self, group_id: str, role: str, emails: List[str]): + def upload_members( + self, group_id: str, role: str, emails: List[str] + ) -> Optional[UploadReport]: if len(emails) == 0: print("No emails to upload.") - return + return None role_id = self._get_role_id(role) if role_id is None: - raise ResourceNotFoundError(message="The role does not exist.") + raise ResourceNotFoundError( + message="Could not find a valid role with the name provided. Please make sure the role name is correct." + ) buffer = BytesIO() buffer.write(b"email\n") # Header row @@ -146,11 +163,20 @@ def _get_role_id(self, role_name: str) -> Optional[str]: return role_id - def _parse_csv_report(self, csv_report: str) -> List[dict]: + def _parse_csv_report(self, csv_report: str) -> UploadReport: lines = csv_report.strip().split("\n") headers = lines[0].split(",") - report_list = [] + report_lines = [] for line in lines[1:]: values = line.split(",") - report_list.append(dict(zip(headers, values))) - return report_list + row = dict(zip(headers, values)) + report_lines.append( + UploadReportLine( + email=row["Email"], + result=row["Result"], + error=row.get( + "Error" + ), # Using get() since error is optional + ) + ) + return UploadReport(lines=report_lines) From 5a2ac10179b7f751a5ed080b72610fc2f5539aad Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Fri, 6 Dec 2024 10:16:21 -0800 Subject: [PATCH 4/7] Added docstring and readthedocs page --- docs/labelbox/index.rst | 1 + .../src/labelbox/schema/user_group_upload.py | 43 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/docs/labelbox/index.rst b/docs/labelbox/index.rst index 41207c78f..fd9941c1f 100644 --- a/docs/labelbox/index.rst +++ b/docs/labelbox/index.rst @@ -52,4 +52,5 @@ Labelbox Python SDK Documentation task task-queue user + user-group-upload webhook diff --git a/libs/labelbox/src/labelbox/schema/user_group_upload.py b/libs/labelbox/src/labelbox/schema/user_group_upload.py index 9cf4d334e..01640cb2a 100644 --- a/libs/labelbox/src/labelbox/schema/user_group_upload.py +++ b/libs/labelbox/src/labelbox/schema/user_group_upload.py @@ -16,6 +16,30 @@ @dataclass class UploadReportLine: + """A single line in the CSV report of the upload members mutation. + Both errors and successes are reported here. + + Example output when using dataclasses.asdict(): + >>> { + >>> 'lines': [ + >>> { + >>> 'email': '...', + >>> 'result': 'Not added', + >>> 'error': 'User not found in the current organization' + >>> }, + >>> { + >>> 'email': '...', + >>> 'result': 'Not added', + >>> 'error': 'Member already exists in group' + >>> }, + >>> { + >>> 'email': '...', + >>> 'result': 'Added', + >>> 'error': '' + >>> } + >>> ] + >>> } + """ email: str result: str error: Optional[str] = None @@ -23,16 +47,35 @@ class UploadReportLine: @dataclass class UploadReport: + """The report of the upload members mutation.""" lines: List[UploadReportLine] class UserGroupUpload: + """Upload members to a user group.""" + def __init__(self, client: Client): self.client = client def upload_members( self, group_id: str, role: str, emails: List[str] ) -> Optional[UploadReport]: + """Upload members to a user group. + + Args: + group_id: A valid ID of the user group. + role: The name of the role to assign to the uploaded members as it appears in the UI on the Import Members popup. + emails: The list of emails of the members to upload. + + Returns: + UploadReport: The report of the upload members mutation. + + Raises: + ResourceNotFoundError: If the role is not found. + LabelboxError: If the upload fails. + + For indicvidual email errors, the error message is available in the UploadReport. + """ if len(emails) == 0: print("No emails to upload.") return None From 7fa0bb2dad5a4bd889319ae4516bec4eb42bd7f8 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Fri, 6 Dec 2024 10:28:38 -0800 Subject: [PATCH 5/7] Fix linting --- libs/labelbox/src/labelbox/schema/user_group.py | 5 +---- libs/labelbox/src/labelbox/schema/user_group_upload.py | 2 ++ 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/user_group.py b/libs/labelbox/src/labelbox/schema/user_group.py index 35e3c5d7d..2e93b4376 100644 --- a/libs/labelbox/src/labelbox/schema/user_group.py +++ b/libs/labelbox/src/labelbox/schema/user_group.py @@ -1,6 +1,6 @@ from collections import defaultdict from enum import Enum -from typing import Iterator, List, Set +from typing import Iterator, Set from lbox.exceptions import ( MalformedQueryException, @@ -317,9 +317,6 @@ def delete(self) -> bool: ) return result["deleteUserGroup"]["success"] - def import_members(self, role: str, emails: List[str]): - pass - def get_user_groups(self) -> Iterator["UserGroup"]: """ Gets all user groups in Labelbox. diff --git a/libs/labelbox/src/labelbox/schema/user_group_upload.py b/libs/labelbox/src/labelbox/schema/user_group_upload.py index 01640cb2a..b15341f35 100644 --- a/libs/labelbox/src/labelbox/schema/user_group_upload.py +++ b/libs/labelbox/src/labelbox/schema/user_group_upload.py @@ -40,6 +40,7 @@ class UploadReportLine: >>> ] >>> } """ + email: str result: str error: Optional[str] = None @@ -48,6 +49,7 @@ class UploadReportLine: @dataclass class UploadReport: """The report of the upload members mutation.""" + lines: List[UploadReportLine] From 17edde4ec1ac25e20cb8a0b2541584510faafbdd Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Fri, 6 Dec 2024 11:05:33 -0800 Subject: [PATCH 6/7] Fix mypy errors --- libs/labelbox/src/labelbox/schema/user_group_upload.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/user_group_upload.py b/libs/labelbox/src/labelbox/schema/user_group_upload.py index b15341f35..904b5057f 100644 --- a/libs/labelbox/src/labelbox/schema/user_group_upload.py +++ b/libs/labelbox/src/labelbox/schema/user_group_upload.py @@ -10,7 +10,7 @@ ResourceNotFoundError, ) -from labelbox.client import Client +from labelbox import Client from labelbox.pagination import PaginatedCollection @@ -168,9 +168,9 @@ def upload_members( if not file_data or not file_data.get("importUsersAsCsvToGroup", None): try: errors = response.json().get("errors", []) - error_msg = next(iter(errors), {}).get( - "message", "Unknown error" - ) + error_msg = "Unknown error" + if errors: + error_msg = errors[0].get("message", "Unknown error") except Exception: error_msg = "Unknown error" raise LabelboxError("Failed to upload, message: %s" % error_msg) From 0409a96abd541cb0745456e331e98611d3e9c981 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Fri, 6 Dec 2024 11:14:54 -0800 Subject: [PATCH 7/7] Add experimental warning --- libs/labelbox/src/labelbox/schema/user_group_upload.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/libs/labelbox/src/labelbox/schema/user_group_upload.py b/libs/labelbox/src/labelbox/schema/user_group_upload.py index 904b5057f..8d6d2bc25 100644 --- a/libs/labelbox/src/labelbox/schema/user_group_upload.py +++ b/libs/labelbox/src/labelbox/schema/user_group_upload.py @@ -1,4 +1,5 @@ import json +import warnings from dataclasses import dataclass from io import BytesIO from typing import List, Optional @@ -78,6 +79,10 @@ def upload_members( For indicvidual email errors, the error message is available in the UploadReport. """ + warnings.warn( + "The upload_members for UserGroupUpload is in beta. The method name and signature may change in the future.”", + ) + if len(emails) == 0: print("No emails to upload.") return None