Skip to content

Commit d8c7b83

Browse files
author
Matt Sokoloff
committed
ontology features working
1 parent fb2f71d commit d8c7b83

File tree

6 files changed

+105
-9
lines changed

6 files changed

+105
-9
lines changed

labelbox/client.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def check_errors(keywords, *path):
196196
return None
197197

198198
def get_error_status_code(error):
199+
print(error)
199200
return error["extensions"]["exception"].get("status")
200201

201202
if check_errors(["AUTHENTICATION_ERROR"], "extensions",
@@ -409,7 +410,7 @@ def get_project(self, project_id):
409410
labelbox.exceptions.ResourceNotFoundError: If there is no
410411
Project with the given ID.
411412
"""
412-
return self._get_single(Project, project_id)
413+
return self._get_single(Entity.Project, project_id)
413414

414415
def get_dataset(self, dataset_id):
415416
""" Gets a single Dataset with the given ID.
@@ -720,7 +721,62 @@ def get_ontologies(self, name_contains: str):
720721
res = PaginatedCollection(
721722
self, query_str, {'search' : name_contains, 'filter' :{'status' : 'ALL'}}, ['ontologies', 'nodes'],
722723
Entity.Ontology, ['ontologies', 'nextCursor'])
724+
# status can be ALL or UNUSED
723725
return res
724726

727+
def get_root_schema(self, root_schema_id):
728+
return self._get_single(Entity.RootSchemaNode, root_schema_id)
729+
730+
731+
def get_root_schemas(self, name_contains):
732+
query_str = """query getRootSchemaNodePyApi($search: String, $filter: RootSchemaNodeFilter, $from : String, $first: PageSize){
733+
rootSchemaNodes(where: {filter: $filter, search: $search}, after: $from, first: $first){
734+
nodes {%s}
735+
nextCursor
736+
}
737+
}
738+
""" % query.results_query_part(Entity.RootSchemaNode)
739+
return PaginatedCollection(
740+
self, query_str, {'search' : name_contains, 'filter' :{'status' : 'ALL'}}, ['rootSchemaNodes', 'nodes'],
741+
Entity.RootSchema, ['rootSchemaNodes', 'nextCursor'])
742+
743+
# TODO: Also supports FeatreSchemaKind in the filter
744+
# status can be ALL or UNUSED
745+
746+
def create_ontology(self, name , normalized_ontology = None, root_schema_ids = None):
747+
"""
748+
- If I create an ontology with an empty ontology does it create the root schemas?
749+
- If I mix ontology with root schemas it reuses right?
750+
751+
- should be able to lookup root schema nodes for an ontology. Add relationship..
752+
"""
753+
query_str = """mutation upsertRootSchemaNodePyApi($data: UpsertOntologyInput!){
754+
upsertOntology(data: $data){
755+
%s
756+
}
757+
} """ % query.results_query_part(Entity.Ontology)
758+
if normalized_ontology is None:
759+
if root_schema_ids is None:
760+
raise ValueError("Must provide either a normalized ontology or a list of root_schema_ids")
761+
return root_schema_ids
762+
763+
res = self.execute(query_str, {'data' : {'name' : name ,'normalized' : json.dumps(normalized_ontology)}})
764+
return Entity.Ontology(self, res['upsertOntology'])
765+
766+
767+
def create_root_schema(self, normalized_ontology):
768+
query_str = """mutation upsertRootSchemaNodePyApi($data: UpsertRootSchemaNodeInput!){
769+
upsertRootSchemaNode(data: $data){
770+
%s
771+
}
772+
} """ % query.results_query_part(Entity.RootSchema)
773+
# TODO: Is this necessary?
774+
normalized_ontology = {k:v for k,v in normalized_ontology.items() if v}
775+
# Check color. Quick gotcha..
776+
if 'color' not in normalized_ontology:
777+
raise KeyError("Must provide color.")
778+
return Entity.RootSchemaNode(self, self.execute(query_str, {'data' : {'normalized' : json.dumps(normalized_ontology)}})['upsertRootSchemaNode'])
779+
780+
725781

726782

labelbox/schema/ontology.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@
1616
SchemaId: Type[str] = constr(min_length=25, max_length=25)
1717

1818

19+
20+
class RootSchemaNode(DbObject):
21+
name = Field.String("name")
22+
color = Field.String("name")
23+
definition = Field.Json("definition")
24+
normalized = Field.Json("normalized")
25+
26+
1927
@dataclass
2028
class Option:
2129
"""
@@ -52,6 +60,10 @@ def from_dict(cls, dictionary: Dict[str, Any]):
5260
for o in dictionary.get("options", [])
5361
])
5462

63+
@classmethod
64+
def from_root_schema(cls, root_schema: RootSchemaNode):
65+
return cls.from_dict(root_schema.normalized)
66+
5567
def asdict(self) -> Dict[str, Any]:
5668
return {
5769
"schemaNodeId": self.schema_id,
@@ -130,6 +142,10 @@ def from_dict(cls, dictionary: Dict[str, Any]):
130142
schema_id=dictionary.get("schemaNodeId", None),
131143
feature_schema_id=dictionary.get("featureSchemaId", None))
132144

145+
@classmethod
146+
def from_root_schema(cls, root_schema: RootSchemaNode):
147+
return cls.from_dict(root_schema.normalized)
148+
133149
def asdict(self) -> Dict[str, Any]:
134150
if self.class_type in self._REQUIRES_OPTIONS \
135151
and len(self.options) < 1:
@@ -233,11 +249,6 @@ def add_classification(self, classification: Classification):
233249
self.classifications.append(classification)
234250

235251

236-
class FeatureSchema(DbObject):
237-
...
238-
239-
240-
241252
class Ontology(DbObject):
242253
"""An ontology specifies which tools and classifications are available
243254
to a project. This is read only for now.
@@ -345,6 +356,11 @@ def from_project(cls, project: "project.Project"):
345356
ontology = project.ontology().normalized
346357
return cls.from_dict(ontology)
347358

359+
@classmethod
360+
def from_ontology(cls, ontology: Ontology):
361+
return cls.from_dict(ontology.normalized)
362+
363+
348364
def add_tool(self, tool: Tool):
349365
if tool.name in (t.name for t in self.tools):
350366
raise InconsistentOntologyException(
@@ -358,3 +374,6 @@ def add_classification(self, classification: Classification):
358374
f"Duplicate classification instructions '{classification.instructions}'. "
359375
)
360376
self.classifications.append(classification)
377+
378+
379+

labelbox/schema/project.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import enum
22
import json
3+
from labelbox.schema.labeling_frontend import LabelingFrontend
34
import logging
45
import time
56
import warnings
@@ -422,6 +423,16 @@ def review_metrics(self, net_score):
422423
res = self.client.execute(query_str, {id_param: self.uid})
423424
return res["project"]["reviewMetrics"]["labelAggregate"]["count"]
424425

426+
427+
def setup_editor(self, ontology):
428+
fe = next(self.client.get_labeling_frontends(where = LabelingFrontend.name == "Editor"))
429+
self.labeling_frontend.connect(fe)
430+
query_str = """mutation ConnectOntology($projectId: ID!, $ontologyId: ID!) {project(where: {id: $projectId}) {connectOntology(ontologyId: $ontologyId) {id}}}"""
431+
self.client.execute(query_str, {'ontologyId' : ontology.uid, 'projectId' : self.uid})
432+
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
433+
self.update(setup_complete=timestamp)
434+
435+
425436
def setup(self, labeling_frontend, labeling_frontend_options):
426437
""" Finalizes the Project setup.
427438
@@ -481,7 +492,7 @@ def _post_batch(self, method, data_row_ids: List[str]):
481492
}
482493
}
483494
}
484-
}
495+
}
485496
""" % (method, method)
486497

487498
res = self.client.execute(query, {
@@ -517,7 +528,7 @@ def _update_queue_mode(self, mode: QueueMode) -> QueueMode:
517528
tagSetStatus
518529
}
519530
}
520-
}
531+
}
521532
""" % "setTagSetStatusPyApi"
522533

523534
self.client.execute(query_str, {
@@ -533,7 +544,7 @@ def queue_mode(self):
533544
project(where: {id: $projectId}) {
534545
tagSetStatus
535546
}
536-
}
547+
}
537548
""" % "GetTagSetStatusPyApi"
538549

539550
status = self.client.execute(

tests/integration/annotation_import/test_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,6 @@ def test_model_delete(client, model):
2929
after = list(client.get_models())
3030

3131
assert len(before) == len(after) + 1
32+
33+
34+

tests/integration/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,4 @@ def configured_project_with_label(client, rand_gen, image_url):
333333
yield [project, label]
334334
dataset.delete()
335335
project.delete()
336+

tests/integration/test_ontology.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,9 @@ def test_ontology_asdict(project) -> None:
240240
def test_from_project_ontology(client, project) -> None:
241241
o = OntologyBuilder.from_project(project)
242242
assert o.asdict() == project.ontology().normalized
243+
244+
245+
246+
def test_create_ontology(client, rand_gen):
247+
client.create_ontology(name = f"test-ontology-{rand_gen(str)}")
248+

0 commit comments

Comments
 (0)