Skip to content

Commit fc0b731

Browse files
author
Claire Pajot
committed
Merge branch 'master' into add_classification_type_to_groundtruth
2 parents 0d1ffd1 + fb9b452 commit fc0b731

17 files changed

+741
-28
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,6 @@ dmypy.json
134134

135135
# Poetry lockfile (no need for deploys, best practice is to not check this in)
136136
poetry.lock
137+
138+
# vscode
139+
.vscode/

nucleus/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,16 @@ def list_autotags(self, dataset_id: str) -> List[str]:
11571157
)
11581158
return response[AUTOTAGS_KEY] if AUTOTAGS_KEY in response else response
11591159

1160+
def delete_autotag(self, autotag_id: str) -> dict:
1161+
"""
1162+
Deletes an autotag based on autotagId.
1163+
Returns an empty payload where response status `200` indicates
1164+
the autotag has been successfully deleted.
1165+
:param autotag_id: id of the autotag to delete.
1166+
:return: {}
1167+
"""
1168+
return self.make_request({}, f"autotag/{autotag_id}", requests.delete)
1169+
11601170
def delete_model(self, model_id: str) -> dict:
11611171
"""
11621172
This endpoint deletes the specified model, along with all

nucleus/autocurate.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import datetime
2+
import requests
3+
from nucleus.constants import (
4+
JOB_CREATION_TIME_KEY,
5+
JOB_LAST_KNOWN_STATUS_KEY,
6+
JOB_TYPE_KEY,
7+
)
8+
from nucleus.job import AsyncJob
9+
10+
11+
def entropy(name, model_run, client):
12+
model_run_ids = [model_run.model_run_id]
13+
dataset_id = model_run.dataset_id
14+
response = client.make_request(
15+
payload={"modelRunIds": model_run_ids},
16+
route=f"autocurate/{dataset_id}/single_model_entropy/{name}",
17+
requests_command=requests.post,
18+
)
19+
# TODO: the response should already have the below three fields populated
20+
response[JOB_LAST_KNOWN_STATUS_KEY] = "Started"
21+
response[JOB_TYPE_KEY] = "autocurateEntropy"
22+
response[JOB_CREATION_TIME_KEY] = (
23+
datetime.datetime.now().isoformat("T", "milliseconds") + "Z"
24+
)
25+
job = AsyncJob.from_json(response, client)
26+
return job

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
TYPE_KEY = "type"
8888
UPDATED_ITEMS = "updated_items"
8989
UPDATE_KEY = "update"
90+
UPLOAD_TO_SCALE_KEY = "upload_to_scale"
9091
URL_KEY = "url"
9192
VERTICES_KEY = "vertices"
9293
WIDTH_KEY = "width"

nucleus/dataset.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,24 @@ def items_and_annotations(
430430
)
431431
return convert_export_payload(api_payload[EXPORTED_ROWS])
432432

433+
def export_embeddings(
434+
self,
435+
) -> List[Dict[str, Union[str, List[float]]]]:
436+
"""Returns a pd.Dataframe-ready format of dataset embeddings.
437+
438+
Returns:
439+
A list, where each item is a dict with two keys representing a row
440+
in the dataset.
441+
* One value in the dict is the reference id
442+
* The other value is a list of the embedding values
443+
"""
444+
api_payload = self._client.make_request(
445+
payload=None,
446+
route=f"dataset/{self.id}/embeddings",
447+
requests_command=requests.get,
448+
)
449+
return api_payload
450+
433451
def delete_annotations(
434452
self, reference_ids: list = None, keep_history=False
435453
):

nucleus/dataset_item.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
from .annotation import is_local_path, Point3D
99
from .constants import (
10-
DATASET_ITEM_ID_KEY,
1110
IMAGE_URL_KEY,
1211
METADATA_KEY,
1312
ORIGINAL_IMAGE_URL_KEY,
13+
UPLOAD_TO_SCALE_KEY,
1414
REFERENCE_ID_KEY,
1515
TYPE_KEY,
1616
URL_KEY,
@@ -91,14 +91,19 @@ class DatasetItemType(Enum):
9191
class DatasetItem: # pylint: disable=R0902
9292
image_location: Optional[str] = None
9393
reference_id: Optional[str] = None
94-
item_id: Optional[str] = None
9594
metadata: Optional[dict] = None
9695
pointcloud_location: Optional[str] = None
96+
upload_to_scale: Optional[bool] = True
9797

9898
def __post_init__(self):
99+
assert self.reference_id is not None, "reference_id is required."
99100
assert bool(self.image_location) != bool(
100101
self.pointcloud_location
101102
), "Must specify exactly one of the image_location, pointcloud_location parameters"
103+
if self.pointcloud_location and not self.upload_to_scale:
104+
raise NotImplementedError(
105+
"Skipping upload to Scale is not currently implemented for pointclouds."
106+
)
102107
self.local = (
103108
is_local_path(self.image_location) if self.image_location else None
104109
)
@@ -127,15 +132,14 @@ def from_json(cls, payload: dict, is_scene=False):
127132
image_location=image_url,
128133
pointcloud_location=payload.get(POINTCLOUD_URL_KEY, None),
129134
reference_id=payload.get(REFERENCE_ID_KEY, None),
130-
item_id=payload.get(DATASET_ITEM_ID_KEY, None),
131135
metadata=payload.get(METADATA_KEY, {}),
132136
)
133137

134138
return cls(
135139
image_location=image_url,
136140
reference_id=payload.get(REFERENCE_ID_KEY, None),
137-
item_id=payload.get(DATASET_ITEM_ID_KEY, None),
138141
metadata=payload.get(METADATA_KEY, {}),
142+
upload_to_scale=payload.get(UPLOAD_TO_SCALE_KEY, None),
139143
)
140144

141145
def local_file_exists(self):
@@ -145,10 +149,8 @@ def to_payload(self, is_scene=False) -> dict:
145149
payload: Dict[str, Any] = {
146150
METADATA_KEY: self.metadata or {},
147151
}
148-
if self.reference_id:
149-
payload[REFERENCE_ID_KEY] = self.reference_id
150-
if self.item_id:
151-
payload[DATASET_ITEM_ID_KEY] = self.item_id
152+
153+
payload[REFERENCE_ID_KEY] = self.reference_id
152154

153155
if is_scene:
154156
if self.image_location:
@@ -163,6 +165,7 @@ def to_payload(self, is_scene=False) -> dict:
163165
self.image_location
164166
), "Must specify image_location for DatasetItems not in a LidarScene"
165167
payload[IMAGE_URL_KEY] = self.image_location
168+
payload[UPLOAD_TO_SCALE_KEY] = self.upload_to_scale
166169

167170
return payload
168171

nucleus/model_run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ class ModelRun:
3131
def __init__(self, model_run_id: str, dataset_id: str, client):
3232
self.model_run_id = model_run_id
3333
self._client = client
34-
self._dataset_id = dataset_id
34+
self.dataset_id = dataset_id
3535

3636
def __repr__(self):
37-
return f"ModelRun(model_run_id='{self.model_run_id}', dataset_id='{self._dataset_id}', client={self._client})"
37+
return f"ModelRun(model_run_id='{self.model_run_id}', dataset_id='{self.dataset_id}', client={self._client})"
3838

3939
def __eq__(self, other):
4040
if self.model_run_id == other.model_run_id:
@@ -115,7 +115,7 @@ def predict(
115115
check_all_mask_paths_remote(annotations)
116116

117117
request_id = serialize_and_write_to_presigned_url(
118-
annotations, self._dataset_id, self._client
118+
annotations, self.dataset_id, self._client
119119
)
120120
response = self._client.make_request(
121121
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},

nucleus/slice.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,24 @@ def send_to_labeling(self, project_id: str):
126126
)
127127
return AsyncJob.from_json(response, self._client)
128128

129+
def export_embeddings(
130+
self,
131+
) -> List[Dict[str, Union[str, List[float]]]]:
132+
"""Returns a pd.Dataframe-ready format of dataset embeddings.
133+
134+
Returns:
135+
A list, where each item is a dict with two keys representing a row
136+
in the dataset.
137+
* One value in the dict is the reference id
138+
* The other value is a list of the embedding values
139+
"""
140+
api_payload = self._client.make_request(
141+
payload=None,
142+
route=f"slice/{self.slice_id}/embeddings",
143+
requests_command=requests.get,
144+
)
145+
return api_payload
146+
129147

130148
def check_annotations_are_in_slice(
131149
annotations: List[Annotation], slice_to_check: Slice

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.1.17"
24+
version = "0.1.18"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]
@@ -48,6 +48,7 @@ flake8 = "^3.9.1"
4848
mypy = "^0.812"
4949
coverage = "^5.5"
5050
pre-commit = "^2.12.1"
51+
jupyterlab = "^3.1.10"
5152

5253
[tool.pytest.ini_options]
5354
markers = [

0 commit comments

Comments
 (0)