Skip to content

Commit a382367

Browse files
authored
[PLT-630] Add support for export embeddings from SDK (#1573)
1 parent 638e96f commit a382367

File tree

10 files changed

+108
-26
lines changed

10 files changed

+108
-26
lines changed

libs/labelbox/src/labelbox/schema/catalog.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def export_v2(
2626
) -> Union[Task, ExportTask]:
2727
"""
2828
Creates a catalog export task with the given params, filters and returns the task.
29-
29+
3030
>>> import labelbox as lb
3131
>>> client = lb.Client(<API_KEY>)
3232
>>> catalog = client.get_catalog()
@@ -98,6 +98,7 @@ def _export(self,
9898

9999
_params = params or CatalogExportParams({
100100
"attachments": False,
101+
"embeddings": False,
101102
"metadata_fields": False,
102103
"data_row_details": False,
103104
"project_details": False,
@@ -142,6 +143,8 @@ def _export(self,
142143
if media_type_override is not None else None,
143144
"includeAttachments":
144145
_params.get('attachments', False),
146+
"includeEmbeddings":
147+
_params.get('embeddings', False),
145148
"includeMetadata":
146149
_params.get('metadata_fields', False),
147150
"includeDataRowDetails":

libs/labelbox/src/labelbox/schema/data_row.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def _export(
266266
) -> Tuple[Task, bool]:
267267
_params = params or CatalogExportParams({
268268
"attachments": False,
269+
"embeddings": False,
269270
"metadata_fields": False,
270271
"data_row_details": False,
271272
"project_details": False,
@@ -325,6 +326,8 @@ def _export(
325326
if media_type_override is not None else None,
326327
"includeAttachments":
327328
_params.get('attachments', False),
329+
"includeEmbeddings":
330+
_params.get('embeddings', False),
328331
"includeMetadata":
329332
_params.get('metadata_fields', False),
330333
"includeDataRowDetails":

libs/labelbox/src/labelbox/schema/dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,7 @@ def _export(
717717
) -> Tuple[Task, bool]:
718718
_params = params or CatalogExportParams({
719719
"attachments": False,
720+
"embeddings": False,
720721
"metadata_fields": False,
721722
"data_row_details": False,
722723
"project_details": False,
@@ -763,6 +764,8 @@ def _export(
763764
if media_type_override is not None else None,
764765
"includeAttachments":
765766
_params.get('attachments', False),
767+
"includeEmbeddings":
768+
_params.get('embeddings', False),
766769
"includeMetadata":
767770
_params.get('metadata_fields', False),
768771
"includeDataRowDetails":

libs/labelbox/src/labelbox/schema/export_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class DataRowParams(TypedDict):
1515
data_row_details: Optional[bool]
1616
metadata_fields: Optional[bool]
1717
attachments: Optional[bool]
18+
embeddings: Optional[bool]
1819
media_type_override: Optional[MediaType]
1920

2021

libs/labelbox/src/labelbox/schema/model_run.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def upsert_labels(self,
5757
label_ids: Optional[List[str]] = None,
5858
project_id: Optional[str] = None,
5959
timeout_seconds=3600):
60-
"""
60+
"""
6161
Adds data rows and labels to a Model Run
6262
6363
Args:
@@ -273,16 +273,16 @@ def add_predictions(
273273
name: str,
274274
predictions: Union[str, Path, Iterable[Dict], Iterable["Label"]],
275275
) -> 'MEAPredictionImport': # type: ignore
276-
"""
276+
"""
277277
Uploads predictions to a new Editor project.
278-
278+
279279
Args:
280280
name (str): name of the AnnotationImport job
281281
predictions (str or Path or Iterable): url that is publicly accessible by Labelbox containing an
282282
ndjson file
283283
OR local path to an ndjson file
284284
OR iterable of annotation rows
285-
285+
286286
Returns:
287287
AnnotationImport
288288
"""
@@ -566,6 +566,8 @@ def _export(
566566
_params.get('media_type_override', None),
567567
"includeAttachments":
568568
_params.get('attachments', False),
569+
"includeEmbeddings":
570+
_params.get('embeddings', False),
569571
"includeMetadata":
570572
_params.get('metadata_fields', False),
571573
"includeDataRowDetails":

libs/labelbox/src/labelbox/schema/project.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ def _export(
518518
) -> Tuple[Task, bool]:
519519
_params = params or ProjectExportParams({
520520
"attachments": False,
521+
"embeddings": False,
521522
"metadata_fields": False,
522523
"data_row_details": False,
523524
"project_details": False,
@@ -560,6 +561,8 @@ def _export(
560561
if media_type_override is not None else None,
561562
"includeAttachments":
562563
_params.get('attachments', False),
564+
"includeEmbeddings":
565+
_params.get('embeddings', False),
563566
"includeMetadata":
564567
_params.get('metadata_fields', False),
565568
"includeDataRowDetails":

libs/labelbox/src/labelbox/schema/slice.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def _export(
167167
) -> Tuple[Task, bool]:
168168
_params = params or CatalogExportParams({
169169
"attachments": False,
170+
"embeddings": False,
170171
"metadata_fields": False,
171172
"data_row_details": False,
172173
"project_details": False,
@@ -201,6 +202,8 @@ def _export(
201202
if media_type_override is not None else None,
202203
"includeAttachments":
203204
_params.get('attachments', False),
205+
"includeEmbeddings":
206+
_params.get('embeddings', False),
204207
"includeMetadata":
205208
_params.get('metadata_fields', False),
206209
"includeDataRowDetails":

libs/labelbox/tests/conftest.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class Environ(Enum):
7878

7979
@pytest.fixture
8080
def image_url() -> str:
81-
return IMAGE_URL
81+
return MASKABLE_IMG_URL
8282

8383

8484
@pytest.fixture
@@ -376,14 +376,6 @@ def client(environ: str):
376376
return IntegrationClient(environ)
377377

378378

379-
@pytest.fixture(scope="session")
380-
def image_url(client):
381-
return client.upload_data(requests.get(MASKABLE_IMG_URL).content,
382-
content_type="image/jpeg",
383-
filename="image.jpeg",
384-
sign=True)
385-
386-
387379
@pytest.fixture(scope="session")
388380
def pdf_url(client):
389381
pdf_url = client.upload_file('tests/assets/loremipsum.pdf')
@@ -1042,4 +1034,12 @@ def configured_project_with_complex_ontology(client, initial_dataset, rand_gen,
10421034
project.setup(editor, ontology.asdict())
10431035

10441036
yield [project, data_row]
1045-
project.delete()
1037+
project.delete()
1038+
1039+
1040+
@pytest.fixture
1041+
def embedding(client: Client):
1042+
uuid_str = uuid.uuid4().hex
1043+
embedding = client.create_embedding(f"sdk-int-{uuid_str}", 8)
1044+
yield embedding
1045+
embedding.delete()
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import json
2+
import random
3+
4+
from labelbox import StreamType, JsonConverter
5+
6+
7+
class TestExportEmbeddings:
8+
9+
def test_export_embeddings_precomputed(self, client, dataset, environ,
10+
image_url):
11+
data_row_specs = [{
12+
"row_data": image_url,
13+
"external_id": "image",
14+
}]
15+
task = dataset.create_data_rows(data_row_specs)
16+
task.wait_till_done()
17+
export_task = dataset.export(params={"embeddings": True})
18+
export_task.wait_till_done()
19+
assert export_task.status == "COMPLETE"
20+
assert export_task.has_result()
21+
assert export_task.has_errors() is False
22+
23+
results = []
24+
export_task.get_stream(converter=JsonConverter(),
25+
stream_type=StreamType.RESULT).start(
26+
stream_handler=lambda output: results.append(
27+
json.loads(output.json_str)))
28+
29+
assert len(results) == len(data_row_specs)
30+
31+
result = results[0]
32+
assert "embeddings" in result
33+
assert len(result["embeddings"]) > 0
34+
assert result["embeddings"][0][
35+
"name"] == "Image Embedding V2 (CLIP ViT-B/32)"
36+
assert len(result["embeddings"][0]["values"]) == 1
37+
38+
def test_export_embeddings_custom(self, client, dataset, image_url,
39+
embedding):
40+
vector = [random.uniform(1.0, 2.0) for _ in range(embedding.dims)]
41+
import_task = dataset.create_data_rows([{
42+
"row_data": image_url,
43+
"embeddings": [{
44+
"embedding_id": embedding.id,
45+
"vector": vector,
46+
}],
47+
}])
48+
import_task.wait_till_done()
49+
assert import_task.status == "COMPLETE"
50+
51+
export_task = dataset.export(params={"embeddings": True})
52+
export_task.wait_till_done()
53+
assert export_task.status == "COMPLETE"
54+
assert export_task.has_result()
55+
assert export_task.has_errors() is False
56+
57+
results = []
58+
export_task.get_stream(converter=JsonConverter(),
59+
stream_type=StreamType.RESULT).start(
60+
stream_handler=lambda output: results.append(
61+
json.loads(output.json_str)))
62+
63+
assert len(results) == 1
64+
assert "embeddings" in results[0]
65+
assert (len(results[0]["embeddings"])
66+
>= 1) # should at least contain the custom embedding
67+
for emb in results[0]["embeddings"]:
68+
if emb["id"] == embedding.id:
69+
assert emb["name"] == embedding.name
70+
assert emb["dimensions"] == embedding.dims
71+
assert emb["is_custom"] == True
72+
assert len(emb["values"]) == 1
73+
assert emb["values"][0]["value"] == vector

libs/labelbox/tests/integration/test_embedding.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import json
22
import random
33
import threading
4-
import uuid
54
from tempfile import NamedTemporaryFile
65
from typing import List, Dict, Any
76

@@ -12,14 +11,6 @@
1211
from labelbox.schema.embedding import Embedding
1312

1413

15-
@pytest.fixture
16-
def embedding(client: Client):
17-
uuid_str = uuid.uuid4().hex
18-
embedding = client.create_embedding(f"sdk-int-{uuid_str}", 8)
19-
yield embedding
20-
embedding.delete()
21-
22-
2314
def test_get_embedding_by_id(client: Client, embedding: Embedding):
2415
e = client.get_embedding_by_id(embedding.id)
2516
assert e.id == embedding.id
@@ -43,7 +34,7 @@ def test_get_embeddings(client: Client, embedding: Embedding):
4334
@pytest.mark.parametrize('data_rows', [10], indirect=True)
4435
def test_import_vectors_from_file(data_rows: List[DataRow],
4536
embedding: Embedding):
46-
vector = [random.uniform(1.0, 2.0) for _ in range(8)]
37+
vector = [random.uniform(1.0, 2.0) for _ in range(embedding.dims)]
4738
event = threading.Event()
4839

4940
def callback(_: Dict[str, Any]):
@@ -66,7 +57,7 @@ def callback(_: Dict[str, Any]):
6657
def test_get_imported_vector_count(dataset: Dataset, embedding: Embedding):
6758
assert embedding.get_imported_vector_count() == 0
6859

69-
vector = [random.uniform(1.0, 2.0) for _ in range(8)]
60+
vector = [random.uniform(1.0, 2.0) for _ in range(embedding.dims)]
7061
dataset.create_data_row(row_data="foo",
7162
embeddings=[{
7263
"embedding_id": embedding.id,

0 commit comments

Comments
 (0)