Skip to content

Commit 523ca50

Browse files
authored
[Launch] add bundle_id for model creation (#270)
1 parent a4652dc commit 523ca50

File tree

4 files changed

+56
-7
lines changed

4 files changed

+56
-7
lines changed

nucleus/__init__.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,11 @@ def add_model(
479479
return self.create_model(name, reference_id, metadata)
480480

481481
def create_model(
482-
self, name: str, reference_id: str, metadata: Optional[Dict] = None
482+
self,
483+
name: str,
484+
reference_id: str,
485+
metadata: Optional[Dict] = None,
486+
bundle_name: Optional[str] = None,
483487
) -> Model:
484488
"""Adds a :class:`Model` to Nucleus.
485489
@@ -496,14 +500,23 @@ def create_model(
496500
:class:`Model`: The newly created model as an object.
497501
"""
498502
response = self.make_request(
499-
construct_model_creation_payload(name, reference_id, metadata),
503+
construct_model_creation_payload(
504+
name, reference_id, metadata, bundle_name
505+
),
500506
"models/add",
501507
)
502508
model_id = response.get("model_id", None)
503509
if not model_id:
504510
raise ModelCreationError(response.get("error"))
505511

506-
return Model(model_id, name, reference_id, metadata, self)
512+
return Model(
513+
model_id=model_id,
514+
name=name,
515+
reference_id=reference_id,
516+
metadata=metadata,
517+
bundle_name=bundle_name,
518+
client=self,
519+
)
507520

508521
@deprecated(
509522
"Model runs have been deprecated and will be removed. Use a Model instead"

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
MAX_PAYLOAD_SIZE = 0x1FFFFFE8 # Set to max string size since we currently convert payloads to strings for processing on the server-side
8989
MESSAGE_KEY = "message"
9090
METADATA_KEY = "metadata"
91+
MODEL_BUNDLE_NAME_KEY = "bundle_name"
9192
MODEL_ID_KEY = "model_id"
9293
MODEL_RUN_ID_KEY = "model_run_id"
9394
NAME_KEY = "name"

nucleus/model.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,22 +92,26 @@ class Model:
9292
endpoint, using :meth:`NucleusClient.add_model`.
9393
"""
9494

95-
def __init__(self, model_id, name, reference_id, metadata, client):
95+
def __init__(
96+
self, model_id, name, reference_id, metadata, client, bundle_name=None
97+
):
9698
self.id = model_id
9799
self.name = name
98100
self.reference_id = reference_id
99101
self.metadata = metadata
102+
self.bundle_name = bundle_name
100103
self._client = client
101104

102105
def __repr__(self):
103-
return f"Model(model_id='{self.id}', name='{self.name}', reference_id='{self.reference_id}', metadata={self.metadata}, client={self._client})"
106+
return f"Model(model_id='{self.id}', name='{self.name}', reference_id='{self.reference_id}', metadata={self.metadata}, bundle_name={self.bundle_name}, client={self._client})"
104107

105108
def __eq__(self, other):
106109
return (
107110
(self.id == other.id)
108111
and (self.name == other.name)
109112
and (self.metadata == other.metadata)
110113
and (self._client == other._client)
114+
and (self.bundle_name == other.bundle_name)
111115
)
112116

113117
def __hash__(self):
@@ -187,3 +191,25 @@ def evaluate(self, scenario_test_names: List[str]) -> AsyncJob:
187191
requests_command=requests.post,
188192
)
189193
return AsyncJob.from_json(response, self._client)
194+
195+
def run(self, dataset_id: str, slice_id: Optional[str]) -> str:
196+
"""Runs inference on the bundle associated with the model on the dataset. ::
197+
198+
import nucleus
199+
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
200+
model = client.list_models()[0]
201+
202+
model.run("ds_123456")
203+
204+
Args:
205+
dataset_id: id of dataset to run inference on
206+
job_id: nucleus job used to track async job progress
207+
slice_id: (optional) id of slice of the dataset to run inference on
208+
"""
209+
response = self._client.make_request(
210+
{"dataset_id": dataset_id, "slice_id": slice_id},
211+
f"model/run/{self.id}/",
212+
requests_command=requests.post,
213+
)
214+
215+
return response

nucleus/payload_constructor.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ITEMS_KEY,
1616
LABELS_KEY,
1717
METADATA_KEY,
18+
MODEL_BUNDLE_NAME_KEY,
1819
MODEL_ID_KEY,
1920
NAME_KEY,
2021
REFERENCE_ID_KEY,
@@ -122,14 +123,22 @@ def construct_box_predictions_payload(
122123

123124

124125
def construct_model_creation_payload(
125-
name: str, reference_id: str, metadata: Optional[Dict]
126+
name: str,
127+
reference_id: str,
128+
metadata: Optional[Dict],
129+
bundle_name: Optional[str],
126130
) -> dict:
127-
return {
131+
payload = {
128132
NAME_KEY: name,
129133
REFERENCE_ID_KEY: reference_id,
130134
METADATA_KEY: metadata if metadata else {},
131135
}
132136

137+
if bundle_name:
138+
payload[MODEL_BUNDLE_NAME_KEY] = bundle_name
139+
140+
return payload
141+
133142

134143
def construct_model_run_creation_payload(
135144
name: str,

0 commit comments

Comments
 (0)