Skip to content

Commit 3f5b9a4

Browse files
author
Adrian Chang
committed
Buffered stream
1 parent 99b07e5 commit 3f5b9a4

File tree

1 file changed

+160
-59
lines changed

1 file changed

+160
-59
lines changed

libs/labelbox/src/labelbox/schema/export_task.py

Lines changed: 160 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -414,59 +414,6 @@ def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]:
414414
yield file_info, raw_data
415415
result = self._retrieval_strategy.get_next_chunk()
416416

417-
418-
@dataclass
419-
class BufferedJsonConverterOutput:
420-
"""Output with the JSON object"""
421-
json: Any
422-
423-
424-
class _BufferedJsonConverter(Converter[BufferedJsonConverterOutput]):
425-
"""Converts JSON data in a buffered manner
426-
"""
427-
def convert(
428-
self, input_args: Converter.ConverterInputArgs
429-
) -> Iterator[BufferedJsonConverterOutput]:
430-
yield BufferedJsonConverterOutput(json=json.loads(input_args.raw_data))
431-
432-
433-
class _BufferedGCSFileReader(_Reader):
434-
"""Reads data from multiple GCS files and buffer them to disk"""
435-
436-
def __init__(self):
437-
super().__init__()
438-
self._retrieval_strategy = None
439-
440-
def set_retrieval_strategy(self, strategy: FileRetrieverStrategy) -> None:
441-
"""Sets the retrieval strategy."""
442-
self._retrieval_strategy = strategy
443-
444-
def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]:
445-
if not self._retrieval_strategy:
446-
raise ValueError("retrieval strategy not set")
447-
# create a buffer
448-
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file:
449-
result = self._retrieval_strategy.get_next_chunk()
450-
while result:
451-
_, raw_data = result
452-
# there is something wrong with the way the offsets are being calculated
453-
# so just write all of the chunks as is too the file, with pointer initially
454-
# pointed to the start of the file (like what is in GCS) and do not
455-
# rely on offsets for file location
456-
# temp_file.seek(file_info.offsets.start)
457-
temp_file.write(raw_data)
458-
result = self._retrieval_strategy.get_next_chunk()
459-
# read buffer
460-
with open(temp_file.name, 'r') as temp_file_reopened:
461-
for idx, line in enumerate(temp_file_reopened):
462-
yield _MetadataFileInfo(
463-
offsets=Range(start=0, end=len(line) - 1),
464-
lines=Range(start=idx, end=idx + 1),
465-
file=temp_file.name), line
466-
# manually delete buffer
467-
os.unlink(temp_file.name)
468-
469-
470417
class Stream(Generic[OutputT]):
471418
"""Streams data from a Reader."""
472419

@@ -524,6 +471,142 @@ def start(
524471
stream_handler(output)
525472

526473

474+
class _BufferedFileRetrieverByOffset(FileRetrieverStrategy): # pylint: disable=too-few-public-methods
475+
"""Retrieves files by offset."""
476+
477+
def __init__(
478+
self,
479+
ctx: _TaskContext,
480+
offset: int,
481+
) -> None:
482+
super().__init__(ctx)
483+
self._current_offset = offset
484+
self._current_line: Optional[int] = None
485+
if self._current_offset >= self._ctx.metadata_header.total_size:
486+
raise ValueError(
487+
f"offset is out of range, max offset is {self._ctx.metadata_header.total_size - 1}"
488+
)
489+
490+
def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]:
491+
if self._current_offset >= self._ctx.metadata_header.total_size:
492+
return None
493+
query = (
494+
f"query GetExportFileFromOffsetPyApi"
495+
f"($where: WhereUniqueIdInput, $streamType: TaskStreamType!, $offset: UInt64!)"
496+
f"{{task(where: $where)"
497+
f"{{{'exportFileFromOffset'}(streamType: $streamType, offset: $offset)"
498+
f"{{offsets {{start end}} lines {{start end}} file}}"
499+
f"}}}}")
500+
variables = {
501+
"where": {
502+
"id": self._ctx.task_id
503+
},
504+
"streamType": self._ctx.stream_type.value,
505+
"offset": str(self._current_offset),
506+
}
507+
file_info, file_content = self._get_file_content(
508+
query, variables, "exportFileFromOffset")
509+
file_info.offsets.start = self._current_offset
510+
file_info.lines.start = self._current_line
511+
self._current_offset = file_info.offsets.end + 1
512+
self._current_line = file_info.lines.end + 1
513+
return file_info, file_content
514+
515+
516+
class BufferedStream(Generic[OutputT]):
517+
"""Streams data from a Reader."""
518+
519+
def __init__(
520+
self,
521+
ctx: _TaskContext,
522+
):
523+
self._ctx = ctx
524+
self._reader = _BufferedGCSFileReader()
525+
self._converter = _BufferedJsonConverter()
526+
self._reader.set_retrieval_strategy(_BufferedFileRetrieverByOffset(self._ctx, 0))
527+
528+
def __iter__(self):
529+
yield from self._fetch()
530+
531+
def _fetch(self,) -> Iterator[OutputT]:
532+
"""Fetches the result data.
533+
Returns an iterator that yields the offset and the data.
534+
"""
535+
if self._ctx.metadata_header.total_size is None:
536+
return
537+
538+
stream = self._reader.read()
539+
with self._converter as converter:
540+
for file_info, raw_data in stream:
541+
for output in converter.convert(
542+
Converter.ConverterInputArgs(self._ctx, file_info,
543+
raw_data)):
544+
yield output
545+
546+
def start(
547+
self,
548+
stream_handler: Optional[Callable[[OutputT], None]] = None) -> None:
549+
"""Starts streaming the result data.
550+
Calls the stream_handler for each result.
551+
"""
552+
# this calls the __iter__ method, which in turn calls the _fetch method
553+
for output in self:
554+
if stream_handler:
555+
stream_handler(output)
556+
557+
558+
@dataclass
559+
class BufferedJsonConverterOutput:
560+
"""Output with the JSON object"""
561+
json: Any
562+
563+
564+
class _BufferedJsonConverter(Converter[BufferedJsonConverterOutput]):
565+
"""Converts JSON data in a buffered manner
566+
"""
567+
def convert(
568+
self, input_args: Converter.ConverterInputArgs
569+
) -> Iterator[BufferedJsonConverterOutput]:
570+
yield BufferedJsonConverterOutput(json=json.loads(input_args.raw_data))
571+
572+
573+
class _BufferedGCSFileReader(_Reader):
574+
"""Reads data from multiple GCS files and buffer them to disk"""
575+
576+
def __init__(self):
577+
super().__init__()
578+
self._retrieval_strategy = None
579+
580+
def set_retrieval_strategy(self, strategy: FileRetrieverStrategy) -> None:
581+
"""Sets the retrieval strategy."""
582+
self._retrieval_strategy = strategy
583+
584+
def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]:
585+
if not self._retrieval_strategy:
586+
raise ValueError("retrieval strategy not set")
587+
# create a buffer
588+
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file:
589+
result = self._retrieval_strategy.get_next_chunk()
590+
while result:
591+
_, raw_data = result
592+
# there is something wrong with the way the offsets are being calculated
593+
# so just write all of the chunks as is too the file, with pointer initially
594+
# pointed to the start of the file (like what is in GCS) and do not
595+
# rely on offsets for file location
596+
# temp_file.seek(file_info.offsets.start)
597+
temp_file.write(raw_data)
598+
result = self._retrieval_strategy.get_next_chunk()
599+
# read buffer
600+
with open(temp_file.name, 'r') as temp_file_reopened:
601+
for idx, line in enumerate(temp_file_reopened):
602+
yield _MetadataFileInfo(
603+
offsets=Range(start=0, end=len(line) - 1),
604+
lines=Range(start=idx, end=idx + 1),
605+
file=temp_file.name), line
606+
# manually delete buffer
607+
os.unlink(temp_file.name)
608+
609+
527610
class ExportTask:
528611
"""
529612
An adapter class for working with task objects, providing extended functionality
@@ -649,11 +732,9 @@ def errors(self):
649732
self._task.client, self._task.uid, StreamType.ERRORS)
650733
if metadata_header is None:
651734
return None
652-
Stream(
735+
BufferedStream(
653736
_TaskContext(self._task.client, self._task.uid, StreamType.ERRORS,
654737
metadata_header),
655-
_BufferedGCSFileReader(),
656-
_BufferedJsonConverter(),
657738
).start(stream_handler=lambda output: data.append(output.json))
658739
return data
659740

@@ -671,11 +752,9 @@ def result(self):
671752
self._task.client, self._task.uid, StreamType.RESULT)
672753
if metadata_header is None:
673754
return []
674-
Stream(
755+
BufferedStream(
675756
_TaskContext(self._task.client, self._task.uid,
676757
StreamType.RESULT, metadata_header),
677-
_BufferedGCSFileReader(),
678-
_BufferedJsonConverter(),
679758
).start(stream_handler=lambda output: data.append(output.json))
680759
return data
681760
return self._task.result_url
@@ -767,11 +846,33 @@ def get_stream(
767846
) -> Stream[FileConverterOutput]:
768847
"""Overload for getting the right typing hints when using a FileConverter."""
769848

849+
def get_buffered_stream(
850+
self,
851+
stream_type: StreamType = StreamType.RESULT,
852+
) -> Stream:
853+
"""Returns the result of the task."""
854+
if self._task.status == "FAILED":
855+
raise ExportTask.ExportTaskException("Task failed")
856+
if self._task.status != "COMPLETE":
857+
raise ExportTask.ExportTaskException("Task is not ready yet")
858+
859+
metadata_header = self._get_metadata_header(self._task.client,
860+
self._task.uid, stream_type)
861+
if metadata_header is None:
862+
raise ValueError(
863+
f"Task {self._task.uid} does not have a {stream_type.value} stream"
864+
)
865+
return BufferedStream(
866+
_TaskContext(self._task.client, self._task.uid, stream_type,
867+
metadata_header),
868+
)
869+
770870
def get_stream(
771871
self,
772872
converter: Optional[Converter] = None,
773873
stream_type: StreamType = StreamType.RESULT,
774874
) -> Stream:
875+
warnings.warn("get_stream is deprecated and will be removed in a future release, use get_buffered_stream")
775876
if converter is None:
776877
converter = JsonConverter()
777878
"""Returns the result of the task."""

0 commit comments

Comments
 (0)