Skip to content
This repository was archived by the owner on Jul 29, 2024. It is now read-only.

Commit fa81989

Browse files
Update data_rows.py
1 parent a5a7bfd commit fa81989

File tree

1 file changed

+35
-18
lines changed

1 file changed

+35
-18
lines changed

labelpandas/data_rows.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66
from concurrent.futures import ThreadPoolExecutor, as_completed
77

88
def create_data_row_upload_dict(client:labelboxClient, table:pd.core.frame.DataFrame,
9-
row_data_col:str, global_key_col:str, external_id_col:str,
10-
metadata_index:dict, attachment_index:dict,
9+
row_data_col:str, global_key_col:str, external_id_col:str, dataset_id_col:str,
10+
dataset_id:str, metadata_index:dict, attachment_index:dict,
1111
divider:str, verbose:bool):
1212
""" Multithreads over a Pandas DataFrame, calling create_data_rows() on each row to return an upload dictionary
1313
Args:
14+
client : Required (labelbox.client.Client) - Labelbox Client object
1415
table : Required (pandas.core.frame.DataFrame) - Pandas DataFrame
15-
client : Required (labelbox.client.Client) - Labelbox Client object
1616
row_data_col : Required (str) - Column containing asset URL or raw text
1717
global_key_col : Required (str) - Column name containing the data row global key - defaults to row data
1818
external_id_col : Required (str) - Column name containing the data row external ID - defaults to global key
19+
dataset_id_col : Required (str) - Column name containing the dataset ID to add data rows to
20+
dataset_id : Required (str) - Default dataset if dataset_id == ""
1921
metadata_index : Required (dict) - Dictonary where {key=column_name : value=metadata_type}
2022
attachment_index : Required (dict) - Dictonary where {key=column_name : value=attachment_type}
2123
divider : Required (str) - String delimiter for all name keys generated
@@ -32,6 +34,10 @@ def create_data_row_upload_dict(client:labelboxClient, table:pd.core.frame.DataF
3234
print(f"Warning: Your global key column is not unique - upload will resume, only uploading 1 data row per unique global key")
3335
metadata_schema_to_name_key = labelbase.metadata.get_metadata_schema_to_name_key(client=lb_client, lb_mdo=False, divider=divider, invert=False)
3436
metadata_name_key_to_schema = labelbase.metadata.get_metadata_schema_to_name_key(client=lb_client, lb_mdo=False, divider=divider, invert=True)
37+
if dataset_id:
38+
dataset_to_global_key_to_upload_dict = {dataset_id : {}}
39+
else:
40+
dataset_to_global_key_to_upload_dict = {id : {} for id in connector.get_unique_values_function(table=table)}
3541
with ThreadPoolExecutor(max_workers=8) as exc:
3642
global_key_to_upload_dict = {}
3743
errors = []
@@ -41,13 +47,15 @@ def create_data_row_upload_dict(client:labelboxClient, table:pd.core.frame.DataF
4147
for index, row in tqdm(table.iterrows()):
4248
futures.append(exc.submit(
4349
create_data_rows, client, row, metadata_name_key_to_schema, metadata_schema_to_name_key,
44-
row_data_col, global_key_col, external_id_col, metadata_index, attachment_index, divider
50+
row_data_col, global_key_col, external_id_col, dataset_id_col,
51+
dataset_id, metadata_index, attachment_index, divider
4552
))
4653
else:
4754
for index, row in table.iterrows():
4855
futures.append(exc.submit(
4956
create_data_rows, client, row, metadata_name_key_to_schema, metadata_schema_to_name_key,
50-
row_data_col, global_key_col, external_id_col, metadata_index, attachment_index, divider
57+
row_data_col, global_key_col, external_id_col, dataset_id_col,
58+
dataset_id, metadata_index, attachment_index, divider
5159
))
5260
if verbose:
5361
print(f'Processing data rows...')
@@ -56,22 +64,28 @@ def create_data_row_upload_dict(client:labelboxClient, table:pd.core.frame.DataF
5664
if res['error']:
5765
errors.append(res)
5866
else:
59-
global_key_to_upload_dict[str(res['data_row']["global_key"])] = res['data_row']
67+
id = str(list(res.keys()))[0]
68+
data_row_dict = res["res"][id]
69+
global_key = str(data_row_dict["global_key"])
70+
dataset_to_global_key_to_upload_dict[id].update({global_key:data_row_dict})
6071
else:
6172
for f in as_completed(futures):
6273
res = f.result()
6374
if res['error']:
6475
errors.append(res)
6576
else:
66-
global_key_to_upload_dict[str(res['data_row']["global_key"])] = res['data_row']
77+
id = str(list(res.keys()))[0]
78+
data_row_dict = res["res"][id]
79+
global_key = str(data_row_dict["global_key"])
80+
dataset_to_global_key_to_upload_dict[id].update({global_key:data_row_dict})
6781
if verbose:
68-
print(f'Generated upload list - {len(global_key_to_upload_dict)} data rows to upload')
82+
print(f'Generated upload list')
6983
return global_key_to_upload_dict, errors
7084

7185
def create_data_rows(client:labelboxClient, row:pandas.core.series.Series,
7286
metadata_name_key_to_schema:dict, metadata_schema_to_name_key:dict,
73-
row_data_col:str, global_key_col:str, external_id_col:str,
74-
metadata_index:dict, attachment_index:dict,
87+
row_data_col:str, global_key_col:str, external_id_col:str, dataset_id_col:str,
88+
metadata_index:str, metadata_index:dict, attachment_index:dict,
7589
divider:str):
7690
""" Function to-be-multithreaded to create data row dictionaries from a Pandas DataFrame
7791
Args:
@@ -82,6 +96,8 @@ def create_data_rows(client:labelboxClient, row:pandas.core.series.Series,
8296
row_data_col : Required (str) - Column containing asset URL or raw text
8397
global_key_col : Required (str) - Column name containing the data row global key
8498
external_id_col : Required (str) - Column name containing the data row external ID
99+
dataset_id_col : Required (str) - Column name containing the dataset ID to add data rows to
100+
dataset_id : Required (str) - Default dataset if dataset_id_col == ""
85101
metadata_index : Required (dict) - Dictonary where {key=column_name : value=metadata_type}
86102
attachment_index : Required (dict) - Dictonary where {key=column_name : value=attachment_type}
87103
divider : Required (str) - String delimiter for all name keys generated
@@ -90,11 +106,13 @@ def create_data_rows(client:labelboxClient, row:pandas.core.series.Series,
90106
- "error" - If there's value in the "error" key, the script will scip it on upload and return the error at the end
91107
- "data_row" - Dictionary with "global_key" "external_id" "row_data" and "metadata_fields" keys in the proper format to-be-uploaded
92108
"""
93-
return_value = {"error" : None, "data_row" : {}}
109+
return_value = {"error" : None, "res" : {}}
94110
try:
95-
return_value["data_row"]["row_data"] = str(row[row_data_col])
96-
return_value["data_row"]["global_key"] = str(row[global_key_col])
97-
return_value["data_row"]["external_id"] = str(row[external_id_col])
111+
id = dataset_id if dataset_id else row["dataset_id_col"]
112+
return_value["res"] = {id : {}}
113+
return_value["res"][id]["row_data"] = str(row[row_data_col])
114+
return_value["res"][id]["global_key"] = str(row[global_key_col])
115+
return_value["res"][id]["external_id"] = str(row[external_id_col])
98116
metadata_fields = [{"schema_id" : metadata_name_key_to_schema['lb_integration_source'], "value" : "Pandas"}]
99117
if metadata_index:
100118
for metadata_field_name in metadata_index.keys():
@@ -106,12 +124,11 @@ def create_data_rows(client:labelboxClient, row:pandas.core.series.Series,
106124
metadata_fields.append({"schema_id" : metadata_name_key_to_schema[metadata_field_name], "value" : input_metadata})
107125
else:
108126
continue
109-
return_value["data_row"]["metadata_fields"] = metadata_fields
127+
return_value["res"][id]["metadata_fields"] = metadata_fields
110128
if attachment_index:
111-
return_value['data_row']['attachments'] = []
129+
return_value["res"][id]["attachments"] = []
112130
for column_name in attachment_index:
113-
return_value['data_row']['attachments'].append({"type" : attachment_index[column_name], "value" : row[column_name]})
131+
return_value["res"][id]['attachments'].append({"type" : attachment_index[column_name], "value" : row[column_name]})
114132
except Exception as e:
115133
return_value["error"] = e
116-
return_value["data_row"]["global_key"] = str(row[global_key_col])
117134
return return_value

0 commit comments

Comments
 (0)