Skip to content

Commit c8887be

Browse files
authored
Merge branch 'develop' into al-1351
2 parents 8240417 + 89c7877 commit c8887be

30 files changed

+284
-187
lines changed

labelbox/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
from labelbox.schema.user import User
1414
from labelbox.schema.organization import Organization
1515
from labelbox.schema.task import Task
16-
from labelbox.schema.labeling_frontend import LabelingFrontend
16+
from labelbox.schema.labeling_frontend import LabelingFrontend, LabelingFrontendOptions
1717
from labelbox.schema.asset_attachment import AssetAttachment
1818
from labelbox.schema.webhook import Webhook
1919
from labelbox.schema.ontology import Ontology, OntologyBuilder, Classification, Option, Tool, FeatureSchema
2020
from labelbox.schema.role import Role, ProjectRole
2121
from labelbox.schema.invite import Invite, InviteLimit
2222
from labelbox.schema.data_row_metadata import DataRowMetadataOntology
2323
from labelbox.schema.model_run import ModelRun
24+
from labelbox.schema.benchmark import Benchmark
25+
from labelbox.schema.iam_integration import IAMIntegration

labelbox/client.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from google.api_core import retry
1212
import requests
1313
import requests.exceptions
14+
from labelbox.data.annotation_types.feature import FeatureSchema
15+
from labelbox.data.serialization.ndjson.base import DataRow
1416

1517
import labelbox.exceptions
1618
from labelbox import utils
@@ -20,9 +22,16 @@
2022
from labelbox.orm.model import Entity
2123
from labelbox.pagination import PaginatedCollection
2224
from labelbox.schema.data_row_metadata import DataRowMetadataOntology
25+
from labelbox.schema.dataset import Dataset
2326
from labelbox.schema.iam_integration import IAMIntegration
2427
from labelbox.schema import role
25-
from labelbox.schema.ontology import Tool, Classification
28+
from labelbox.schema.labeling_frontend import LabelingFrontend
29+
from labelbox.schema.model import Model
30+
from labelbox.schema.ontology import Ontology, Tool, Classification
31+
from labelbox.schema.organization import Organization
32+
from labelbox.schema.user import User
33+
from labelbox.schema.project import Project
34+
from labelbox.schema.role import Role
2635

2736
logger = logging.getLogger(__name__)
2837

@@ -411,7 +420,7 @@ def get_project(self, project_id):
411420
"""
412421
return self._get_single(Entity.Project, project_id)
413422

414-
def get_dataset(self, dataset_id):
423+
def get_dataset(self, dataset_id) -> Dataset:
415424
""" Gets a single Dataset with the given ID.
416425
417426
>>> dataset = client.get_dataset("<dataset_id>")
@@ -426,14 +435,14 @@ def get_dataset(self, dataset_id):
426435
"""
427436
return self._get_single(Entity.Dataset, dataset_id)
428437

429-
def get_user(self):
438+
def get_user(self) -> User:
430439
""" Gets the current User database object.
431440
432441
>>> user = client.get_user()
433442
"""
434443
return self._get_single(Entity.User, None)
435444

436-
def get_organization(self):
445+
def get_organization(self) -> Organization:
437446
""" Gets the Organization DB object of the current user.
438447
439448
>>> organization = client.get_organization()
@@ -461,7 +470,7 @@ def _get_all(self, db_object_type, where, filter_deleted=True):
461470
[utils.camel_case(db_object_type.type_name()) + "s"],
462471
db_object_type)
463472

464-
def get_projects(self, where=None):
473+
def get_projects(self, where=None) -> List[Project]:
465474
""" Fetches all the projects the user has access to.
466475
467476
>>> projects = client.get_projects(where=(Project.name == "<project_name>") & (Project.description == "<project_description>"))
@@ -474,7 +483,7 @@ def get_projects(self, where=None):
474483
"""
475484
return self._get_all(Entity.Project, where)
476485

477-
def get_datasets(self, where=None):
486+
def get_datasets(self, where=None) -> List[Dataset]:
478487
""" Fetches one or more datasets.
479488
480489
>>> datasets = client.get_datasets(where=(Dataset.name == "<dataset_name>") & (Dataset.description == "<dataset_description>"))
@@ -487,7 +496,7 @@ def get_datasets(self, where=None):
487496
"""
488497
return self._get_all(Entity.Dataset, where)
489498

490-
def get_labeling_frontends(self, where=None):
499+
def get_labeling_frontends(self, where=None) -> List[LabelingFrontend]:
491500
""" Fetches all the labeling frontends.
492501
493502
>>> frontend = client.get_labeling_frontends(where=LabelingFrontend.name == "Editor")
@@ -527,7 +536,9 @@ def _create(self, db_object_type, data):
527536
res = res["create%s" % db_object_type.type_name()]
528537
return db_object_type(self, res)
529538

530-
def create_dataset(self, iam_integration=IAMIntegration._DEFAULT, **kwargs):
539+
def create_dataset(self,
540+
iam_integration=IAMIntegration._DEFAULT,
541+
**kwargs) -> Dataset:
531542
""" Creates a Dataset object on the server.
532543
533544
Attribute values are passed as keyword arguments.
@@ -585,7 +596,7 @@ def create_dataset(self, iam_integration=IAMIntegration._DEFAULT, **kwargs):
585596
raise e
586597
return dataset
587598

588-
def create_project(self, **kwargs):
599+
def create_project(self, **kwargs) -> Project:
589600
""" Creates a Project object on the server.
590601
591602
Attribute values are passed as keyword arguments.
@@ -602,15 +613,15 @@ def create_project(self, **kwargs):
602613
"""
603614
return self._create(Entity.Project, kwargs)
604615

605-
def get_roles(self):
616+
def get_roles(self) -> List[Role]:
606617
"""
607618
Returns:
608619
Roles: Provides information on available roles within an organization.
609620
Roles are used for user management.
610621
"""
611622
return role.get_roles(self)
612623

613-
def get_data_row(self, data_row_id):
624+
def get_data_row(self, data_row_id) -> DataRow:
614625
"""
615626
616627
Returns:
@@ -619,7 +630,7 @@ def get_data_row(self, data_row_id):
619630

620631
return self._get_single(Entity.DataRow, data_row_id)
621632

622-
def get_data_row_metadata_ontology(self):
633+
def get_data_row_metadata_ontology(self) -> DataRowMetadataOntology:
623634
"""
624635
625636
Returns:
@@ -628,7 +639,7 @@ def get_data_row_metadata_ontology(self):
628639
"""
629640
return DataRowMetadataOntology(self)
630641

631-
def get_model(self, model_id):
642+
def get_model(self, model_id) -> Model:
632643
""" Gets a single Model with the given ID.
633644
634645
>>> model = client.get_model("<model_id>")
@@ -643,7 +654,7 @@ def get_model(self, model_id):
643654
"""
644655
return self._get_single(Entity.Model, model_id)
645656

646-
def get_models(self, where=None):
657+
def get_models(self, where=None) -> List[Model]:
647658
""" Fetches all the models the user has access to.
648659
649660
>>> models = client.get_models(where=(Model.name == "<model_name>"))
@@ -656,7 +667,7 @@ def get_models(self, where=None):
656667
"""
657668
return self._get_all(Entity.Model, where, filter_deleted=False)
658669

659-
def create_model(self, name, ontology_id):
670+
def create_model(self, name, ontology_id) -> Model:
660671
""" Creates a Model object on the server.
661672
662673
>>> model = client.create_model(<model_name>, <ontology_id>)
@@ -707,7 +718,7 @@ def get_data_row_ids_for_external_ids(
707718
result[row['externalId']].append(row['dataRowId'])
708719
return result
709720

710-
def get_ontology(self, ontology_id):
721+
def get_ontology(self, ontology_id) -> Ontology:
711722
"""
712723
Fetches an Ontology by id.
713724
@@ -718,7 +729,7 @@ def get_ontology(self, ontology_id):
718729
"""
719730
return self._get_single(Entity.Ontology, ontology_id)
720731

721-
def get_ontologies(self, name_contains):
732+
def get_ontologies(self, name_contains) -> PaginatedCollection:
722733
"""
723734
Fetches all ontologies with names that match the name_contains string.
724735
@@ -739,7 +750,7 @@ def get_ontologies(self, name_contains):
739750
['ontologies', 'nodes'], Entity.Ontology,
740751
['ontologies', 'nextCursor'])
741752

742-
def get_feature_schema(self, feature_schema_id):
753+
def get_feature_schema(self, feature_schema_id) -> FeatureSchema:
743754
"""
744755
Fetches a feature schema. Only supports top level feature schemas.
745756
@@ -760,7 +771,7 @@ def get_feature_schema(self, feature_schema_id):
760771
res['id'] = res['normalized']['featureSchemaId']
761772
return Entity.FeatureSchema(self, res)
762773

763-
def get_feature_schemas(self, name_contains):
774+
def get_feature_schemas(self, name_contains) -> PaginatedCollection:
764775
"""
765776
Fetches top level feature schemas with names that match the `name_contains` string
766777
@@ -789,7 +800,8 @@ def rootSchemaPayloadToFeatureSchema(client, payload):
789800
rootSchemaPayloadToFeatureSchema,
790801
['rootSchemaNodes', 'nextCursor'])
791802

792-
def create_ontology_from_feature_schemas(self, name, feature_schema_ids):
803+
def create_ontology_from_feature_schemas(self, name,
804+
feature_schema_ids) -> Ontology:
793805
"""
794806
Creates an ontology from a list of feature schema ids
795807
@@ -828,7 +840,7 @@ def create_ontology_from_feature_schemas(self, name, feature_schema_ids):
828840
normalized = {'tools': tools, 'classifications': classifications}
829841
return self.create_ontology(name, normalized)
830842

831-
def create_ontology(self, name, normalized):
843+
def create_ontology(self, name, normalized) -> Ontology:
832844
"""
833845
Creates an ontology from normalized data
834846
>>> normalized = {"tools" : [{'tool': 'polygon', 'name': 'cat', 'color': 'black'}], "classifications" : []}
@@ -855,7 +867,7 @@ def create_ontology(self, name, normalized):
855867
res = self.execute(query_str, params)
856868
return Entity.Ontology(self, res['upsertOntology'])
857869

858-
def create_feature_schema(self, normalized):
870+
def create_feature_schema(self, normalized) -> FeatureSchema:
859871
"""
860872
Creates a feature schema from normalized data.
861873
>>> normalized = {'tool': 'polygon', 'name': 'cat', 'color': 'black'}

labelbox/data/annotation_types/classification/classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ClassificationAnswer(FeatureSchema):
3333
extra: Dict[str, Any] = {}
3434
keyframe: Optional[bool] = None
3535

36-
def dict(self, *args, **kwargs):
36+
def dict(self, *args, **kwargs) -> Dict[str, str]:
3737
res = super().dict(*args, **kwargs)
3838
if res['keyframe'] is None:
3939
res.pop('keyframe')

labelbox/data/annotation_types/collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _ensure_unique_external_ids(self) -> None:
133133
)
134134
external_ids.add(label.data.external_id)
135135

136-
def append(self, label: Label):
136+
def append(self, label: Label) -> None:
137137
self._data.append(label)
138138

139139
def __iter__(self) -> "LabelList":

labelbox/data/annotation_types/data/raster.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ class RasterData(BaseModel, ABC):
2323

2424
@classmethod
2525
def from_2D_arr(cls, arr: Union[TypedArray[Literal['uint8']],
26-
TypedArray[Literal['int']]], **kwargs):
26+
TypedArray[Literal['int']]],
27+
**kwargs) -> "RasterData":
2728
"""Construct from a 2D numpy array
2829
2930
Args:

labelbox/data/annotation_types/data/text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class TextData(BaseData):
1111
"""
1212
Represents text data. Requires arg file_path, text, or url
1313
14-
>>> TextData(text="")
14+
>>> TextData(text="")
1515
1616
Args:
1717
file_path (str)

labelbox/data/annotation_types/data/tiled_image.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class TileLayer(BaseModel):
9595
url: str
9696
name: Optional[str] = "default"
9797

98-
def asdict(self) -> Dict:
98+
def asdict(self) -> Dict[str, str]:
9999
return {"tileLayerUrl": self.url, "name": self.name}
100100

101101
@validator('url')
@@ -138,7 +138,7 @@ def __post_init__(self) -> None:
138138
if self.max_native_zoom is None:
139139
self.max_native_zoom = self.zoom_levels[0]
140140

141-
def asdict(self) -> Dict:
141+
def asdict(self) -> Dict[str, str]:
142142
return {
143143
"tileLayerUrl": self.tile_layer.url,
144144
"bounds": [[
@@ -411,7 +411,7 @@ def geo_and_pixel(cls,
411411

412412
if src_epsg == EPSG.SIMPLEPIXEL:
413413

414-
def transform(x: int, y: int) -> Callable:
414+
def transform(x: int, y: int) -> Callable[[int, int], Transformer]:
415415
scaled_xy = (x * (global_x_range) / (local_x_range),
416416
y * (global_y_range) / (local_y_range))
417417

@@ -431,7 +431,7 @@ def transform(x: int, y: int) -> Callable:
431431
#handles 4326 from lat,lng
432432
elif src_epsg == EPSG.EPSG4326:
433433

434-
def transform(x: int, y: int) -> Callable:
434+
def transform(x: int, y: int) -> Callable[[int, int], Transformer]:
435435
point_in_px = PygeoPoint.from_latitude_longitude(
436436
latitude=y, longitude=x).pixels(zoom)
437437

@@ -446,7 +446,7 @@ def transform(x: int, y: int) -> Callable:
446446
#handles 3857 from meters
447447
elif src_epsg == EPSG.EPSG3857:
448448

449-
def transform(x: int, y: int) -> Callable:
449+
def transform(x: int, y: int) -> Callable[[int, int], Transformer]:
450450
point_in_px = PygeoPoint.from_meters(meter_y=y,
451451
meter_x=x).pixels(zoom)
452452

@@ -459,8 +459,9 @@ def transform(x: int, y: int) -> Callable:
459459
return transform
460460

461461
@classmethod
462-
def create_geo_to_geo_transformer(cls, src_epsg: EPSG,
463-
tgt_epsg: EPSG) -> Callable:
462+
def create_geo_to_geo_transformer(
463+
cls, src_epsg: EPSG,
464+
tgt_epsg: EPSG) -> Callable[[int, int], Transformer]:
464465
"""method to change from one projection to another projection.
465466
466467
supports EPSG transformations not Simple.
@@ -474,11 +475,12 @@ def create_geo_to_geo_transformer(cls, src_epsg: EPSG,
474475
src_epsg.value, tgt_epsg.value, always_xy=True).transform)
475476

476477
@classmethod
477-
def create_geo_to_pixel_transformer(cls,
478-
src_epsg,
479-
pixel_bounds: TiledBounds,
480-
geo_bounds: TiledBounds,
481-
zoom=0) -> Callable:
478+
def create_geo_to_pixel_transformer(
479+
cls,
480+
src_epsg,
481+
pixel_bounds: TiledBounds,
482+
geo_bounds: TiledBounds,
483+
zoom=0) -> Callable[[int, int], Transformer]:
482484
"""method to change from a geo projection to Simple"""
483485

484486
transform_function = cls.geo_and_pixel(src_epsg=src_epsg,
@@ -488,11 +490,12 @@ def create_geo_to_pixel_transformer(cls,
488490
return EPSGTransformer(transformer=transform_function)
489491

490492
@classmethod
491-
def create_pixel_to_geo_transformer(cls,
492-
src_epsg,
493-
pixel_bounds: TiledBounds,
494-
geo_bounds: TiledBounds,
495-
zoom=0) -> Callable:
493+
def create_pixel_to_geo_transformer(
494+
cls,
495+
src_epsg,
496+
pixel_bounds: TiledBounds,
497+
geo_bounds: TiledBounds,
498+
zoom=0) -> Callable[[int, int], Transformer]:
496499
"""method to change from a geo projection to Simple"""
497500
transform_function = cls.geo_and_pixel(src_epsg=src_epsg,
498501
pixel_bounds=pixel_bounds,

labelbox/data/annotation_types/geometry/mask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Optional, Tuple, Union, List
1+
from typing import Callable, Optional, Tuple, Union, Dict, List
22

33
import numpy as np
44
from pydantic.class_validators import validator
@@ -36,7 +36,7 @@ class Mask(Geometry):
3636
color: Union[Tuple[int, int, int], int]
3737

3838
@property
39-
def geometry(self):
39+
def geometry(self) -> Dict[str, Tuple[int, int, int]]:
4040
mask = self.draw(color=1)
4141
contours, hierarchy = cv2.findContours(image=mask,
4242
mode=cv2.RETR_TREE,

labelbox/data/annotation_types/geometry/rectangle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,6 @@ def draw(self,
7474
return cv2.polylines(canvas, pts, True, color, thickness)
7575

7676
@classmethod
77-
def from_xyhw(cls, x: float, y: float, h: float, w: float):
77+
def from_xyhw(cls, x: float, y: float, h: float, w: float) -> "Rectangle":
7878
"""Create Rectangle from x,y, height width format"""
7979
return cls(start=Point(x=x, y=y), end=Point(x=x + w, y=y + h))

0 commit comments

Comments
 (0)