diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 6269f4927..af3b5b3fc 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -179,7 +179,12 @@ def execute( Args: query (str): The query to execute. - variables (dict): Variables to pass to the query. + params (dict): Variables to pass to the query. + data (dict): Includes the query and variables as well as the map for file upload multipart/form-data requests as per GraphQL multipart request specification. + files (dict): File descriptors to pass to the query for file upload multipart/form-data requests. + timeout (float): Timeout for the request. + experimental (bool): Whether to use experimental features. + error_log_key (str): Key to use for error logging. raise_return_resource_not_found (bool): If True, raise a ResourceNotFoundError if the query returns None. error_handlers (dict): A dictionary mapping graphql error code to handler functions. @@ -187,6 +192,8 @@ def execute( Returns: dict: The response from the server. + + See UserGroupV2.upload_members for an example of how to use this method for file upload. """ return self._request_client.execute( query, @@ -264,42 +271,7 @@ def upload_data( if (filename and content_type) else content } - headers = self.connection.headers.copy() - headers.pop("Content-Type", None) - request = requests.Request( - "POST", - self.endpoint, - headers=headers, - data=request_data, - files=files, - ) - - prepped: requests.PreparedRequest = request.prepare() - - response = self.connection.send(prepped) - - if response.status_code == 502: - error_502 = "502 Bad Gateway" - raise InternalServerError(error_502) - elif response.status_code == 503: - raise InternalServerError(response.text) - elif response.status_code == 520: - raise InternalServerError(response.text) - - try: - file_data = response.json().get("data", None) - except ValueError as e: # response is not valid JSON - raise LabelboxError("Failed to upload, unknown cause", e) - - if not file_data or not file_data.get("uploadFile", None): - try: - errors = response.json().get("errors", []) - error_msg = next(iter(errors), {}).get( - "message", "Unknown error" - ) - except Exception: - error_msg = "Unknown error" - raise LabelboxError("Failed to upload, message: %s" % error_msg) + file_data = self.execute(data=request_data, files=files) return file_data["uploadFile"]["url"] diff --git a/libs/labelbox/src/labelbox/schema/user_group_v2.py b/libs/labelbox/src/labelbox/schema/user_group_v2.py index a734eb397..0880123b1 100644 --- a/libs/labelbox/src/labelbox/schema/user_group_v2.py +++ b/libs/labelbox/src/labelbox/schema/user_group_v2.py @@ -4,9 +4,7 @@ from io import BytesIO from typing import List, Optional -import requests from lbox.exceptions import ( - InternalServerError, LabelboxError, ResourceNotFoundError, ) @@ -107,11 +105,14 @@ def upload_members( # Reset pointer to start of stream buffer.seek(0) - multipart_file_field = "1" + # Use 0-based indexing as per common convention + multipart_file_field = "0" gql_file_field = "file" + + # Prepare the file content files = { multipart_file_field: ( - f"{multipart_file_field}.csv", + "members.csv", # More descriptive filename buffer, "text/csv", ) @@ -128,63 +129,30 @@ def upload_members( } } """ - params = { - "roleId": role_id, - gql_file_field: None, - "where": {"id": group_id}, + # Construct the multipart request following the spec + operations = { + "query": query, + "variables": { + "roleId": role_id, + gql_file_field: None, # Placeholder for file + "where": {"id": group_id}, + }, } + # Map file to the variable + map_data = {multipart_file_field: [f"variables.{gql_file_field}"]} + request_data = { - "operations": json.dumps( - { - "variables": params, - "query": query, - } - ), - "map": ( - None, - json.dumps( - {multipart_file_field: [f"variables.{gql_file_field}"]} - ), - ), + "operations": json.dumps(operations), + "map": json.dumps( + map_data + ), # Remove the unnecessary (None, ...) tuple } - client = self.client - headers = dict(client.connection.headers) - headers.pop("Content-Type", None) - request = requests.Request( - "POST", - client.endpoint, - headers=headers, - data=request_data, - files=files, - ) - - prepped: requests.PreparedRequest = request.prepare() - - response = client.connection.send(prepped) - - if response.status_code == 502: - error_502 = "502 Bad Gateway" - raise InternalServerError(error_502) - elif response.status_code == 503: - raise InternalServerError(response.text) - elif response.status_code == 520: - raise InternalServerError(response.text) - - try: - file_data = response.json().get("data", None) - except ValueError as e: # response is not valid JSON - raise LabelboxError("Failed to upload, unknown cause", e) + file_data = self.client.execute(data=request_data, files=files) if not file_data or not file_data.get("importUsersAsCsvToGroup", None): - try: - errors = response.json().get("errors", []) - error_msg = "Unknown error" - if errors: - error_msg = errors[0].get("message", "Unknown error") - except Exception: - error_msg = "Unknown error" + error_msg = "Unknown error" raise LabelboxError("Failed to upload, message: %s" % error_msg) csv_report = file_data["importUsersAsCsvToGroup"]["csvReport"]