Skip to content

Commit fb2f71d

Browse files
author
Matt Sokoloff
committed
clean up client
1 parent e3fba71 commit fb2f71d

File tree

1 file changed

+23
-28
lines changed

1 file changed

+23
-28
lines changed

labelbox/client.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,12 @@
1717
from labelbox import __version__ as SDK_VERSION
1818
from labelbox.orm import query
1919
from labelbox.orm.db_object import DbObject
20+
from labelbox.orm.model import Entity
2021
from labelbox.pagination import PaginatedCollection
21-
from labelbox.schema.project import Project
22-
from labelbox.schema.dataset import Dataset
23-
from labelbox.schema.data_row import DataRow
24-
from labelbox.schema.model import Model
25-
from labelbox.schema.user import User
26-
from labelbox.schema.organization import Organization
2722
from labelbox.schema.data_row_metadata import DataRowMetadataOntology
28-
from labelbox.schema.labeling_frontend import LabelingFrontend
2923
from labelbox.schema.iam_integration import IAMIntegration
3024
from labelbox.schema import role
31-
from labelbox.orm.model import Entity
25+
3226

3327
logger = logging.getLogger(__name__)
3428

@@ -430,22 +424,22 @@ def get_dataset(self, dataset_id):
430424
labelbox.exceptions.ResourceNotFoundError: If there is no
431425
Dataset with the given ID.
432426
"""
433-
return self._get_single(Dataset, dataset_id)
427+
return self._get_single(Entity.Dataset, dataset_id)
434428

435429
def get_user(self):
436430
""" Gets the current User database object.
437431
438432
>>> user = client.get_user()
439433
"""
440-
return self._get_single(User, None)
434+
return self._get_single(Entity.User, None)
441435

442436
def get_organization(self):
443437
""" Gets the Organization DB object of the current user.
444438
445439
>>> organization = client.get_organization()
446440
447441
"""
448-
return self._get_single(Organization, None)
442+
return self._get_single(Entity.Organization, None)
449443

450444
def _get_all(self, db_object_type, where, filter_deleted=True):
451445
""" Fetches all the objects of the given type the user has access to.
@@ -478,7 +472,7 @@ def get_projects(self, where=None):
478472
Returns:
479473
An iterable of Projects (typically a PaginatedCollection).
480474
"""
481-
return self._get_all(Project, where)
475+
return self._get_all(Entity.Project, where)
482476

483477
def get_datasets(self, where=None):
484478
""" Fetches one or more datasets.
@@ -491,7 +485,7 @@ def get_datasets(self, where=None):
491485
Returns:
492486
An iterable of Datasets (typically a PaginatedCollection).
493487
"""
494-
return self._get_all(Dataset, where)
488+
return self._get_all(Entity.Dataset, where)
495489

496490
def get_labeling_frontends(self, where=None):
497491
""" Fetches all the labeling frontends.
@@ -504,7 +498,7 @@ def get_labeling_frontends(self, where=None):
504498
Returns:
505499
An iterable of LabelingFrontends (typically a PaginatedCollection).
506500
"""
507-
return self._get_all(LabelingFrontend, where)
501+
return self._get_all(Entity.LabelingFrontend, where)
508502

509503
def _create(self, db_object_type, data):
510504
""" Creates an object on the server. Attribute values are
@@ -551,7 +545,7 @@ def create_dataset(self, iam_integration=IAMIntegration._DEFAULT, **kwargs):
551545
InvalidAttributeError: If the Dataset type does not contain
552546
any of the attribute names given in kwargs.
553547
"""
554-
dataset = self._create(Dataset, kwargs)
548+
dataset = self._create(Entity.Dataset, kwargs)
555549

556550
if iam_integration == IAMIntegration._DEFAULT:
557551
iam_integration = self.get_organization(
@@ -560,14 +554,15 @@ def create_dataset(self, iam_integration=IAMIntegration._DEFAULT, **kwargs):
560554
if iam_integration is None:
561555
return dataset
562556

563-
if not isinstance(iam_integration, IAMIntegration):
564-
raise TypeError(
565-
f"iam integration must be a reference an `IAMIntegration` object. Found {type(iam_integration)}"
566-
)
567-
568-
if not iam_integration.valid:
569-
raise ValueError("Integration is not valid. Please select another.")
570557
try:
558+
if not isinstance(iam_integration, IAMIntegration):
559+
raise TypeError(
560+
f"iam integration must be a reference an `IAMIntegration` object. Found {type(iam_integration)}"
561+
)
562+
563+
if not iam_integration.valid:
564+
raise ValueError("Integration is not valid. Please select another.")
565+
571566
self.execute(
572567
"""mutation setSignerForDatasetPyApi($signerId: ID!, $datasetId: ID!) {
573568
setSignerForDataset(data: { signerId: $signerId}, where: {id: $datasetId}){id}}
@@ -604,7 +599,7 @@ def create_project(self, **kwargs):
604599
InvalidAttributeError: If the Project type does not contain
605600
any of the attribute names given in kwargs.
606601
"""
607-
return self._create(Project, kwargs)
602+
return self._create(Entity.Project, kwargs)
608603

609604
def get_roles(self):
610605
"""
@@ -621,7 +616,7 @@ def get_data_row(self, data_row_id):
621616
DataRow: returns a single data row given the data row id
622617
"""
623618

624-
return self._get_single(DataRow, data_row_id)
619+
return self._get_single(Entity.DataRow, data_row_id)
625620

626621
def get_data_row_metadata_ontology(self):
627622
"""
@@ -645,7 +640,7 @@ def get_model(self, model_id):
645640
labelbox.exceptions.ResourceNotFoundError: If there is no
646641
Model with the given ID.
647642
"""
648-
return self._get_single(Model, model_id)
643+
return self._get_single(Entity.Model, model_id)
649644

650645
def get_models(self, where=None):
651646
""" Fetches all the models the user has access to.
@@ -658,7 +653,7 @@ def get_models(self, where=None):
658653
Returns:
659654
An iterable of Models (typically a PaginatedCollection).
660655
"""
661-
return self._get_all(Model, where, filter_deleted=False)
656+
return self._get_all(Entity.Model, where, filter_deleted=False)
662657

663658
def create_model(self, name, ontology_id):
664659
""" Creates a Model object on the server.
@@ -678,13 +673,13 @@ def create_model(self, name, ontology_id):
678673
createModel(data: {name : $name, ontologyId : $ontologyId}){
679674
%s
680675
}
681-
}""" % query.results_query_part(Model)
676+
}""" % query.results_query_part(Entity.Model)
682677

683678
result = self.execute(query_str, {
684679
"name": name,
685680
"ontologyId": ontology_id
686681
})
687-
return Model(self, result['createModel'])
682+
return Entity.Model(self, result['createModel'])
688683

689684
def get_data_row_ids_for_external_ids(
690685
self, external_ids: List[str]) -> Dict[str, List[str]]:

0 commit comments

Comments
 (0)