Skip to content

Commit 58e48cc

Browse files
author
Val Brodsky
committed
Refactor upsert code so that it can be reused for create
Extract spec generation Extract data row upload logic Extract chunk generation and upload Update create data row Rename DatarowUploader --> DataRowUploader Reuse upsert backend for create_data_rows Add DataUpsertTask
1 parent cbf5dd1 commit 58e48cc

File tree

9 files changed

+570
-350
lines changed

9 files changed

+570
-350
lines changed

libs/labelbox/src/labelbox/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def upload_data(self,
407407
}),
408408
"map": (None, json.dumps({"1": ["variables.file"]})),
409409
}
410+
410411
response = requests.post(
411412
self.endpoint,
412413
headers={"authorization": "Bearer %s" % self.api_key},

libs/labelbox/src/labelbox/schema/dataset.py

Lines changed: 36 additions & 335 deletions
Large diffs are not rendered by default.
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List, Tuple, Optional
3+
4+
from labelbox.schema.identifiable import UniqueId, GlobalKey
5+
from labelbox.pydantic_compat import BaseModel
6+
7+
8+
class DataRowItemBase(BaseModel, ABC):
9+
id: dict
10+
payload: dict
11+
12+
@classmethod
13+
@abstractmethod
14+
def build(
15+
cls,
16+
dataset_id: str,
17+
items: List[dict],
18+
key_types: Optional[Tuple[type, ...]] = ()
19+
) -> List["DataRowItemBase"]:
20+
upload_items = []
21+
22+
for item in items:
23+
# enforce current dataset's id for all specs
24+
item['dataset_id'] = dataset_id
25+
key = item.pop('key', None)
26+
if not key:
27+
key = {'type': 'AUTO', 'value': ''}
28+
elif isinstance(key, key_types):
29+
key = {'type': key.id_type.value, 'value': key.key}
30+
else:
31+
if not key_types:
32+
raise ValueError(
33+
f"Can not have a key for this item, got: {key}"
34+
)
35+
raise ValueError(
36+
f"Key must be an instance of {', '.join([t.__name__ for t in key_types])}, got: {type(item['key']).__name__}"
37+
)
38+
item = {
39+
k: v for k, v in item.items() if v is not None
40+
} # remove None values
41+
upload_items.append(cls(payload=item, id=key))
42+
return upload_items
43+
44+
def is_empty(self) -> bool:
45+
"""
46+
The payload is considered empty if it's actually empty or the only key is `dataset_id`.
47+
:return: bool
48+
"""
49+
return (not self.payload or
50+
len(self.payload.keys()) == 1 and "dataset_id" in self.payload)
51+
52+
53+
class DataRowUpsertItem(DataRowItemBase):
54+
55+
@classmethod
56+
def build(cls, dataset_id: str,
57+
items: List[dict]) -> List["DataRowUpsertItem"]:
58+
return super().build(dataset_id, items, (UniqueId, GlobalKey))
59+
60+
61+
class DataRowCreateItem(DataRowItemBase):
62+
63+
@classmethod
64+
def build(cls, dataset_id: str,
65+
items: List[dict]) -> List["DataRowCreateItem"]:
66+
return super().build(dataset_id, items, ())
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
import json
2+
import os
3+
from concurrent.futures import ThreadPoolExecutor, as_completed
4+
5+
from typing import Iterable, List
6+
7+
from labelbox.exceptions import InvalidQueryError
8+
from labelbox.exceptions import InvalidAttributeError
9+
from labelbox.exceptions import MalformedQueryException
10+
from labelbox.orm.model import Entity
11+
from labelbox.orm.model import Field
12+
from labelbox.schema.embedding import EmbeddingVector
13+
from labelbox.schema.internal.data_row_create_upsert import DataRowItemBase
14+
from labelbox.schema.internal.datarow_upload_constants import MAX_DATAROW_PER_API_OPERATION
15+
16+
17+
class UploadManifest:
18+
19+
def __init__(self, source: str, item_count: int, chunk_uris: List[str]):
20+
self.source = source
21+
self.item_count = item_count
22+
self.chunk_uris = chunk_uris
23+
24+
def to_dict(self):
25+
return {
26+
"source": self.source,
27+
"item_count": self.item_count,
28+
"chunk_uris": self.chunk_uris
29+
}
30+
31+
32+
class DataRowUploader:
33+
34+
@staticmethod
35+
def create_descriptor_file(client,
36+
items,
37+
max_attachments_per_data_row=None,
38+
is_upsert=False):
39+
"""
40+
This function is shared by `Dataset.create_data_rows`, `Dataset.create_data_rows_sync` and `Dataset.update_data_rows`.
41+
It is used to prepare the input file. The user defined input is validated, processed, and json stringified.
42+
Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed to
43+
44+
Each element in `items` can be either a `str` or a `dict`. If
45+
it is a `str`, then it is interpreted as a local file path. The file
46+
is uploaded to Labelbox and a DataRow referencing it is created.
47+
48+
If an item is a `dict`, then it could support one of the two following structures
49+
1. For static imagery, video, and text it should map `DataRow` field names to values.
50+
At the minimum an `item` passed as a `dict` must contain a `row_data` key and value.
51+
If the value for row_data is a local file path and the path exists,
52+
then the local file will be uploaded to labelbox.
53+
54+
2. For tiled imagery the dict must match the import structure specified in the link below
55+
https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import
56+
57+
>>> dataset.create_data_rows([
58+
>>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"},
59+
>>> {DataRow.row_data:"/path/to/file1.jpg"},
60+
>>> "path/to/file2.jpg",
61+
>>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}}
62+
>>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}}
63+
>>> ])
64+
65+
For an example showing how to upload tiled data_rows see the following notebook:
66+
https://github.com/Labelbox/labelbox-python/blob/ms/develop/model_assisted_labeling/tiled_imagery_mal.ipynb
67+
68+
Args:
69+
items (iterable of (dict or str)): See above for details.
70+
max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine
71+
if the user has provided too many attachments.
72+
73+
Returns:
74+
uri (string): A reference to the uploaded json data.
75+
76+
Raises:
77+
InvalidQueryError: If the `items` parameter does not conform to
78+
the specification above or if the server did not accept the
79+
DataRow creation request (unknown reason).
80+
InvalidAttributeError: If there are fields in `items` not valid for
81+
a DataRow.
82+
ValueError: When the upload parameters are invalid
83+
"""
84+
file_upload_thread_count = 20
85+
DataRow = Entity.DataRow
86+
AssetAttachment = Entity.AssetAttachment
87+
88+
def upload_if_necessary(item):
89+
if is_upsert and 'row_data' not in item:
90+
# When upserting, row_data is not required
91+
return item
92+
row_data = item['row_data']
93+
if isinstance(row_data, str) and os.path.exists(row_data):
94+
item_url = client.upload_file(row_data)
95+
item['row_data'] = item_url
96+
if 'external_id' not in item:
97+
# Default `external_id` to local file name
98+
item['external_id'] = row_data
99+
return item
100+
101+
def validate_attachments(item):
102+
attachments = item.get('attachments')
103+
if attachments:
104+
if isinstance(attachments, list):
105+
if max_attachments_per_data_row and len(
106+
attachments) > max_attachments_per_data_row:
107+
raise ValueError(
108+
f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}."
109+
f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary."
110+
)
111+
for attachment in attachments:
112+
AssetAttachment.validate_attachment_json(attachment)
113+
else:
114+
raise ValueError(
115+
f"Attachments must be a list. Found {type(attachments)}"
116+
)
117+
return attachments
118+
119+
def validate_embeddings(item):
120+
embeddings = item.get("embeddings")
121+
if embeddings:
122+
item["embeddings"] = [
123+
EmbeddingVector(**e).to_gql() for e in embeddings
124+
]
125+
126+
def validate_conversational_data(conversational_data: list) -> None:
127+
"""
128+
Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json
129+
130+
Args:
131+
conversational_data (list): list of dictionaries.
132+
"""
133+
134+
def check_message_keys(message):
135+
accepted_message_keys = set([
136+
"messageId", "timestampUsec", "content", "user", "align",
137+
"canLabel"
138+
])
139+
for key in message.keys():
140+
if not key in accepted_message_keys:
141+
raise KeyError(
142+
f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}"
143+
)
144+
145+
if conversational_data and not isinstance(conversational_data,
146+
list):
147+
raise ValueError(
148+
f"conversationalData must be a list. Found {type(conversational_data)}"
149+
)
150+
151+
[check_message_keys(message) for message in conversational_data]
152+
153+
def parse_metadata_fields(item):
154+
metadata_fields = item.get('metadata_fields')
155+
if metadata_fields:
156+
mdo = client.get_data_row_metadata_ontology()
157+
item['metadata_fields'] = mdo.parse_upsert_metadata(
158+
metadata_fields)
159+
160+
def format_row(item):
161+
# Formats user input into a consistent dict structure
162+
if isinstance(item, dict):
163+
# Convert fields to strings
164+
item = {
165+
key.name if isinstance(key, Field) else key: value
166+
for key, value in item.items()
167+
}
168+
elif isinstance(item, str):
169+
# The main advantage of using a string over a dict is that the user is specifying
170+
# that the file should exist locally.
171+
# That info is lost after this section so we should check for it here.
172+
if not os.path.exists(item):
173+
raise ValueError(f"Filepath {item} does not exist.")
174+
item = {"row_data": item, "external_id": item}
175+
return item
176+
177+
def validate_keys(item):
178+
if not is_upsert and 'row_data' not in item:
179+
raise InvalidQueryError(
180+
"`row_data` missing when creating DataRow.")
181+
182+
if isinstance(item.get('row_data'),
183+
str) and item.get('row_data').startswith("s3:/"):
184+
raise InvalidQueryError(
185+
"row_data: s3 assets must start with 'https'.")
186+
allowed_extra_fields = {
187+
'attachments', 'media_type', 'dataset_id', 'embeddings'
188+
}
189+
invalid_keys = set(item) - {f.name for f in DataRow.fields()
190+
} - allowed_extra_fields
191+
if invalid_keys:
192+
raise InvalidAttributeError(DataRow, invalid_keys)
193+
return item
194+
195+
def formatLegacyConversationalData(item):
196+
messages = item.pop("conversationalData")
197+
version = item.pop("version", 1)
198+
type = item.pop("type", "application/vnd.labelbox.conversational")
199+
if "externalId" in item:
200+
external_id = item.pop("externalId")
201+
item["external_id"] = external_id
202+
if "globalKey" in item:
203+
global_key = item.pop("globalKey")
204+
item["globalKey"] = global_key
205+
validate_conversational_data(messages)
206+
one_conversation = \
207+
{
208+
"type": type,
209+
"version": version,
210+
"messages": messages
211+
}
212+
item["row_data"] = one_conversation
213+
return item
214+
215+
def convert_item(data_row_item):
216+
if isinstance(data_row_item, DataRowItemBase):
217+
item = data_row_item.payload
218+
else:
219+
item = data_row_item
220+
221+
if "tileLayerUrl" in item:
222+
validate_attachments(item)
223+
return item
224+
225+
if "conversationalData" in item:
226+
formatLegacyConversationalData(item)
227+
228+
# Convert all payload variations into the same dict format
229+
item = format_row(item)
230+
# Make sure required keys exist (and there are no extra keys)
231+
validate_keys(item)
232+
# Make sure attachments are valid
233+
validate_attachments(item)
234+
# Make sure embeddings are valid
235+
validate_embeddings(item)
236+
# Parse metadata fields if they exist
237+
parse_metadata_fields(item)
238+
# Upload any local file paths
239+
item = upload_if_necessary(item)
240+
241+
if isinstance(data_row_item, DataRowItemBase):
242+
return {'id': data_row_item.id, 'payload': item}
243+
else:
244+
return item
245+
246+
if not isinstance(items, Iterable):
247+
raise ValueError(
248+
f"Must pass an iterable to create_data_rows. Found {type(items)}"
249+
)
250+
251+
if len(items) > MAX_DATAROW_PER_API_OPERATION:
252+
raise MalformedQueryException(
253+
f"Cannot create more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call."
254+
)
255+
256+
with ThreadPoolExecutor(file_upload_thread_count) as executor:
257+
futures = [executor.submit(convert_item, item) for item in items]
258+
items = [future.result() for future in as_completed(futures)]
259+
# Prepare and upload the desciptor file
260+
data = json.dumps(items)
261+
return client.upload_data(data,
262+
content_type="application/json",
263+
filename="json_import.json")
264+
265+
@staticmethod
266+
def upload_in_chunks(client, specs: List[DataRowItemBase],
267+
file_upload_thread_count: int,
268+
upsert_chunk_size: int) -> UploadManifest:
269+
empty_specs = list(filter(lambda spec: spec.is_empty(), specs))
270+
271+
if empty_specs:
272+
ids = list(map(lambda spec: spec.id.get("value"), empty_specs))
273+
raise ValueError(
274+
f"The following items have an empty payload: {ids}")
275+
276+
chunks = [
277+
specs[i:i + upsert_chunk_size]
278+
for i in range(0, len(specs), upsert_chunk_size)
279+
]
280+
281+
def _upload_chunk(_chunk):
282+
return DataRowUploader.create_descriptor_file(client,
283+
_chunk,
284+
is_upsert=True)
285+
286+
with ThreadPoolExecutor(file_upload_thread_count) as executor:
287+
futures = [
288+
executor.submit(_upload_chunk, chunk) for chunk in chunks
289+
]
290+
chunk_uris = [future.result() for future in as_completed(futures)]
291+
292+
return UploadManifest(source="SDK",
293+
item_count=len(specs),
294+
chunk_uris=chunk_uris)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
MAX_DATAROW_PER_API_OPERATION = 150_000
2+
FILE_UPLOAD_THREAD_COUNT = 20
3+
UPSERT_CHUNK_SIZE = 10_000

0 commit comments

Comments
 (0)