Skip to content

Commit 2fbe5a8

Browse files
authored
Merge pull request #141 from Labelbox/ms/cached-relationships
Ms/cached relationships
2 parents e4f4741 + 285529f commit 2fbe5a8

File tree

9 files changed

+117
-6
lines changed

9 files changed

+117
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
* Delete users from organization
1010
* Example notebook added under examples/basics
1111
* Issues and comments export
12-
* Bulk export issues and comments. See `Project.export_labels`
12+
* Bulk export issues and comments. See `Project.export_issues`
1313
* MAL on Tiled Imagery
1414
* Example notebook added under examples/model_assisted_labeling
1515
* `Dataset.create_data_rows` now allows users to upload tms imagery

labelbox/orm/db_object.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ def __init__(self, client, field_values):
4343
"""
4444
self.client = client
4545
self._set_field_values(field_values)
46-
4746
for relationship in self.relationships():
48-
value = field_values.get(relationship.name)
47+
value = field_values.get(utils.camel_case(relationship.name))
4948
if relationship.cache and value is None:
5049
raise KeyError(
5150
f"Expected field values for {relationship.name}")
@@ -168,6 +167,7 @@ def _to_one(self):
168167
result = result and result.get(rel.graphql_name)
169168
if result is None:
170169
return None
170+
171171
return rel.destination_type(self.source.client, result)
172172

173173
def connect(self, other):

labelbox/orm/model.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum, auto
2-
from typing import Union
2+
from typing import Dict, List, Union
33

44
from labelbox import utils
55
from labelbox.exceptions import InvalidAttributeError
@@ -239,11 +239,74 @@ class EntityMeta(type):
239239
of the Entity class object so they can be referenced for example like:
240240
Entity.Project.
241241
"""
242+
# Maps Entity name to Relationships for all currently defined Entities
243+
relationship_mappings: Dict[str, List[Relationship]] = {}
242244

243245
def __init__(cls, clsname, superclasses, attributedict):
244246
super().__init__(clsname, superclasses, attributedict)
247+
cls.validate_cached_relationships()
245248
if clsname != "Entity":
246249
setattr(Entity, clsname, cls)
250+
EntityMeta.relationship_mappings[utils.snake_case(
251+
cls.__name__)] = cls.relationships()
252+
253+
@staticmethod
254+
def raise_for_nested_cache(first: str, middle: str, last: List[str]):
255+
raise TypeError(
256+
"Cannot cache a relationship to an Entity with its own cached relationship(s). "
257+
f"`{first}` caches `{middle}` which caches `{last}`")
258+
259+
@staticmethod
260+
def cached_entities(entity_name: str):
261+
"""
262+
Return all cached entites for a given Entity name
263+
"""
264+
cached_entities = EntityMeta.relationship_mappings.get(entity_name, [])
265+
return {
266+
entity.name: entity for entity in cached_entities if entity.cache
267+
}
268+
269+
def validate_cached_relationships(cls):
270+
"""
271+
Graphql doesn't allow for infinite nesting in queries.
272+
This function checks that cached relationships result in valid queries.
273+
* It does this by making sure that a cached relationship do not
274+
reference any entity with its own cached relationships.
275+
276+
This check is performed by looking to see if this entity caches
277+
any entities that have their own cached fields. If this entity
278+
that we are checking has any cached fields then we also check
279+
all currently defined entities to see if they cache this entity.
280+
281+
A two way check is necessary because checks are performed as classes are being defined.
282+
As opposed to after all objects have been created.
283+
"""
284+
# All cached relationships
285+
cached_rels = [r for r in cls.relationships() if r.cache]
286+
287+
# Check if any cached entities have their own cached fields
288+
for rel in cached_rels:
289+
nested = cls.cached_entities(rel.name)
290+
if nested:
291+
cls.raise_for_nested_cache(utils.snake_case(cls.__name__),
292+
rel.name, list(nested.keys()))
293+
294+
# If the current Entity (cls) has any cached relationships (cached_rels)
295+
# then no other defined Entity (entities in EntityMeta.relationship_mappings) can cache this Entity.
296+
if cached_rels:
297+
# For all currently defined Entities
298+
for entity_name in EntityMeta.relationship_mappings:
299+
# Get all cached ToOne relationships
300+
rels = cls.cached_entities(entity_name)
301+
# Check if the current Entity (cls) is referenced by the Entity with `entity_name`
302+
rel = rels.get(utils.snake_case(cls.__name__))
303+
# If rel exists and is cached then raise an exception
304+
# This means `entity_name` caches `cls` which cached items in `cached_rels`
305+
if rel and rel.cache:
306+
cls.raise_for_nested_cache(
307+
utils.snake_case(entity_name),
308+
utils.snake_case(cls.__name__),
309+
[entity.name for entity in cached_rels])
247310

248311

249312
class Entity(metaclass=EntityMeta):

labelbox/orm/query.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,16 @@ def results_query_part(entity):
4040
Args:
4141
entity (type): The entity which needs fetching.
4242
"""
43-
return " ".join(field.graphql_name for field in entity.fields())
43+
# Query for fields
44+
fields = [field.graphql_name for field in entity.fields()]
45+
46+
# Query for cached relationships
47+
fields.extend([
48+
Query(rel.graphql_name, rel.destination_type).format()[0]
49+
for rel in entity.relationships()
50+
if rel.cache
51+
])
52+
return " ".join(fields)
4453

4554

4655
class Query:

labelbox/schema/project.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,8 @@ class LabelingParameterOverride(DbObject):
676676
priority = Field.Int("priority")
677677
number_of_labels = Field.Int("number_of_labels")
678678

679+
data_row = Relationship.ToOne("DataRow", cache=True)
680+
679681

680682
LabelerPerformance = namedtuple(
681683
"LabelerPerformance", "user count seconds_per_label, total_time_labeling "

labelbox/schema/webhook.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def delete(self):
112112
"""
113113
Deletes the webhook
114114
"""
115-
self.update(status=self.Status.INACTIVE)
115+
self.update(status=self.Status.INACTIVE.value)
116116

117117
def update(self, topics=None, url=None, status=None):
118118
""" Updates the Webhook.

tests/integration/test_labeling_parameter_overrides.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def test_labeling_parameter_overrides(project, rand_gen):
2626
assert {o.number_of_labels for o in overrides} == {3, 2, 5}
2727
assert {o.priority for o in overrides} == {4, 3, 8}
2828

29+
for override in overrides:
30+
assert isinstance(override.data_row(), DataRow)
31+
2932
success = project.unset_labeling_parameter_overrides(
3033
[data[0][0], data[1][0]])
3134
assert success

tests/integration/test_webhook.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,5 @@ def test_webhook_create_update(project, rand_gen):
3838
webhook.update(topics="invalid..")
3939
assert str(exc_info.value) == \
4040
"Topics must be List[Webhook.Topic]. Found `invalid..`"
41+
42+
webhook.delete()

tests/test_entity_meta.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
3+
from labelbox.orm.model import Relationship
4+
from labelbox.orm.db_object import DbObject
5+
6+
7+
def test_illegal_cache_cond1():
8+
9+
class TestEntityA(DbObject):
10+
test_entity_b = Relationship.ToOne("TestEntityB", cache=True)
11+
12+
with pytest.raises(TypeError) as exc_info:
13+
14+
class TestEntityB(DbObject):
15+
another_entity = Relationship.ToOne("AnotherEntity", cache=True)
16+
17+
assert "`test_entity_a` caches `test_entity_b` which caches `['another_entity']`" in str(
18+
exc_info.value)
19+
20+
21+
def test_illegal_cache_cond2():
22+
23+
class TestEntityD(DbObject):
24+
another_entity = Relationship.ToOne("AnotherEntity", cache=True)
25+
26+
with pytest.raises(TypeError) as exc_info:
27+
28+
class TestEntityC(DbObject):
29+
test_entity_d = Relationship.ToOne("TestEntityD", cache=True)
30+
31+
assert "`test_entity_c` caches `test_entity_d` which caches `['another_entity']`" in str(
32+
exc_info.value)

0 commit comments

Comments
 (0)