Skip to content

Commit df66bcd

Browse files
authored
[PLT-2151] Vb/graphql upload plt 2151 (#1937)
1 parent 0482921 commit df66bcd

File tree

2 files changed

+31
-91
lines changed

2 files changed

+31
-91
lines changed

libs/labelbox/src/labelbox/client.py

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -179,14 +179,21 @@ def execute(
179179
180180
Args:
181181
query (str): The query to execute.
182-
variables (dict): Variables to pass to the query.
182+
params (dict): Variables to pass to the query.
183+
data (dict): Includes the query and variables as well as the map for file upload multipart/form-data requests as per GraphQL multipart request specification.
184+
files (dict): File descriptors to pass to the query for file upload multipart/form-data requests.
185+
timeout (float): Timeout for the request.
186+
experimental (bool): Whether to use experimental features.
187+
error_log_key (str): Key to use for error logging.
183188
raise_return_resource_not_found (bool): If True, raise a
184189
ResourceNotFoundError if the query returns None.
185190
error_handlers (dict): A dictionary mapping graphql error code to handler functions.
186191
Allows a caller to handle specific errors reporting in a custom way or produce more user-friendly readable messages
187192
188193
Returns:
189194
dict: The response from the server.
195+
196+
See UserGroupV2.upload_members for an example of how to use this method for file upload.
190197
"""
191198
return self._request_client.execute(
192199
query,
@@ -264,42 +271,7 @@ def upload_data(
264271
if (filename and content_type)
265272
else content
266273
}
267-
headers = self.connection.headers.copy()
268-
headers.pop("Content-Type", None)
269-
request = requests.Request(
270-
"POST",
271-
self.endpoint,
272-
headers=headers,
273-
data=request_data,
274-
files=files,
275-
)
276-
277-
prepped: requests.PreparedRequest = request.prepare()
278-
279-
response = self.connection.send(prepped)
280-
281-
if response.status_code == 502:
282-
error_502 = "502 Bad Gateway"
283-
raise InternalServerError(error_502)
284-
elif response.status_code == 503:
285-
raise InternalServerError(response.text)
286-
elif response.status_code == 520:
287-
raise InternalServerError(response.text)
288-
289-
try:
290-
file_data = response.json().get("data", None)
291-
except ValueError as e: # response is not valid JSON
292-
raise LabelboxError("Failed to upload, unknown cause", e)
293-
294-
if not file_data or not file_data.get("uploadFile", None):
295-
try:
296-
errors = response.json().get("errors", [])
297-
error_msg = next(iter(errors), {}).get(
298-
"message", "Unknown error"
299-
)
300-
except Exception:
301-
error_msg = "Unknown error"
302-
raise LabelboxError("Failed to upload, message: %s" % error_msg)
274+
file_data = self.execute(data=request_data, files=files)
303275

304276
return file_data["uploadFile"]["url"]
305277

libs/labelbox/src/labelbox/schema/user_group_v2.py

Lines changed: 22 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
from io import BytesIO
55
from typing import List, Optional
66

7-
import requests
87
from lbox.exceptions import (
9-
InternalServerError,
108
LabelboxError,
119
ResourceNotFoundError,
1210
)
@@ -107,11 +105,14 @@ def upload_members(
107105
# Reset pointer to start of stream
108106
buffer.seek(0)
109107

110-
multipart_file_field = "1"
108+
# Use 0-based indexing as per common convention
109+
multipart_file_field = "0"
111110
gql_file_field = "file"
111+
112+
# Prepare the file content
112113
files = {
113114
multipart_file_field: (
114-
f"{multipart_file_field}.csv",
115+
"members.csv", # More descriptive filename
115116
buffer,
116117
"text/csv",
117118
)
@@ -128,63 +129,30 @@ def upload_members(
128129
}
129130
}
130131
"""
131-
params = {
132-
"roleId": role_id,
133-
gql_file_field: None,
134-
"where": {"id": group_id},
132+
# Construct the multipart request following the spec
133+
operations = {
134+
"query": query,
135+
"variables": {
136+
"roleId": role_id,
137+
gql_file_field: None, # Placeholder for file
138+
"where": {"id": group_id},
139+
},
135140
}
136141

142+
# Map file to the variable
143+
map_data = {multipart_file_field: [f"variables.{gql_file_field}"]}
144+
137145
request_data = {
138-
"operations": json.dumps(
139-
{
140-
"variables": params,
141-
"query": query,
142-
}
143-
),
144-
"map": (
145-
None,
146-
json.dumps(
147-
{multipart_file_field: [f"variables.{gql_file_field}"]}
148-
),
149-
),
146+
"operations": json.dumps(operations),
147+
"map": json.dumps(
148+
map_data
149+
), # Remove the unnecessary (None, ...) tuple
150150
}
151151

152-
client = self.client
153-
headers = dict(client.connection.headers)
154-
headers.pop("Content-Type", None)
155-
request = requests.Request(
156-
"POST",
157-
client.endpoint,
158-
headers=headers,
159-
data=request_data,
160-
files=files,
161-
)
162-
163-
prepped: requests.PreparedRequest = request.prepare()
164-
165-
response = client.connection.send(prepped)
166-
167-
if response.status_code == 502:
168-
error_502 = "502 Bad Gateway"
169-
raise InternalServerError(error_502)
170-
elif response.status_code == 503:
171-
raise InternalServerError(response.text)
172-
elif response.status_code == 520:
173-
raise InternalServerError(response.text)
174-
175-
try:
176-
file_data = response.json().get("data", None)
177-
except ValueError as e: # response is not valid JSON
178-
raise LabelboxError("Failed to upload, unknown cause", e)
152+
file_data = self.client.execute(data=request_data, files=files)
179153

180154
if not file_data or not file_data.get("importUsersAsCsvToGroup", None):
181-
try:
182-
errors = response.json().get("errors", [])
183-
error_msg = "Unknown error"
184-
if errors:
185-
error_msg = errors[0].get("message", "Unknown error")
186-
except Exception:
187-
error_msg = "Unknown error"
155+
error_msg = "Unknown error"
188156
raise LabelboxError("Failed to upload, message: %s" % error_msg)
189157

190158
csv_report = file_data["importUsersAsCsvToGroup"]["csvReport"]

0 commit comments

Comments
 (0)