diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index 1cc827c5f..897759af1 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -165,6 +165,8 @@ def tool_cls_from_type(tool_type: str): tool_cls = map_tool_type_to_tool_cls(tool_type) if tool_cls is not None: return tool_cls + if tool_type == Tool.Type.RELATIONSHIP: + return RelationshipTool return Tool diff --git a/libs/labelbox/src/labelbox/schema/tool_building/relationship_tool.py b/libs/labelbox/src/labelbox/schema/tool_building/relationship_tool.py new file mode 100644 index 000000000..f11d204a9 --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/tool_building/relationship_tool.py @@ -0,0 +1,93 @@ +# type: ignore + +import uuid +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +from labelbox.schema.ontology import Tool + + +@dataclass +class RelationshipTool(Tool): + """ + A relationship tool to be added to a Project's ontology. + + The "tool" parameter is automatically set to Tool.Type.RELATIONSHIP + and doesn't need to be passed during instantiation. + + The "classifications" parameter holds a list of Classification objects. + This can be used to add nested classifications to a tool. + + Example(s): + tool = RelationshipTool( + name = "Relationship Tool example", + constraints = [ + ("source_tool_feature_schema_id_1", "target_tool_feature_schema_id_1"), + ("source_tool_feature_schema_id_2", "target_tool_feature_schema_id_2") + ] + ) + classification = Classification( + class_type = Classification.Type.TEXT, + instructions = "Classification Example") + tool.add_classification(classification) + + Attributes: + tool: Tool.Type.RELATIONSHIP (automatically set) + name: (str) + required: (bool) + color: (str) + classifications: (list) + schema_id: (str) + feature_schema_id: (str) + attributes: (list) + constraints: (list of [str, str]) + """ + + constraints: Optional[List[Tuple[str, str]]] = None + + def __init__( + self, + name: str, + constraints: Optional[List[Tuple[str, str]]] = None, + **kwargs, + ): + super().__init__(Tool.Type.RELATIONSHIP, name, **kwargs) + if constraints is not None: + self.constraints = constraints + + def __post_init__(self): + # Ensure tool type is set to RELATIONSHIP + self.tool = Tool.Type.RELATIONSHIP + super().__post_init__() + + def asdict(self) -> Dict[str, Any]: + result = super().asdict() + if self.constraints is not None: + result["definition"] = {"constraints": self.constraints} + return result + + def add_constraint(self, start: Tool, end: Tool) -> None: + if self.constraints is None: + self.constraints = [] + + # Ensure feature schema ids are set for the tools, + # the newly set ids will be changed during ontology creation + # but we need to refer to the same ids in the constraints array + # to ensure that the valid constraints are created. + if start.feature_schema_id is None: + start.feature_schema_id = str(uuid.uuid4()) + if start.schema_id is None: + start.schema_id = str(uuid.uuid4()) + if end.feature_schema_id is None: + end.feature_schema_id = str(uuid.uuid4()) + if end.schema_id is None: + end.schema_id = str(uuid.uuid4()) + + self.constraints.append( + (start.feature_schema_id, end.feature_schema_id) + ) + + def set_constraints(self, constraints: List[Tuple[Tool, Tool]]) -> None: + self.constraints = [] + for constraint in constraints: + self.add_constraint(constraint[0], constraint[1]) diff --git a/libs/labelbox/tests/unit/test_unit_relationship_tool.py b/libs/labelbox/tests/unit/test_unit_relationship_tool.py new file mode 100644 index 000000000..0f50187da --- /dev/null +++ b/libs/labelbox/tests/unit/test_unit_relationship_tool.py @@ -0,0 +1,212 @@ +import uuid +from unittest.mock import patch + +from labelbox.schema.ontology import Tool +from labelbox.schema.tool_building.relationship_tool import RelationshipTool + + +def test_basic_instantiation(): + tool = RelationshipTool(name="Test Relationship Tool") + + assert tool.name == "Test Relationship Tool" + assert tool.tool == Tool.Type.RELATIONSHIP + assert tool.constraints is None + assert tool.required is False + assert tool.color is None + assert tool.schema_id is None + assert tool.feature_schema_id is None + + +def test_instantiation_with_constraints(): + constraints = [ + ("source_id_1", "target_id_1"), + ("source_id_2", "target_id_2"), + ] + tool = RelationshipTool(name="Test Tool", constraints=constraints) + + assert tool.name == "Test Tool" + assert tool.constraints == constraints + assert len(tool.constraints) == 2 + + +def test_post_init_sets_tool_type(): + tool = RelationshipTool(name="Test Tool") + assert tool.tool == Tool.Type.RELATIONSHIP + + +def test_asdict_without_constraints(): + tool = RelationshipTool(name="Test Tool", required=True, color="#FF0000") + + result = tool.asdict() + expected = { + "tool": "edge", + "name": "Test Tool", + "required": True, + "color": "#FF0000", + "classifications": [], + "schemaNodeId": None, + "featureSchemaId": None, + "attributes": None, + } + + assert result == expected + + +def test_asdict_with_constraints(): + constraints = [("source_id", "target_id")] + tool = RelationshipTool(name="Test Tool", constraints=constraints) + + result = tool.asdict() + + assert "definition" in result + assert result["definition"] == {"constraints": constraints} + assert result["tool"] == "edge" + assert result["name"] == "Test Tool" + + +def test_add_constraint_to_empty_constraints(): + tool = RelationshipTool(name="Test Tool") + start_tool = Tool(Tool.Type.BBOX, "Start Tool") + end_tool = Tool(Tool.Type.POLYGON, "End Tool") + + with patch("uuid.uuid4") as mock_uuid: + mock_uuid.return_value.hex = "test-uuid" + tool.add_constraint(start_tool, end_tool) + + assert tool.constraints is not None + assert len(tool.constraints) == 1 + assert start_tool.feature_schema_id is not None + assert start_tool.schema_id is not None + assert end_tool.feature_schema_id is not None + assert end_tool.schema_id is not None + + +def test_add_constraint_to_existing_constraints(): + existing_constraints = [("existing_source", "existing_target")] + tool = RelationshipTool(name="Test Tool", constraints=existing_constraints) + + start_tool = Tool(Tool.Type.BBOX, "Start Tool") + end_tool = Tool(Tool.Type.POLYGON, "End Tool") + + tool.add_constraint(start_tool, end_tool) + + assert len(tool.constraints) == 2 + assert tool.constraints[0] == ("existing_source", "existing_target") + assert tool.constraints[1] == ( + start_tool.feature_schema_id, + end_tool.feature_schema_id, + ) + + +def test_add_constraint_preserves_existing_ids(): + tool = RelationshipTool(name="Test Tool") + start_tool_feature_schema_id = "start_tool_feature_schema_id" + start_tool_schema_id = "start_tool_schema_id" + start_tool = Tool( + Tool.Type.BBOX, + "Start Tool", + feature_schema_id=start_tool_feature_schema_id, + schema_id=start_tool_schema_id, + ) + end_tool_feature_schema_id = "end_tool_feature_schema_id" + end_tool_schema_id = "end_tool_schema_id" + end_tool = Tool( + Tool.Type.POLYGON, + "End Tool", + feature_schema_id=end_tool_feature_schema_id, + schema_id=end_tool_schema_id, + ) + + tool.add_constraint(start_tool, end_tool) + + assert start_tool.feature_schema_id == start_tool_feature_schema_id + assert start_tool.schema_id == start_tool_schema_id + assert end_tool.feature_schema_id == end_tool_feature_schema_id + assert end_tool.schema_id == end_tool_schema_id + assert tool.constraints == [ + (start_tool_feature_schema_id, end_tool_feature_schema_id) + ] + + +def test_set_constraints(): + tool = RelationshipTool(name="Test Tool") + + start_tool1 = Tool(Tool.Type.BBOX, "Start Tool 1") + end_tool1 = Tool(Tool.Type.POLYGON, "End Tool 1") + start_tool2 = Tool(Tool.Type.POINT, "Start Tool 2") + end_tool2 = Tool(Tool.Type.LINE, "End Tool 2") + + tool.set_constraints([(start_tool1, end_tool1), (start_tool2, end_tool2)]) + + assert len(tool.constraints) == 2 + assert tool.constraints[0] == ( + start_tool1.feature_schema_id, + end_tool1.feature_schema_id, + ) + assert tool.constraints[1] == ( + start_tool2.feature_schema_id, + end_tool2.feature_schema_id, + ) + + +def test_set_constraints_replaces_existing(): + existing_constraints = [("old_source", "old_target")] + tool = RelationshipTool(name="Test Tool", constraints=existing_constraints) + + start_tool = Tool(Tool.Type.BBOX, "Start Tool") + end_tool = Tool(Tool.Type.POLYGON, "End Tool") + + tool.set_constraints([(start_tool, end_tool)]) + + assert len(tool.constraints) == 1 + assert tool.constraints[0] != ("old_source", "old_target") + assert tool.constraints[0] == ( + start_tool.feature_schema_id, + end_tool.feature_schema_id, + ) + + +def test_uuid_generation_in_add_constraint(): + tool = RelationshipTool(name="Test Tool") + + start_tool = Tool(Tool.Type.BBOX, "Start Tool") + end_tool = Tool(Tool.Type.POLYGON, "End Tool") + + # Ensure tools don't have IDs initially + assert start_tool.feature_schema_id is None + assert start_tool.schema_id is None + assert end_tool.feature_schema_id is None + assert end_tool.schema_id is None + + tool.add_constraint(start_tool, end_tool) + + # Check that UUIDs were generated + assert start_tool.feature_schema_id is not None + assert start_tool.schema_id is not None + assert end_tool.feature_schema_id is not None + assert end_tool.schema_id is not None + + # Check that they are valid UUID strings + uuid.UUID(start_tool.feature_schema_id) # Will raise ValueError if invalid + uuid.UUID(start_tool.schema_id) + uuid.UUID(end_tool.feature_schema_id) + uuid.UUID(end_tool.schema_id) + + +def test_constraints_in_asdict(): + tool = RelationshipTool(name="Test Tool") + + start_tool = Tool(Tool.Type.BBOX, "Start Tool") + end_tool = Tool(Tool.Type.POLYGON, "End Tool") + + tool.add_constraint(start_tool, end_tool) + + result = tool.asdict() + + assert "definition" in result + assert "constraints" in result["definition"] + assert len(result["definition"]["constraints"]) == 1 + assert result["definition"]["constraints"][0] == ( + start_tool.feature_schema_id, + end_tool.feature_schema_id, + )