Skip to content

Commit 10cbf66

Browse files
author
Adrian Chang
committed
Buffered stream code
1 parent f9e76b1 commit 10cbf66

File tree

3 files changed

+154
-44
lines changed

3 files changed

+154
-44
lines changed

libs/labelbox/src/labelbox/schema/export_task.py

Lines changed: 58 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,11 @@ class JsonConverter(Converter[JsonConverterOutput]): # pylint: disable=too-few-
114114
115115
Deprecated: This converter is deprecated and will be removed in a future release.
116116
"""
117-
117+
118118
def __init__(self) -> None:
119-
warnings.warn("JSON converter is deprecated and will be removed in a future release")
119+
warnings.warn(
120+
"JSON converter is deprecated and will be removed in a future release"
121+
)
120122
super().__init__()
121123

122124
def _find_json_object_offsets(self, data: str) -> List[Tuple[int, int]]:
@@ -397,7 +399,9 @@ class _MultiGCSFileReader(_Reader): # pylint: disable=too-few-public-methods
397399
"""
398400

399401
def __init__(self):
400-
warnings.warn("_MultiGCSFileReader is deprecated and will be removed in a future release")
402+
warnings.warn(
403+
"_MultiGCSFileReader is deprecated and will be removed in a future release"
404+
)
401405
super().__init__()
402406
self._retrieval_strategy = None
403407

@@ -414,6 +418,7 @@ def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]:
414418
yield file_info, raw_data
415419
result = self._retrieval_strategy.get_next_chunk()
416420

421+
417422
class Stream(Generic[OutputT]):
418423
"""Streams data from a Reader."""
419424

@@ -481,7 +486,7 @@ def __init__(
481486
) -> None:
482487
super().__init__(ctx)
483488
self._current_offset = offset
484-
self._current_line: Optional[int] = None
489+
self._current_line = 0
485490
if self._current_offset >= self._ctx.metadata_header.total_size:
486491
raise ValueError(
487492
f"offset is out of range, max offset is {self._ctx.metadata_header.total_size - 1}"
@@ -523,7 +528,8 @@ def __init__(
523528
self._ctx = ctx
524529
self._reader = _BufferedGCSFileReader()
525530
self._converter = _BufferedJsonConverter()
526-
self._reader.set_retrieval_strategy(_BufferedFileRetrieverByOffset(self._ctx, 0))
531+
self._reader.set_retrieval_strategy(
532+
_BufferedFileRetrieverByOffset(self._ctx, 0))
527533

528534
def __iter__(self):
529535
yield from self._fetch()
@@ -564,13 +570,14 @@ class BufferedJsonConverterOutput:
564570
class _BufferedJsonConverter(Converter[BufferedJsonConverterOutput]):
565571
"""Converts JSON data in a buffered manner
566572
"""
573+
567574
def convert(
568575
self, input_args: Converter.ConverterInputArgs
569576
) -> Iterator[BufferedJsonConverterOutput]:
570577
yield BufferedJsonConverterOutput(json=json.loads(input_args.raw_data))
571578

572579

573-
class _BufferedGCSFileReader(_Reader):
580+
class _BufferedGCSFileReader(_Reader):
574581
"""Reads data from multiple GCS files and buffer them to disk"""
575582

576583
def __init__(self):
@@ -599,10 +606,10 @@ def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]:
599606
# read buffer
600607
with open(temp_file.name, 'r') as temp_file_reopened:
601608
for idx, line in enumerate(temp_file_reopened):
602-
yield _MetadataFileInfo(
603-
offsets=Range(start=0, end=len(line) - 1),
604-
lines=Range(start=idx, end=idx + 1),
605-
file=temp_file.name), line
609+
yield _MetadataFileInfo(offsets=Range(start=0,
610+
end=len(line) - 1),
611+
lines=Range(start=idx, end=idx + 1),
612+
file=temp_file.name), line
606613
# manually delete buffer
607614
os.unlink(temp_file.name)
608615

@@ -733,9 +740,10 @@ def errors(self):
733740
if metadata_header is None:
734741
return None
735742
BufferedStream(
736-
_TaskContext(self._task.client, self._task.uid, StreamType.ERRORS,
737-
metadata_header),
738-
).start(stream_handler=lambda output: data.append(output.json))
743+
_TaskContext(
744+
self._task.client, self._task.uid, StreamType.ERRORS,
745+
metadata_header),).start(
746+
stream_handler=lambda output: data.append(output.json))
739747
return data
740748

741749
@property
@@ -753,9 +761,10 @@ def result(self):
753761
if metadata_header is None:
754762
return []
755763
BufferedStream(
756-
_TaskContext(self._task.client, self._task.uid,
757-
StreamType.RESULT, metadata_header),
758-
).start(stream_handler=lambda output: data.append(output.json))
764+
_TaskContext(
765+
self._task.client, self._task.uid, StreamType.RESULT,
766+
metadata_header),).start(
767+
stream_handler=lambda output: data.append(output.json))
759768
return data
760769
return self._task.result_url
761770

@@ -830,27 +839,23 @@ def has_errors(self) -> bool:
830839
total_size = self.get_total_file_size(StreamType.ERRORS)
831840
return total_size is not None and total_size > 0
832841

833-
@overload
834-
def get_stream(
842+
def get_buffered_stream(
835843
self,
836-
converter: JsonConverter,
837844
stream_type: StreamType = StreamType.RESULT,
838-
) -> Stream[JsonConverterOutput]:
839-
"""Overload for getting the right typing hints when using a JsonConverter."""
845+
) -> BufferedStream:
846+
"""
847+
Returns the result of the task.
840848
841-
@overload
842-
def get_stream(
843-
self,
844-
converter: FileConverter,
845-
stream_type: StreamType = StreamType.RESULT,
846-
) -> Stream[FileConverterOutput]:
847-
"""Overload for getting the right typing hints when using a FileConverter."""
849+
Args:
850+
stream_type (StreamType, optional): The type of stream to retrieve. Defaults to StreamType.RESULT.
848851
849-
def get_buffered_stream(
850-
self,
851-
stream_type: StreamType = StreamType.RESULT,
852-
) -> Stream:
853-
"""Returns the result of the task."""
852+
Returns:
853+
Stream: The buffered stream object.
854+
855+
Raises:
856+
ExportTask.ExportTaskException: If the task has failed or is not ready yet.
857+
ValueError: If the task does not have the specified stream type.
858+
"""
854859
if self._task.status == "FAILED":
855860
raise ExportTask.ExportTaskException("Task failed")
856861
if self._task.status != "COMPLETE":
@@ -864,15 +869,32 @@ def get_buffered_stream(
864869
)
865870
return BufferedStream(
866871
_TaskContext(self._task.client, self._task.uid, stream_type,
867-
metadata_header),
868-
)
872+
metadata_header),)
873+
874+
@overload
875+
def get_stream(
876+
self,
877+
converter: JsonConverter,
878+
stream_type: StreamType = StreamType.RESULT,
879+
) -> Stream[JsonConverterOutput]:
880+
"""Overload for getting the right typing hints when using a JsonConverter."""
881+
882+
@overload
883+
def get_stream(
884+
self,
885+
converter: FileConverter,
886+
stream_type: StreamType = StreamType.RESULT,
887+
) -> Stream[FileConverterOutput]:
888+
"""Overload for getting the right typing hints when using a FileConverter."""
869889

870890
def get_stream(
871891
self,
872892
converter: Optional[Converter] = None,
873893
stream_type: StreamType = StreamType.RESULT,
874894
) -> Stream:
875-
warnings.warn("get_stream is deprecated and will be removed in a future release, use get_buffered_stream")
895+
warnings.warn(
896+
"get_stream is deprecated and will be removed in a future release, use get_buffered_stream"
897+
)
876898
if converter is None:
877899
converter = JsonConverter()
878900
"""Returns the result of the task."""

libs/labelbox/tests/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,10 +430,14 @@ def consensus_project(client, rand_gen):
430430

431431
@pytest.fixture
432432
def model_config(client, rand_gen, valid_model_id):
433-
model_config = client.create_model_config(name=rand_gen(str), model_id=valid_model_id, inference_params = {"param": "value"})
433+
model_config = client.create_model_config(
434+
name=rand_gen(str),
435+
model_id=valid_model_id,
436+
inference_params={"param": "value"})
434437
yield model_config
435438
client.delete_model_config(model_config.uid)
436439

440+
437441
@pytest.fixture
438442
def consensus_project_with_batch(consensus_project, initial_dataset, rand_gen,
439443
image_url):
@@ -1050,6 +1054,7 @@ def embedding(client: Client):
10501054
yield embedding
10511055
embedding.delete()
10521056

1057+
10531058
@pytest.fixture
10541059
def valid_model_id():
10551060
return "2c903542-d1da-48fd-9db1-8c62571bd3d2"

libs/labelbox/tests/unit/export_task/test_export_task.py

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,53 @@
33
from unittest.mock import MagicMock, patch
44
from labelbox.schema.export_task import ExportTask
55

6+
67
class TestExportTask:
78

89
def test_export_task(self):
910
with patch('requests.get') as mock_requests_get:
1011
mock_task = MagicMock()
1112
mock_task.client.execute.side_effect = [
12-
{"task": {"exportMetadataHeader": { "total_size": 1, "total_lines": 1, "lines": { "start": 0, "end": 1 }, "offsets": { "start": 0, "end": 0 }, "file": "file" } } },
13-
{"task": {"exportFileFromOffset": { "total_size": 1, "total_lines": 1, "lines": { "start": 0, "end": 1 }, "offsets": { "start": 0, "end": 0 }, "file": "file" } } },
13+
{
14+
"task": {
15+
"exportMetadataHeader": {
16+
"total_size": 1,
17+
"total_lines": 1,
18+
"lines": {
19+
"start": 0,
20+
"end": 1
21+
},
22+
"offsets": {
23+
"start": 0,
24+
"end": 0
25+
},
26+
"file": "file"
27+
}
28+
}
29+
},
30+
{
31+
"task": {
32+
"exportFileFromOffset": {
33+
"total_size": 1,
34+
"total_lines": 1,
35+
"lines": {
36+
"start": 0,
37+
"end": 1
38+
},
39+
"offsets": {
40+
"start": 0,
41+
"end": 0
42+
},
43+
"file": "file"
44+
}
45+
}
46+
},
1447
]
1548
mock_task.status = "COMPLETE"
1649
data = {
1750
"data_row": {
18-
"raw_data": """
51+
"raw_data":
52+
"""
1953
{"raw_text":"}{"}
2054
{"raw_text":"\\nbad"}
2155
"""
@@ -30,12 +64,61 @@ def test_export_task_bad_offsets(self):
3064
with patch('requests.get') as mock_requests_get:
3165
mock_task = MagicMock()
3266
mock_task.client.execute.side_effect = [
33-
{"task": {"exportMetadataHeader": { "total_size": 1, "total_lines": 1, "lines": { "start": 0, "end": 1 }, "offsets": { "start": 0, "end": 0 }, "file": "file" } } },
34-
{"task": {"exportFileFromOffset": { "total_size": 1, "total_lines": 1, "lines": { "start": 0, "end": 1 }, "offsets": { "start": 0, "end": 0 }, "file": "file" } } },
67+
{
68+
"task": {
69+
"exportMetadataHeader": {
70+
"total_size": 1,
71+
"total_lines": 1,
72+
"lines": {
73+
"start": 0,
74+
"end": 1
75+
},
76+
"offsets": {
77+
"start": 0,
78+
"end": 0
79+
},
80+
"file": "file"
81+
}
82+
}
83+
},
84+
{
85+
"task": {
86+
"exportFileFromOffset": {
87+
"total_size": 1,
88+
"total_lines": 1,
89+
"lines": {
90+
"start": 0,
91+
"end": 1
92+
},
93+
"offsets": {
94+
"start": 0,
95+
"end": 0
96+
},
97+
"file": "file"
98+
}
99+
}
100+
},
35101
]
36102
mock_task.status = "COMPLETE"
37-
data = {"data_row": {"id": "clwb6wvpv3mpx0712aafl9m00", "external_id": "43cdad5e-1fcf-450d-ad72-df4460edf973", "global_key": "9ab56c5a-5c2f-45ae-8e21-e53eb415cefe", "row_data": "{\"type\":\"application/vnd.labelbox.conversational\",\"version\":1,\"messages\":[{\"messageId\":\"message-0\",\"timestampUsec\":1530718491,\"content\":\"The minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$ is\",\"user\":{\"userId\":\"prompt\",\"name\":\"prompt\"},\"align\":\"left\",\"canLabel\":true}],\"modelOutputs\":[{\"title\":\"Response 1\",\"content\":\"To find the minimum value of the expression $3 \\\\cos x + 4 \\\\sin x + 8$, we can use the fact that $a\\\\cos x+b\\\\sin x=\\\\sqrt{a^2+b^2}\\\\left(\\\\frac{a}{\\\\sqrt{a^2+b^2}}\\\\cos x+\\\\frac{b}{\\\\sqrt{a^2+b^2}}\\\\sin x\\\\right)$. This allows us to rewrite the expression as:\\n\\n$3\\\\cos x+4\\\\sin x+8=\\\\sqrt{3^2+4^2}\\\\left(\\\\frac{3}{\\\\sqrt{3^2+4^2}}\\\\cos x+\\\\frac{4}{\\\\sqrt{3^2+4^2}}\\\\sin x\\\\right)+8=5\\\\left(\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x\\\\right)+8$\\n\\nNow, let's consider the expression $\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x$. Since $\\\\left(\\\\frac{3}{5}\\\\right)^2+\\\\left(\\\\frac{4}{5}\\\\right)^2=1$, we can write $\\\\frac{3}{5}=\\\\cos\\\\theta$ and $\\\\frac{4}{5}=\\\\sin\\\\theta$ for some angle $\\\\theta$. Then:\\n\\n$\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x=\\\\cos\\\\theta\\\\cos x+\\\\sin\\\\theta\\\\sin x=\\\\cos(x-\\\\theta)$\\n\\nSo, the original expression can be written as:\\n\\n$5\\\\cos(x-\\\\theta)+8$\\n\\nSince the minimum value of $\\\\cos(x-\\\\theta)$ is $-1$, the minimum value of the original expression is:\\n\\n$5(-1)+8=-5+8=3$\\n\\nTherefore, the minimum value of $3\\\\cos x + 4\\\\sin x + 8$ is $\\\\boxed{3}$.\",\"modelConfigName\":\"null\"},{\"title\":\"Response 2\",\"content\":\"A nice math question!\\n\\nTo find the minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$, we can use the fact that $a\\\\cos x + b\\\\sin x = \\\\sqrt{a^2 + b^2} \\\\cos(x - \\\\alpha)$, where $\\\\alpha = \\\\tan^{-1}\\\\left(\\\\frac{b}{a}\\\\right)$.\\n\\nIn this case, $a = 3$ and $b = 4$, so $\\\\alpha = \\\\tan^{-1}\\\\left(\\\\frac{4}{3}\\\\right)$.\\n\\nSo, we have:\\n\\n$$3 \\\\cos x + 4 \\\\sin x + 8 = \\\\sqrt{3^2 + 4^2} \\\\cos(x - \\\\alpha) + 8 = 5 \\\\cos(x - \\\\alpha) + 8$$\\n\\nNow, the minimum value of $\\\\cos(x - \\\\alpha)$ is $-1$, so the minimum value of $5 \\\\cos(x - \\\\alpha) + 8$ is:\\n\\n$$5(-1) + 8 = -5 + 8 = 3$$\\n\\nTherefore, the minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$ is $\\\\boxed{3}$.\\n\\nLet me know if you have any questions or need further clarification!\",\"modelConfigName\":\"null\"}]}"}, "media_attributes": {"asset_type": "conversational", "mime_type": "application/vnd.labelbox.conversational", "labelable_ids": ["message-0"], "message_count": 1}}
103+
data = {
104+
"data_row": {
105+
"id":
106+
"clwb6wvpv3mpx0712aafl9m00",
107+
"external_id":
108+
"43cdad5e-1fcf-450d-ad72-df4460edf973",
109+
"global_key":
110+
"9ab56c5a-5c2f-45ae-8e21-e53eb415cefe",
111+
"row_data":
112+
"{\"type\":\"application/vnd.labelbox.conversational\",\"version\":1,\"messages\":[{\"messageId\":\"message-0\",\"timestampUsec\":1530718491,\"content\":\"The minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$ is\",\"user\":{\"userId\":\"prompt\",\"name\":\"prompt\"},\"align\":\"left\",\"canLabel\":true}],\"modelOutputs\":[{\"title\":\"Response 1\",\"content\":\"To find the minimum value of the expression $3 \\\\cos x + 4 \\\\sin x + 8$, we can use the fact that $a\\\\cos x+b\\\\sin x=\\\\sqrt{a^2+b^2}\\\\left(\\\\frac{a}{\\\\sqrt{a^2+b^2}}\\\\cos x+\\\\frac{b}{\\\\sqrt{a^2+b^2}}\\\\sin x\\\\right)$. This allows us to rewrite the expression as:\\n\\n$3\\\\cos x+4\\\\sin x+8=\\\\sqrt{3^2+4^2}\\\\left(\\\\frac{3}{\\\\sqrt{3^2+4^2}}\\\\cos x+\\\\frac{4}{\\\\sqrt{3^2+4^2}}\\\\sin x\\\\right)+8=5\\\\left(\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x\\\\right)+8$\\n\\nNow, let's consider the expression $\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x$. Since $\\\\left(\\\\frac{3}{5}\\\\right)^2+\\\\left(\\\\frac{4}{5}\\\\right)^2=1$, we can write $\\\\frac{3}{5}=\\\\cos\\\\theta$ and $\\\\frac{4}{5}=\\\\sin\\\\theta$ for some angle $\\\\theta$. Then:\\n\\n$\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x=\\\\cos\\\\theta\\\\cos x+\\\\sin\\\\theta\\\\sin x=\\\\cos(x-\\\\theta)$\\n\\nSo, the original expression can be written as:\\n\\n$5\\\\cos(x-\\\\theta)+8$\\n\\nSince the minimum value of $\\\\cos(x-\\\\theta)$ is $-1$, the minimum value of the original expression is:\\n\\n$5(-1)+8=-5+8=3$\\n\\nTherefore, the minimum value of $3\\\\cos x + 4\\\\sin x + 8$ is $\\\\boxed{3}$.\",\"modelConfigName\":\"null\"},{\"title\":\"Response 2\",\"content\":\"A nice math question!\\n\\nTo find the minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$, we can use the fact that $a\\\\cos x + b\\\\sin x = \\\\sqrt{a^2 + b^2} \\\\cos(x - \\\\alpha)$, where $\\\\alpha = \\\\tan^{-1}\\\\left(\\\\frac{b}{a}\\\\right)$.\\n\\nIn this case, $a = 3$ and $b = 4$, so $\\\\alpha = \\\\tan^{-1}\\\\left(\\\\frac{4}{3}\\\\right)$.\\n\\nSo, we have:\\n\\n$$3 \\\\cos x + 4 \\\\sin x + 8 = \\\\sqrt{3^2 + 4^2} \\\\cos(x - \\\\alpha) + 8 = 5 \\\\cos(x - \\\\alpha) + 8$$\\n\\nNow, the minimum value of $\\\\cos(x - \\\\alpha)$ is $-1$, so the minimum value of $5 \\\\cos(x - \\\\alpha) + 8$ is:\\n\\n$$5(-1) + 8 = -5 + 8 = 3$$\\n\\nTherefore, the minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$ is $\\\\boxed{3}$.\\n\\nLet me know if you have any questions or need further clarification!\",\"modelConfigName\":\"null\"}]}"
113+
},
114+
"media_attributes": {
115+
"asset_type": "conversational",
116+
"mime_type": "application/vnd.labelbox.conversational",
117+
"labelable_ids": ["message-0"],
118+
"message_count": 1
119+
}
120+
}
38121
mock_requests_get.return_value.text = json.dumps(data)
39122
mock_requests_get.return_value.content = "b"
40123
export_task = ExportTask(mock_task, is_export_v2=True)
41-
assert export_task.result[0] == data
124+
assert export_task.result[0] == data

0 commit comments

Comments
 (0)