1
+ import json
2
+ from io import BytesIO
1
3
from typing import List , Optional
2
4
3
- from lbox .exceptions import ResourceNotFoundError
5
+ import requests
6
+ from lbox .exceptions import (
7
+ InternalServerError ,
8
+ LabelboxError ,
9
+ ResourceNotFoundError ,
10
+ )
4
11
5
12
from labelbox .client import Client
6
13
from labelbox .pagination import PaginatedCollection
@@ -11,10 +18,104 @@ def __init__(self, client: Client):
11
18
self .client = client
12
19
13
20
def upload_members (self , group_id : str , role : str , emails : List [str ]):
21
+ if len (emails ) == 0 :
22
+ print ("No emails to upload." )
23
+ return
24
+
14
25
role_id = self ._get_role_id (role )
15
26
if role_id is None :
16
27
raise ResourceNotFoundError (message = "The role does not exist." )
17
28
29
+ buffer = BytesIO ()
30
+ buffer .write (b"email\n " ) # Header row
31
+ for email in emails :
32
+ buffer .write (f"{ email } \n " .encode ("utf-8" ))
33
+ # Reset pointer to start of stream
34
+ buffer .seek (0 )
35
+
36
+ multipart_file_field = "1"
37
+ gql_file_field = "file"
38
+ files = {
39
+ multipart_file_field : (
40
+ f"{ multipart_file_field } .csv" ,
41
+ buffer ,
42
+ "text/csv" ,
43
+ )
44
+ }
45
+ query = """mutation ImportMembersToGroup(
46
+ $roleId: ID!
47
+ $file: Upload!
48
+ $where: WhereUniqueIdInput!
49
+ ) {
50
+ importUsersAsCsvToGroup(roleId: $roleId, file: $file, where: $where) {
51
+ csvReport
52
+ addedCount
53
+ count
54
+ }
55
+ }
56
+ """
57
+ params = {
58
+ "roleId" : role_id ,
59
+ gql_file_field : None ,
60
+ "where" : {"id" : group_id },
61
+ }
62
+
63
+ request_data = {
64
+ "operations" : json .dumps (
65
+ {
66
+ "variables" : params ,
67
+ "query" : query ,
68
+ }
69
+ ),
70
+ "map" : (
71
+ None ,
72
+ json .dumps (
73
+ {multipart_file_field : [f"variables.{ gql_file_field } " ]}
74
+ ),
75
+ ),
76
+ }
77
+
78
+ client = self .client
79
+ headers = dict (client .connection .headers )
80
+ headers .pop ("Content-Type" , None )
81
+ request = requests .Request (
82
+ "POST" ,
83
+ client .endpoint ,
84
+ headers = headers ,
85
+ data = request_data ,
86
+ files = files ,
87
+ )
88
+
89
+ prepped : requests .PreparedRequest = request .prepare ()
90
+
91
+ response = client .connection .send (prepped )
92
+
93
+ if response .status_code == 502 :
94
+ error_502 = "502 Bad Gateway"
95
+ raise InternalServerError (error_502 )
96
+ elif response .status_code == 503 :
97
+ raise InternalServerError (response .text )
98
+ elif response .status_code == 520 :
99
+ raise InternalServerError (response .text )
100
+
101
+ try :
102
+ file_data = response .json ().get ("data" , None )
103
+ except ValueError as e : # response is not valid JSON
104
+ raise LabelboxError ("Failed to upload, unknown cause" , e )
105
+
106
+ if not file_data or not file_data .get ("importUsersAsCsvToGroup" , None ):
107
+ try :
108
+ errors = response .json ().get ("errors" , [])
109
+ error_msg = next (iter (errors ), {}).get (
110
+ "message" , "Unknown error"
111
+ )
112
+ except Exception :
113
+ error_msg = "Unknown error"
114
+ raise LabelboxError ("Failed to upload, message: %s" % error_msg )
115
+
116
+ csv_report = file_data ["importUsersAsCsvToGroup" ]["csvReport" ]
117
+ return self ._parse_csv_report (csv_report )
118
+
18
119
def _get_role_id (self , role_name : str ) -> Optional [str ]:
19
120
role_id = None
20
121
query = """query GetAvailableUserRolesPyPi {
@@ -44,3 +145,12 @@ def _get_role_id(self, role_name: str) -> Optional[str]:
44
145
break
45
146
46
147
return role_id
148
+
149
+ def _parse_csv_report (self , csv_report : str ) -> List [dict ]:
150
+ lines = csv_report .strip ().split ("\n " )
151
+ headers = lines [0 ].split ("," )
152
+ report_list = []
153
+ for line in lines [1 :]:
154
+ values = line .split ("," )
155
+ report_list .append (dict (zip (headers , values )))
156
+ return report_list
0 commit comments