Skip to content

Commit 98eb8fd

Browse files
committed
updates to type hinting and adding to some inits as well as providing additional errors
1 parent 8037a3e commit 98eb8fd

File tree

9 files changed

+61
-19
lines changed

9 files changed

+61
-19
lines changed

labelbox/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
from labelbox.schema.user import User
1414
from labelbox.schema.organization import Organization
1515
from labelbox.schema.task import Task
16-
from labelbox.schema.labeling_frontend import LabelingFrontend
16+
from labelbox.schema.labeling_frontend import LabelingFrontend, LabelingFrontendOptions
1717
from labelbox.schema.asset_attachment import AssetAttachment
1818
from labelbox.schema.webhook import Webhook
1919
from labelbox.schema.ontology import Ontology, OntologyBuilder, Classification, Option, Tool, FeatureSchema
2020
from labelbox.schema.role import Role, ProjectRole
2121
from labelbox.schema.invite import Invite, InviteLimit
2222
from labelbox.schema.data_row_metadata import DataRowMetadataOntology
2323
from labelbox.schema.model_run import ModelRun
24+
from labelbox.schema.benchmark import Benchmark
25+
from labelbox.schema.iam_integration import IAMIntegration

labelbox/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
from labelbox.schema.model import Model
3030
from labelbox.schema.ontology import Ontology, Tool, Classification
3131
from labelbox.schema.organization import Organization
32+
from labelbox.schema.user import User
33+
from labelbox.schema.project import Project
34+
from labelbox.schema.role import Role
3235

3336
logger = logging.getLogger(__name__)
3437

labelbox/data/serialization/labelbox_v1/converter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from labelbox.data.serialization.labelbox_v1.objects import LBV1Mask
2-
from typing import Any, Dict, Generator, Iterable
2+
from typing import Any, Dict, Generator, Iterable, Union
33
import logging
44

55
import ndjson
@@ -19,7 +19,7 @@
1919
class LBV1Converter:
2020

2121
@staticmethod
22-
def deserialize_video(json_data: Iterable[Dict[str, Any]],
22+
def deserialize_video(json_data: Union[str, Iterable[Dict[str, Any]]],
2323
client: "labelbox.Client") -> LabelGenerator:
2424
"""
2525
Converts a labelbox video export into the common labelbox format.
@@ -36,7 +36,8 @@ def deserialize_video(json_data: Iterable[Dict[str, Any]],
3636
return LabelGenerator(data=label_generator)
3737

3838
@staticmethod
39-
def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator:
39+
def deserialize(
40+
json_data: Union[str, Iterable[Dict[str, Any]]]) -> LabelGenerator:
4041
"""
4142
Converts a labelbox export (non-video) into the common labelbox format.
4243

labelbox/orm/model.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from enum import Enum, auto
2-
from typing import Dict, List, Union
2+
from typing import Dict, List, Union, Any, Type, TYPE_CHECKING
33

4+
import labelbox
45
from labelbox import utils
56
from labelbox.exceptions import InvalidAttributeError
67
from labelbox.orm.comparison import Comparison
@@ -245,6 +246,9 @@ class EntityMeta(type):
245246
# Maps Entity name to Relationships for all currently defined Entities
246247
relationship_mappings: Dict[str, List[Relationship]] = {}
247248

249+
def __setattr__(self, key: Any, value: Any):
250+
super().__setattr__(key, value)
251+
248252
def __init__(cls, clsname, superclasses, attributedict):
249253
super().__init__(clsname, superclasses, attributedict)
250254
cls.validate_cached_relationships()
@@ -325,6 +329,21 @@ class Entity(metaclass=EntityMeta):
325329
# suchs as `fields()`.
326330
deleted = Field.Boolean("deleted")
327331

332+
if TYPE_CHECKING:
333+
DataRow: Type[labelbox.DataRow]
334+
Webhook: Type[labelbox.Webhook]
335+
Task: Type[labelbox.Task]
336+
AssetAttachment: Type[labelbox.AssetAttachment]
337+
ModelRun: Type[labelbox.ModelRun]
338+
Review: Type[labelbox.Review]
339+
User: Type[labelbox.User]
340+
LabelingFrontend: Type[labelbox.LabelingFrontend]
341+
BulkImportRequest: Type[labelbox.BulkImportRequest]
342+
Benchmark: Type[labelbox.Benchmark]
343+
IAMIntegration: Type[labelbox.IAMIntegration]
344+
LabelingFrontendOptions: Type[labelbox.LabelingFrontendOptions]
345+
Label: Type[labelbox.Label]
346+
328347
@classmethod
329348
def _attributes_of_type(cls, attr_type):
330349
""" Yields all the attributes in `cls` of the given `attr_type`. """

labelbox/pagination.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Size of a single page in a paginated query.
22
from abc import ABC, abstractmethod
3-
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
3+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
44

55
from typing import TYPE_CHECKING
66
if TYPE_CHECKING:
@@ -23,8 +23,8 @@ def __init__(self,
2323
client: "Client",
2424
query: str,
2525
params: Dict[str, str],
26-
dereferencing: Dict[str, Any],
27-
obj_class: Type["DbObject"],
26+
dereferencing: Union[List[str], Dict[str, Any]],
27+
obj_class: Union[Type["DbObject"], Callable[[Any, Any], Any]],
2828
cursor_path: Optional[Dict[str, Any]] = None,
2929
experimental: bool = False):
3030
""" Creates a PaginatedCollection.

labelbox/schema/dataset.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Generator, List
1+
from typing import Generator, List, Union, Any
22
import os
33
import json
44
import logging
@@ -10,6 +10,7 @@
1010
from labelbox.data.serialization.ndjson.base import DataRow
1111
from labelbox.schema import iam_integration
1212
from labelbox.schema.task import Task
13+
from labelbox.schema.user import User
1314
from labelbox import utils
1415
from concurrent.futures import ThreadPoolExecutor, as_completed
1516
from io import StringIO
@@ -121,7 +122,7 @@ def create_data_rows_sync(self, items) -> None:
121122
url_param: descriptor_url
122123
})
123124

124-
def create_data_rows(self, items) -> Task:
125+
def create_data_rows(self, items) -> Union[Task, List[Any]]:
125126
""" Asynchronously bulk upload data rows
126127
127128
Use this instead of `Dataset.create_data_rows_sync` uploads for batches that contain more than 1000 data rows.
@@ -164,14 +165,15 @@ def create_data_rows(self, items) -> Task:
164165

165166
# Fetch and return the task.
166167
task_id = res["taskId"]
167-
user = self.client.get_user()
168-
task = list(user.created_tasks(where=Entity.Task.uid == task_id))
168+
user: User = self.client.get_user()
169+
tasks: List[Task] = list(
170+
user.created_tasks(where=Entity.Task.uid == task_id))
169171
# Cache user in a private variable as the relationship can't be
170172
# resolved due to server-side limitations (see Task.created_by)
171173
# for more info.
172-
if len(task) != 1:
174+
if len(tasks) != 1:
173175
raise ResourceNotFoundError(Entity.Task, task_id)
174-
task = task[0]
176+
task: Task = tasks[0]
175177
task._user = user
176178
return task
177179

labelbox/schema/ontology.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from pydantic import constr
99

10-
from labelbox.schema import project
10+
# from labelbox.schema import project
1111
from labelbox.exceptions import InconsistentOntologyException
1212
from labelbox.orm.db_object import DbObject
1313
from labelbox.orm.model import Field, Relationship

labelbox/schema/project.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import namedtuple
77
from datetime import datetime, timezone
88
from pathlib import Path
9-
from typing import Dict, Union, Iterable, List, Optional
9+
from typing import Dict, Union, Iterable, List, Optional, Any
1010
from urllib.parse import urlparse
1111

1212
import ndjson
@@ -203,6 +203,10 @@ def video_label_generator(self, timeout_seconds=600) -> LabelGenerator:
203203
_check_converter_import()
204204
json_data = self.export_labels(download=True,
205205
timeout_seconds=timeout_seconds)
206+
# assert that the instance this would fail is only if timeout runs out
207+
assert isinstance(
208+
json_data,
209+
List), "Unable to successfully get labels. Please try again"
206210
if json_data is None:
207211
raise TimeoutError(
208212
f"Unable to download labels in {timeout_seconds} seconds."
@@ -227,6 +231,10 @@ def label_generator(self, timeout_seconds=600) -> LabelGenerator:
227231
_check_converter_import()
228232
json_data = self.export_labels(download=True,
229233
timeout_seconds=timeout_seconds)
234+
# assert that the instance this would fail is only if timeout runs out
235+
assert isinstance(
236+
json_data,
237+
List), "Unable to successfully get labels. Please try again"
230238
if json_data is None:
231239
raise TimeoutError(
232240
f"Unable to download labels in {timeout_seconds} seconds."
@@ -241,7 +249,10 @@ def label_generator(self, timeout_seconds=600) -> LabelGenerator:
241249
"Or use project.video_label_generator() for video data.")
242250
return LBV1Converter.deserialize(json_data)
243251

244-
def export_labels(self, download=False, timeout_seconds=600) -> str:
252+
def export_labels(
253+
self,
254+
download=False,
255+
timeout_seconds=600) -> Optional[Union[str, List[Dict[Any, Any]]]]:
245256
""" Calls the server-side Label exporting that generates a JSON
246257
payload, and returns the URL to that payload.
247258
@@ -564,7 +575,7 @@ def _update_queue_mode(self, mode: QueueMode) -> QueueMode:
564575

565576
return mode
566577

567-
def queue_mode(self) -> str:
578+
def queue_mode(self) -> QueueMode:
568579
"""Provides the status of if queue mode is enabled in the project."""
569580

570581
query_str = """query %s($projectId: ID!) {

labelbox/schema/task.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import logging
22
import time
3+
from typing import Optional
34

45
from labelbox.exceptions import ResourceNotFoundError
56
from labelbox.orm.db_object import DbObject
6-
from labelbox.orm.model import Field, Relationship
7+
from labelbox.orm.model import Field, Relationship, Entity
8+
from labelbox.schema.user import User
79

810
logger = logging.getLogger(__name__)
911

@@ -27,13 +29,15 @@ class Task(DbObject):
2729
name = Field.String("name")
2830
status = Field.String("status")
2931
completion_percentage = Field.Float("completion_percentage")
32+
_user: Optional[User] = None
3033

3134
# Relationships
3235
created_by = Relationship.ToOne("User", False, "created_by")
3336
organization = Relationship.ToOne("Organization")
3437

3538
def refresh(self) -> None:
3639
""" Refreshes Task data from the server. """
40+
assert self._user is not None
3741
tasks = list(self._user.created_tasks(where=Task.uid == self.uid))
3842
if len(tasks) != 1:
3943
raise ResourceNotFoundError(Task, self.uid)

0 commit comments

Comments
 (0)