Skip to content

Commit ebbf257

Browse files
committed
All objects have better repr now
1 parent 17aa444 commit ebbf257

File tree

8 files changed

+42
-7
lines changed

8 files changed

+42
-7
lines changed

nucleus/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@
141141
)
142142

143143

144-
145144
class NucleusClient:
146145
"""
147146
Nucleus client.
@@ -163,8 +162,6 @@ def __eq__(self, other):
163162
return True
164163
return False
165164

166-
167-
168165
def list_models(self) -> List[Model]:
169166
"""
170167
Lists available models in your repo.

nucleus/annotation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424
from dataclasses import dataclass
2525

26-
class Annotation:
2726

27+
class Annotation:
2828
def _check_ids(self):
2929
if not bool(self.reference_id) and not bool(self.item_id):
3030
raise Exception(
@@ -40,6 +40,7 @@ def from_json(cls, payload: dict):
4040
else:
4141
return SegmentationAnnotation.from_json(payload)
4242

43+
4344
@dataclass
4445
class Segment:
4546
label: str
@@ -71,11 +72,11 @@ class SegmentationAnnotation(Annotation):
7172
reference_id: Optional[str] = None
7273
item_id: Optional[str] = None
7374
annotation_id: Optional[str] = None
75+
7476
def __post_init__(self):
7577
if not self.mask_url:
7678
raise Exception("You must specify a mask_url.")
7779
self._check_ids()
78-
7980

8081
@classmethod
8182
def from_json(cls, payload: dict):
@@ -121,6 +122,7 @@ class BoxAnnotation(Annotation):
121122
item_id: Optional[str] = None
122123
annotation_id: Optional[str] = None
123124
metadata: Optional[Dict] = None
125+
124126
def __post_init__(self):
125127
self._check_ids()
126128
self.metadata = self.metadata if self.metadata else {}
@@ -165,6 +167,7 @@ class PolygonAnnotation(Annotation):
165167
item_id: Optional[str] = None
166168
annotation_id: Optional[str] = None
167169
metadata: Optional[Dict] = None
170+
168171
def __post_init__(self):
169172
self._check_ids()
170173
self.metadata = self.metadata if self.metadata else {}

nucleus/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self, dataset_id: str, client):
3434

3535
def __repr__(self):
3636
return f"Dataset(dataset_id='{self.id}', client={self._client})"
37-
37+
3838
def __eq__(self, other):
3939
if self.id == other.id:
4040
if self._client == other._client:

nucleus/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(
2929
self._client = client
3030

3131
def __repr__(self):
32-
return f"Model(model_id={self.id}, name={self.name}, reference_id={self.reference_id}, metadata={self.metadata}, client={self._client})"
32+
return f"Model(model_id='{self.id}', name='{self.name}', reference_id='{self.reference_id}', metadata={self.metadata}, client={self._client})"
3333

3434
def __eq__(self, other):
3535
return self.id == other.id

nucleus/model_run.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@ def __init__(self, model_run_id: str, client):
2323
self.model_run_id = model_run_id
2424
self._client = client
2525

26+
def __repr__(self):
27+
return f"ModelRun(model_run_id='{self.model_run_id}', client={self._client})"
28+
29+
def __eq__(self, other):
30+
if self.model_run_id == other.model_run_id:
31+
if self._client == other._client:
32+
return True
33+
return False
34+
2635
def info(self) -> dict:
2736
"""
2837
provides information about the Model Run:

nucleus/slice.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@ class Slice:
99
def __init__(self, slice_id: str, client):
1010
self.slice_id = slice_id
1111
self._client = client
12+
13+
def __repr__(self):
14+
return f"Slice(slice_id='{self.slice_id}', client={self._client})"
15+
16+
def __eq__(self, other):
17+
if self.slice_id == other.slice_id:
18+
if self._client == other._client:
19+
return True
20+
return False
1221

1322
def info(self) -> dict:
1423
"""

nucleus/upload_response.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ def __init__(self, json: dict):
3939
self.error_codes: Set[str] = set()
4040
self.error_payload = upload_error_payload
4141

42+
def __repr__(self):
43+
return f"UploadResponse(json={self.json()})"
44+
45+
def __eq__(self, other):
46+
return self.json() == other.json()
47+
4248
def update_response(self, json):
4349
"""
4450
:param json: { new_items: int, updated_items: int, ignored_items: int, upload_errors: int, }

tests/test_models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Model,
88
ModelRun,
99
BoxPrediction,
10+
NucleusClient
1011
)
1112
from nucleus.constants import (
1213
NEW_ITEMS,
@@ -23,6 +24,16 @@
2324
TEST_PREDS,
2425
)
2526

27+
# Have to define here in order to have access to all relevant objects
28+
def test_repr(test_object: any):
29+
assert eval(str(test_object)) == test_object
30+
31+
def test_reprs():
32+
client = NucleusClient(api_key="fake_key")
33+
test_repr(Model(client=client, model_id="fake_model_id", name="fake_name", reference_id="fake_reference_id", metadata={"fake": "metadata"}))
34+
test_repr(ModelRun(client=client, model_run_id="fake_model_run_id"))
35+
36+
2637

2738
def test_model_creation_and_listing(CLIENT, dataset):
2839
models_before = CLIENT.list_models()

0 commit comments

Comments
 (0)