Skip to content

Commit cb10ec6

Browse files
authored
Add support for asynchronous embeddings export (#394)
* Add support for asynchronous embeddings export * Add changelog * Add changed section to changelog * Allow waiting for completion in result_urls * Add from_id to AsyncJob * Add documentation for from_id * Adapt tests
1 parent 36d1315 commit cb10ec6

File tree

9 files changed

+96
-12
lines changed

9 files changed

+96
-12
lines changed

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
command: |
4040
poetry run black --check .
4141
- run:
42-
name: Ruff Lint Check # See pyproject.tooml [tool.ruff]
42+
name: Ruff Lint Check # See pyproject.toml [tool.ruff]
4343
command: |
4444
poetry run ruff .
4545
- run:

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ repos:
1111
- repo: local
1212
hooks:
1313
- id: system
14-
name: flake8
14+
name: ruff
1515
entry: poetry run ruff nucleus
1616
pass_filenames: false
1717
language: system

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88

9+
## [0.16.1](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.1) - 2023-09-18
10+
11+
### Added
12+
- Added `asynchronous` parameter for `slice.export_embeddings()` and `dataset.export_embeddings()` to allow embeddings to be exported asynchronously.
13+
14+
### Changed
15+
- Changed `slice.export_embeddings()` and `dataset.export_embeddings()` to be asynchronous by deafult.
16+
917
## [0.16.0](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.0) - 2023-09-18
1018

1119
### Removed

nucleus/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
__all__ = [
44
"AsyncJob",
5+
"EmbeddingsExportJob",
56
"BoxAnnotation",
67
"BoxPrediction",
78
"CameraParams",
@@ -68,7 +69,7 @@
6869
Segment,
6970
SegmentationAnnotation,
7071
)
71-
from .async_job import AsyncJob
72+
from .async_job import AsyncJob, EmbeddingsExportJob
7273
from .camera_params import CameraParams
7374
from .connection import Connection
7475
from .constants import (
@@ -236,7 +237,7 @@ def models(self) -> List[Model]:
236237
def jobs(
237238
self,
238239
) -> List[AsyncJob]:
239-
"""Lists all jobs, see NucleusClinet.list_jobs(...) for advanced options
240+
"""Lists all jobs, see NucleusClient.list_jobs(...) for advanced options
240241
241242
Returns:
242243
List of all AsyncJobs

nucleus/async_job.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,20 @@ def sleep_until_complete(self, verbose_std_out=True):
119119
if final_status["status"] == "Errored":
120120
raise JobError(final_status, self)
121121

122+
@classmethod
123+
def from_id(cls, job_id: str, client: "NucleusClient"): # type: ignore # noqa: F821
124+
"""Creates a job instance from a specific job Id.
125+
126+
Parameters:
127+
job_id: Defines the job Id
128+
client: The client to use for the request.
129+
130+
Returns:
131+
The specific AsyncMethod (or inherited) instance.
132+
"""
133+
job = client.get_job(job_id)
134+
return cls.from_json(job.__dict__, client)
135+
122136
@classmethod
123137
def from_json(cls, payload: dict, client):
124138
# TODO: make private
@@ -131,6 +145,34 @@ def from_json(cls, payload: dict, client):
131145
)
132146

133147

148+
class EmbeddingsExportJob(AsyncJob):
149+
def result_urls(self, wait_for_completion=True) -> List[str]:
150+
"""Gets a list of signed Scale URLs for each embedding batch.
151+
152+
Parameters:
153+
wait_for_completion: Defines whether the call shall wait for
154+
the job to complete. Defaults to True
155+
156+
Returns:
157+
A list of signed Scale URLs which contain batches of embeddings.
158+
159+
The files contain a JSON array of embedding records with the following schema:
160+
[{
161+
"reference_id": str,
162+
"embedding_vector": List[float]
163+
}]
164+
"""
165+
if wait_for_completion:
166+
self.sleep_until_complete(verbose_std_out=False)
167+
168+
status = self.status()
169+
170+
if status["status"] != "Completed":
171+
raise JobError(status, self)
172+
173+
return status["message"]["result"] # type: ignore
174+
175+
134176
class JobError(Exception):
135177
def __init__(self, job_status: Dict[str, str], job: AsyncJob):
136178
final_status_message = job_status["message"]

nucleus/dataset.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import requests
1616

1717
from nucleus.annotation_uploader import AnnotationUploader, PredictionUploader
18-
from nucleus.async_job import AsyncJob
18+
from nucleus.async_job import AsyncJob, EmbeddingsExportJob
1919
from nucleus.prediction import Prediction, from_json
2020
from nucleus.track import Track
2121
from nucleus.url_utils import sanitize_string_args
@@ -1421,18 +1421,34 @@ def items_and_annotation_generator(
14211421

14221422
def export_embeddings(
14231423
self,
1424-
) -> List[Dict[str, Union[str, List[float]]]]:
1424+
asynchronous: bool = True,
1425+
) -> Union[List[Dict[str, Union[str, List[float]]]], EmbeddingsExportJob]:
14251426
"""Fetches a pd.DataFrame-ready list of dataset embeddings.
14261427
1428+
Parameters:
1429+
asynchronous: Whether or not to process the export asynchronously (and
1430+
return an :class:`EmbeddingsExportJob` object). Default is True.
1431+
14271432
Returns:
1428-
A list, where each item is a dict with two keys representing a row
1433+
If synchronous, a list where each item is a dict with two keys representing a row
14291434
in the dataset::
14301435
14311436
List[{
14321437
"reference_id": str,
14331438
"embedding_vector": List[float]
14341439
}]
1440+
1441+
Otherwise, returns an :class:`EmbeddingsExportJob` object.
14351442
"""
1443+
if asynchronous:
1444+
api_payload = self._client.make_request(
1445+
payload=None,
1446+
route=f"dataset/{self.id}/async_export_embeddings",
1447+
requests_command=requests.post,
1448+
)
1449+
1450+
return EmbeddingsExportJob.from_json(api_payload, self._client)
1451+
14361452
api_payload = self._client.make_request(
14371453
payload=None,
14381454
route=f"dataset/{self.id}/embeddings",

nucleus/job.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class CustomerJobTypes(str, Enum):
2727
CLONE_DATASET = "cloneDataset"
2828
METADATA_UPDATE = "metadataUpdate"
2929
TRIGGER_EVALUATE = "triggerEvaluate"
30+
EXPORT_EMBEDDINGS = "exportEmbeddings"
3031

3132
def __contains__(self, item):
3233
try:

nucleus/slice.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import requests
88

99
from nucleus.annotation import Annotation
10-
from nucleus.async_job import AsyncJob
10+
from nucleus.async_job import AsyncJob, EmbeddingsExportJob
1111
from nucleus.constants import EXPORT_FOR_TRAINING_KEY, EXPORTED_ROWS, ITEMS_KEY
1212
from nucleus.dataset_item import DatasetItem
1313
from nucleus.errors import NucleusAPIError
@@ -600,17 +600,33 @@ def send_to_labeling(self, project_id: str):
600600

601601
def export_embeddings(
602602
self,
603-
) -> List[Dict[str, Union[str, List[float]]]]:
603+
asynchronous: bool = True,
604+
) -> Union[List[Dict[str, Union[str, List[float]]]], EmbeddingsExportJob]:
604605
"""Fetches a pd.DataFrame-ready list of slice embeddings.
605606
607+
Parameters:
608+
asynchronous: Whether or not to process the export asynchronously (and
609+
return an :class:`EmbeddingsExportJob` object). Default is True.
610+
606611
Returns:
607-
A list where each element is a columnar mapping::
612+
If synchronous, a list where each element is a columnar mapping::
608613
609614
List[{
610615
"reference_id": str,
611616
"embedding_vector": List[float]
612617
}]
618+
619+
Otherwise, returns an :class:`EmbeddingsExportJob` object.
613620
"""
621+
if asynchronous:
622+
api_payload = self._client.make_request(
623+
payload=None,
624+
route=f"dataset/{self.id}/async_export_embeddings",
625+
requests_command=requests.post,
626+
)
627+
628+
return EmbeddingsExportJob.from_json(api_payload, self._client)
629+
614630
api_payload = self._client.make_request(
615631
payload=None,
616632
route=f"slice/{self.id}/embeddings",

tests/test_autotag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_export_embeddings(CLIENT):
6060
if running_as_nucleus_pytest_user(CLIENT):
6161
embeddings = Dataset(
6262
DATASET_WITH_EMBEDDINGS, CLIENT
63-
).export_embeddings()
63+
).export_embeddings(asynchronous=False)
6464
assert "embedding_vector" in embeddings[0]
6565
assert "reference_id" in embeddings[0]
6666

@@ -100,7 +100,7 @@ def test_dataset_export_autotag_tagged_items(CLIENT):
100100
def test_export_slice_embeddings(CLIENT):
101101
if running_as_nucleus_pytest_user(CLIENT):
102102
test_slice = CLIENT.get_slice("slc_c8jwtmj372xg07g9v3k0")
103-
embeddings = test_slice.export_embeddings()
103+
embeddings = test_slice.export_embeddings(asynchronous=False)
104104
assert "embedding_vector" in embeddings[0]
105105
assert "reference_id" in embeddings[0]
106106

0 commit comments

Comments
 (0)