Skip to content

Commit 8c5f079

Browse files
ardilaUbuntu
andauthored
model using model run id (#234)
* model using model run id * Review feedback Co-authored-by: Ubuntu <diego.ardila@scale.com>
1 parent 464f3d0 commit 8c5f079

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

nucleus/__init__.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,19 +290,34 @@ def get_job(self, job_id: str) -> AsyncJob:
290290
)
291291
return AsyncJob.from_json(payload=payload, client=self)
292292

293-
def get_model(self, model_id: str) -> Model:
293+
def get_model(
294+
self, model_id: str = None, model_run_id: str = None
295+
) -> Model:
294296
"""Fetches a model by its ID.
295297
296298
Parameters:
297-
model_id: Nucleus-generated model ID (starts with ``prj_``). This can
298-
be retrieved via :meth:`list_models` or a Nucleus dashboard URL.
299+
model_id: You can pass either a model ID (starts with ``prj_``) or a model run id (starts with ``run_``) This can
300+
be retrieved via :meth:`list_models` or a Nucleus dashboard URL. Model run ids result from the application of a model to a dataset.
301+
model_run_id: You can pass either a model ID (starts with ``prj_``), or a model run id (starts with ``run_``) This can
302+
be retrieved via :meth:`list_models` or a Nucleus dashboard URL. Model run ids result from the application of a model to a dataset.
303+
304+
In the future, we plan to hide model_run_ids fully from users.
299305
300306
Returns:
301307
:class:`Model`: The Nucleus model as an object.
302308
"""
309+
if model_id is None and model_run_id is None:
310+
raise ValueError("Must pass either a model_id or a model_run_id")
311+
if model_id is not None and model_run_id is not None:
312+
raise ValueError("Must pass either a model_id or a model_run_id")
313+
314+
model_or_model_run_id = (
315+
model_id if model_id is not None else model_run_id
316+
)
317+
303318
payload = self.make_request(
304319
payload={},
305-
route=f"model/{model_id}",
320+
route=f"model/{model_or_model_run_id}",
306321
requests_command=requests.get,
307322
)
308323
return Model.from_json(payload=payload, client=self)

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def test_model_creation_and_listing(CLIENT, dataset):
6262
model = CLIENT.create_model(model_name, model_reference)
6363
model_run = TEST_MODEL_RUN + get_uuid()
6464
m_run = model.create_run(model_run, dataset, TEST_PREDS)
65-
m_run.commit()
6665

6766
assert isinstance(model, Model)
6867
assert isinstance(m_run, ModelRun)
@@ -72,6 +71,7 @@ def test_model_creation_and_listing(CLIENT, dataset):
7271

7372
# Get a model
7473
m = CLIENT.get_model(model.id)
74+
m = CLIENT.get_model(m_run.model_run_id)
7575
assert m == model
7676

7777
assert model in ms

0 commit comments

Comments
 (0)