|
1 | 1 | import json
|
| 2 | +from dataclasses import dataclass |
2 | 3 | from io import BytesIO
|
3 | 4 | from typing import List, Optional
|
4 | 5 |
|
|
13 | 14 | from labelbox.pagination import PaginatedCollection
|
14 | 15 |
|
15 | 16 |
|
| 17 | +@dataclass |
| 18 | +class UploadReportLine: |
| 19 | + email: str |
| 20 | + result: str |
| 21 | + error: Optional[str] = None |
| 22 | + |
| 23 | + |
| 24 | +@dataclass |
| 25 | +class UploadReport: |
| 26 | + lines: List[UploadReportLine] |
| 27 | + |
| 28 | + |
16 | 29 | class UserGroupUpload:
|
17 | 30 | def __init__(self, client: Client):
|
18 | 31 | self.client = client
|
19 | 32 |
|
20 |
| - def upload_members(self, group_id: str, role: str, emails: List[str]): |
| 33 | + def upload_members( |
| 34 | + self, group_id: str, role: str, emails: List[str] |
| 35 | + ) -> Optional[UploadReport]: |
21 | 36 | if len(emails) == 0:
|
22 | 37 | print("No emails to upload.")
|
23 |
| - return |
| 38 | + return None |
24 | 39 |
|
25 | 40 | role_id = self._get_role_id(role)
|
26 | 41 | if role_id is None:
|
27 |
| - raise ResourceNotFoundError(message="The role does not exist.") |
| 42 | + raise ResourceNotFoundError( |
| 43 | + message="Could not find a valid role with the name provided. Please make sure the role name is correct." |
| 44 | + ) |
28 | 45 |
|
29 | 46 | buffer = BytesIO()
|
30 | 47 | buffer.write(b"email\n") # Header row
|
@@ -146,11 +163,20 @@ def _get_role_id(self, role_name: str) -> Optional[str]:
|
146 | 163 |
|
147 | 164 | return role_id
|
148 | 165 |
|
149 |
| - def _parse_csv_report(self, csv_report: str) -> List[dict]: |
| 166 | + def _parse_csv_report(self, csv_report: str) -> UploadReport: |
150 | 167 | lines = csv_report.strip().split("\n")
|
151 | 168 | headers = lines[0].split(",")
|
152 |
| - report_list = [] |
| 169 | + report_lines = [] |
153 | 170 | for line in lines[1:]:
|
154 | 171 | values = line.split(",")
|
155 |
| - report_list.append(dict(zip(headers, values))) |
156 |
| - return report_list |
| 172 | + row = dict(zip(headers, values)) |
| 173 | + report_lines.append( |
| 174 | + UploadReportLine( |
| 175 | + email=row["Email"], |
| 176 | + result=row["Result"], |
| 177 | + error=row.get( |
| 178 | + "Error" |
| 179 | + ), # Using get() since error is optional |
| 180 | + ) |
| 181 | + ) |
| 182 | + return UploadReport(lines=report_lines) |
0 commit comments