Skip to content

Commit 958198d

Browse files
committed
initial addition to include filters for export_labels
1 parent dba0ce8 commit 958198d

File tree

1 file changed

+50
-7
lines changed

1 file changed

+50
-7
lines changed

labelbox/schema/project.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def export_queued_data_rows(self, timeout_seconds=120):
191191
self.uid)
192192
time.sleep(sleep_time)
193193

194-
def video_label_generator(self, timeout_seconds=600):
194+
def video_label_generator(self, timeout_seconds=600, **kwargs):
195195
"""
196196
Download video annotations
197197
@@ -200,7 +200,8 @@ def video_label_generator(self, timeout_seconds=600):
200200
"""
201201
_check_converter_import()
202202
json_data = self.export_labels(download=True,
203-
timeout_seconds=timeout_seconds)
203+
timeout_seconds=timeout_seconds,
204+
**kwargs)
204205
if json_data is None:
205206
raise TimeoutError(
206207
f"Unable to download labels in {timeout_seconds} seconds."
@@ -215,7 +216,7 @@ def video_label_generator(self, timeout_seconds=600):
215216
"Or use project.label_generator() for text and imagery data.")
216217
return LBV1Converter.deserialize_video(json_data, self.client)
217218

218-
def label_generator(self, timeout_seconds=600):
219+
def label_generator(self, timeout_seconds=600, **kwargs):
219220
"""
220221
Download text and image annotations
221222
@@ -224,7 +225,8 @@ def label_generator(self, timeout_seconds=600):
224225
"""
225226
_check_converter_import()
226227
json_data = self.export_labels(download=True,
227-
timeout_seconds=timeout_seconds)
228+
timeout_seconds=timeout_seconds,
229+
**kwargs)
228230
if json_data is None:
229231
raise TimeoutError(
230232
f"Unable to download labels in {timeout_seconds} seconds."
@@ -239,7 +241,7 @@ def label_generator(self, timeout_seconds=600):
239241
"Or use project.video_label_generator() for video data.")
240242
return LBV1Converter.deserialize(json_data)
241243

242-
def export_labels(self, download=False, timeout_seconds=600):
244+
def export_labels(self, download=False, timeout_seconds=600, **kwargs):
243245
""" Calls the server-side Label exporting that generates a JSON
244246
payload, and returns the URL to that payload.
245247
@@ -251,11 +253,52 @@ def export_labels(self, download=False, timeout_seconds=600):
251253
URL of the data file with this Project's labels. If the server didn't
252254
generate during the `timeout_seconds` period, None is returned.
253255
"""
256+
257+
def _string_from_dict(dictionary: dict, value_with_quotes=False) -> str:
258+
"""Returns a concatenated string of the dictionary's keys and values
259+
260+
The string will be formatted as {key}: 'value' for each key. Value will be inclusive of
261+
quotations while key will not. This can be toggled with `value_with_quotes`"""
262+
if value_with_quotes:
263+
return ",".join([
264+
f"""{c}: "{dictionary.get(c)}\"""" for c in dictionary
265+
if dictionary.get(c)
266+
])
267+
return ",".join([
268+
f"""{c}: {dictionary.get(c)}""" for c in dictionary
269+
if dictionary.get(c)
270+
])
271+
272+
def _validate_datetime(string_date: str) -> None:
273+
"""helper function validate that datetime is as follows: YYYY-MM-DD for the export"""
274+
if string_date:
275+
try:
276+
datetime.fromisoformat(string_date)
277+
except:
278+
raise ValueError("Format of date must be \"YYYY-MM-DD\"")
279+
254280
sleep_time = 2
255281
id_param = "projectId"
282+
filter_param = ""
283+
filter_param_dict = {}
284+
285+
if "start" in kwarg or "end" in kwarg:
286+
created_at_dict = {
287+
"start": kwarg.get("start", ""),
288+
"end": kwarg.get("end", "")
289+
}
290+
[_validate_datetime(date) for date in created_at_dict.values()]
291+
filter_param_dict["labelCreatedAt"] = "{%s}" % _string_from_dict(
292+
created_at_dict, value_with_quotes=True)
293+
294+
if filter_param_dict:
295+
296+
filter_param = """, filters: {%s }""" % (_string_from_dict(
297+
filter_param_dict, value_with_quotes=False))
298+
256299
query_str = """mutation GetLabelExportUrlPyApi($%s: ID!)
257-
{exportLabels(data:{projectId: $%s }) {downloadUrl createdAt shouldPoll} }
258-
""" % (id_param, id_param)
300+
{exportLabels(data:{projectId: $%s%s}) {downloadUrl createdAt shouldPoll} }
301+
""" % (id_param, id_param, filter_param)
259302

260303
while True:
261304
res = self.client.execute(query_str, {id_param: self.uid})

0 commit comments

Comments
 (0)