Skip to content

[PLT-2151] Vb/graphql upload plt 2151 #1937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 9 additions & 37 deletions libs/labelbox/src/labelbox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,21 @@ def execute(

Args:
query (str): The query to execute.
variables (dict): Variables to pass to the query.
params (dict): Variables to pass to the query.
data (dict): Includes the query and variables as well as the map for file upload multipart/form-data requests as per GraphQL multipart request specification.
files (dict): File descriptors to pass to the query for file upload multipart/form-data requests.
timeout (float): Timeout for the request.
experimental (bool): Whether to use experimental features.
error_log_key (str): Key to use for error logging.
raise_return_resource_not_found (bool): If True, raise a
ResourceNotFoundError if the query returns None.
error_handlers (dict): A dictionary mapping graphql error code to handler functions.
Allows a caller to handle specific errors reporting in a custom way or produce more user-friendly readable messages

Returns:
dict: The response from the server.

See UserGroupV2.upload_members for an example of how to use this method for file upload.
"""
return self._request_client.execute(
query,
Expand Down Expand Up @@ -264,42 +271,7 @@ def upload_data(
if (filename and content_type)
else content
}
headers = self.connection.headers.copy()
headers.pop("Content-Type", None)
request = requests.Request(
"POST",
self.endpoint,
headers=headers,
data=request_data,
files=files,
)

prepped: requests.PreparedRequest = request.prepare()

response = self.connection.send(prepped)

if response.status_code == 502:
error_502 = "502 Bad Gateway"
raise InternalServerError(error_502)
elif response.status_code == 503:
raise InternalServerError(response.text)
elif response.status_code == 520:
raise InternalServerError(response.text)

try:
file_data = response.json().get("data", None)
except ValueError as e: # response is not valid JSON
raise LabelboxError("Failed to upload, unknown cause", e)

if not file_data or not file_data.get("uploadFile", None):
try:
errors = response.json().get("errors", [])
error_msg = next(iter(errors), {}).get(
"message", "Unknown error"
)
except Exception:
error_msg = "Unknown error"
raise LabelboxError("Failed to upload, message: %s" % error_msg)
file_data = self.execute(data=request_data, files=files)

return file_data["uploadFile"]["url"]

Expand Down
76 changes: 22 additions & 54 deletions libs/labelbox/src/labelbox/schema/user_group_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from io import BytesIO
from typing import List, Optional

import requests
from lbox.exceptions import (
InternalServerError,
LabelboxError,
ResourceNotFoundError,
)
Expand Down Expand Up @@ -107,11 +105,14 @@ def upload_members(
# Reset pointer to start of stream
buffer.seek(0)

multipart_file_field = "1"
# Use 0-based indexing as per common convention
multipart_file_field = "0"
gql_file_field = "file"

# Prepare the file content
files = {
multipart_file_field: (
f"{multipart_file_field}.csv",
"members.csv", # More descriptive filename
buffer,
"text/csv",
)
Expand All @@ -128,63 +129,30 @@ def upload_members(
}
}
"""
params = {
"roleId": role_id,
gql_file_field: None,
"where": {"id": group_id},
# Construct the multipart request following the spec
operations = {
"query": query,
"variables": {
"roleId": role_id,
gql_file_field: None, # Placeholder for file
"where": {"id": group_id},
},
}

# Map file to the variable
map_data = {multipart_file_field: [f"variables.{gql_file_field}"]}

request_data = {
"operations": json.dumps(
{
"variables": params,
"query": query,
}
),
"map": (
None,
json.dumps(
{multipart_file_field: [f"variables.{gql_file_field}"]}
),
),
"operations": json.dumps(operations),
"map": json.dumps(
map_data
), # Remove the unnecessary (None, ...) tuple
}

client = self.client
headers = dict(client.connection.headers)
headers.pop("Content-Type", None)
request = requests.Request(
"POST",
client.endpoint,
headers=headers,
data=request_data,
files=files,
)

prepped: requests.PreparedRequest = request.prepare()

response = client.connection.send(prepped)

if response.status_code == 502:
error_502 = "502 Bad Gateway"
raise InternalServerError(error_502)
elif response.status_code == 503:
raise InternalServerError(response.text)
elif response.status_code == 520:
raise InternalServerError(response.text)

try:
file_data = response.json().get("data", None)
except ValueError as e: # response is not valid JSON
raise LabelboxError("Failed to upload, unknown cause", e)
file_data = self.client.execute(data=request_data, files=files)

if not file_data or not file_data.get("importUsersAsCsvToGroup", None):
try:
errors = response.json().get("errors", [])
error_msg = "Unknown error"
if errors:
error_msg = errors[0].get("message", "Unknown error")
except Exception:
error_msg = "Unknown error"
error_msg = "Unknown error"
raise LabelboxError("Failed to upload, message: %s" % error_msg)

csv_report = file_data["importUsersAsCsvToGroup"]["csvReport"]
Expand Down
Loading