Skip to content

Commit 2c8e358

Browse files
author
Val Brodsky
committed
Add support for member upload
1 parent d9464b9 commit 2c8e358

File tree

1 file changed

+111
-1
lines changed

1 file changed

+111
-1
lines changed

libs/labelbox/src/labelbox/schema/user_group_upload.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1+
import json
2+
from io import BytesIO
13
from typing import List, Optional
24

3-
from lbox.exceptions import ResourceNotFoundError
5+
import requests
6+
from lbox.exceptions import (
7+
InternalServerError,
8+
LabelboxError,
9+
ResourceNotFoundError,
10+
)
411

512
from labelbox.client import Client
613
from labelbox.pagination import PaginatedCollection
@@ -11,10 +18,104 @@ def __init__(self, client: Client):
1118
self.client = client
1219

1320
def upload_members(self, group_id: str, role: str, emails: List[str]):
21+
if len(emails) == 0:
22+
print("No emails to upload.")
23+
return
24+
1425
role_id = self._get_role_id(role)
1526
if role_id is None:
1627
raise ResourceNotFoundError(message="The role does not exist.")
1728

29+
buffer = BytesIO()
30+
buffer.write(b"email\n") # Header row
31+
for email in emails:
32+
buffer.write(f"{email}\n".encode("utf-8"))
33+
# Reset pointer to start of stream
34+
buffer.seek(0)
35+
36+
multipart_file_field = "1"
37+
gql_file_field = "file"
38+
files = {
39+
multipart_file_field: (
40+
f"{multipart_file_field}.csv",
41+
buffer,
42+
"text/csv",
43+
)
44+
}
45+
query = """mutation ImportMembersToGroup(
46+
$roleId: ID!
47+
$file: Upload!
48+
$where: WhereUniqueIdInput!
49+
) {
50+
importUsersAsCsvToGroup(roleId: $roleId, file: $file, where: $where) {
51+
csvReport
52+
addedCount
53+
count
54+
}
55+
}
56+
"""
57+
params = {
58+
"roleId": role_id,
59+
gql_file_field: None,
60+
"where": {"id": group_id},
61+
}
62+
63+
request_data = {
64+
"operations": json.dumps(
65+
{
66+
"variables": params,
67+
"query": query,
68+
}
69+
),
70+
"map": (
71+
None,
72+
json.dumps(
73+
{multipart_file_field: [f"variables.{gql_file_field}"]}
74+
),
75+
),
76+
}
77+
78+
client = self.client
79+
headers = dict(client.connection.headers)
80+
headers.pop("Content-Type", None)
81+
request = requests.Request(
82+
"POST",
83+
client.endpoint,
84+
headers=headers,
85+
data=request_data,
86+
files=files,
87+
)
88+
89+
prepped: requests.PreparedRequest = request.prepare()
90+
91+
response = client.connection.send(prepped)
92+
93+
if response.status_code == 502:
94+
error_502 = "502 Bad Gateway"
95+
raise InternalServerError(error_502)
96+
elif response.status_code == 503:
97+
raise InternalServerError(response.text)
98+
elif response.status_code == 520:
99+
raise InternalServerError(response.text)
100+
101+
try:
102+
file_data = response.json().get("data", None)
103+
except ValueError as e: # response is not valid JSON
104+
raise LabelboxError("Failed to upload, unknown cause", e)
105+
106+
if not file_data or not file_data.get("importUsersAsCsvToGroup", None):
107+
try:
108+
errors = response.json().get("errors", [])
109+
error_msg = next(iter(errors), {}).get(
110+
"message", "Unknown error"
111+
)
112+
except Exception:
113+
error_msg = "Unknown error"
114+
raise LabelboxError("Failed to upload, message: %s" % error_msg)
115+
116+
csv_report = file_data["importUsersAsCsvToGroup"]["csvReport"]
117+
return self._parse_csv_report(csv_report)
118+
18119
def _get_role_id(self, role_name: str) -> Optional[str]:
19120
role_id = None
20121
query = """query GetAvailableUserRolesPyPi {
@@ -44,3 +145,12 @@ def _get_role_id(self, role_name: str) -> Optional[str]:
44145
break
45146

46147
return role_id
148+
149+
def _parse_csv_report(self, csv_report: str) -> List[dict]:
150+
lines = csv_report.strip().split("\n")
151+
headers = lines[0].split(",")
152+
report_list = []
153+
for line in lines[1:]:
154+
values = line.split(",")
155+
report_list.append(dict(zip(headers, values)))
156+
return report_list

0 commit comments

Comments
 (0)