Skip to content

Commit 0b2f68c

Browse files
authored
Allow embedding info upload for Dataset Items (#405)
1 parent 317d2bf commit 0b2f68c

File tree

4 files changed

+33
-1
lines changed

4 files changed

+33
-1
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.16.7](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.7) - 2023-11-03
9+
10+
### Added
11+
- Allow direct embedding vector upload together with dataset items. `DatasetItem` now has an additional parameter called `embedding_info` which can be used to directly upload embeddings when a dataset is uploaded.
12+
813

914
## [0.16.6](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.6) - 2023-11-01
1015

nucleus/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
DEFAULT_ANNOTATION_UPDATE_MODE = False
4949
DEFAULT_NETWORK_TIMEOUT_SEC = 120
5050
DIMENSIONS_KEY = "dimensions"
51+
EMBEDDING_INFO_KEY = "embedding_info"
5152
EMBEDDING_VECTOR_KEY = "embedding_vector"
5253
EMBEDDINGS_URL_KEY = "embeddings_urls"
5354
EMBEDDING_DIMENSION_KEY = "embedding_dimension"
@@ -70,6 +71,7 @@
7071
IMAGE_LOCATION_KEY = "image_location"
7172
IMAGE_URL_KEY = "image_url"
7273
INDEX_KEY = "index"
74+
INDEX_ID_KEY = "index_id"
7375
INDEX_CONTINUOUS_ENABLE_KEY = "enable"
7476
IOU_KEY = "iou"
7577
ITEMS_KEY = "items"

nucleus/dataset_item.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from .constants import (
1111
BACKEND_REFERENCE_ID_KEY,
1212
CAMERA_PARAMS_KEY,
13+
EMBEDDING_INFO_KEY,
14+
EMBEDDING_VECTOR_KEY,
1315
IMAGE_URL_KEY,
16+
INDEX_ID_KEY,
1417
METADATA_KEY,
1518
ORIGINAL_IMAGE_URL_KEY,
1619
POINTCLOUD_URL_KEY,
@@ -26,6 +29,18 @@ class DatasetItemType(Enum):
2629
POINTCLOUD = "pointcloud"
2730

2831

32+
@dataclass
33+
class DatasetItemEmbeddingInfo:
34+
index_id: str
35+
embedding_vector: list
36+
37+
def to_payload(self) -> dict:
38+
return {
39+
INDEX_ID_KEY: self.index_id,
40+
EMBEDDING_VECTOR_KEY: self.embedding_vector,
41+
}
42+
43+
2944
@dataclass # pylint: disable=R0902
3045
class DatasetItem: # pylint: disable=R0902
3146
"""A dataset item is an image or pointcloud that has associated metadata.
@@ -113,16 +128,23 @@ class DatasetItem: # pylint: disable=R0902
113128
metadata: Optional[dict] = None
114129
pointcloud_location: Optional[str] = None
115130
upload_to_scale: Optional[bool] = True
131+
embedding_info: Optional[DatasetItemEmbeddingInfo] = None
116132

117133
def __post_init__(self):
118134
assert self.reference_id != "DUMMY_VALUE", "reference_id is required."
119135
assert bool(self.image_location) != bool(
120136
self.pointcloud_location
121137
), "Must specify exactly one of the image_location or pointcloud_location parameters"
138+
if self.pointcloud_location and self.embedding_info:
139+
raise AssertionError(
140+
"Cannot upload embedding vector if pointcloud_location is set"
141+
)
142+
122143
if (self.pointcloud_location) and not self.upload_to_scale:
123144
raise NotImplementedError(
124145
"Skipping upload to Scale is not currently implemented for pointclouds."
125146
)
147+
126148
self.local = (
127149
is_local_path(self.image_location) if self.image_location else None
128150
)
@@ -179,6 +201,9 @@ def to_payload(self, is_scene=False) -> dict:
179201

180202
payload[REFERENCE_ID_KEY] = self.reference_id
181203

204+
if self.embedding_info:
205+
payload[EMBEDDING_INFO_KEY] = self.embedding_info.to_payload()
206+
182207
if is_scene:
183208
if self.image_location:
184209
payload[URL_KEY] = self.image_location

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ ignore = ["E501", "E741", "E731", "F401"] # Easy ignore for getting it running
2525

2626
[tool.poetry]
2727
name = "scale-nucleus"
28-
version = "0.16.6"
28+
version = "0.16.7"
2929
description = "The official Python client library for Nucleus, the Data Platform for AI"
3030
license = "MIT"
3131
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

0 commit comments

Comments
 (0)