Skip to content

Commit 9bea518

Browse files
author
Val Brodsky
committed
Add step by step reasoning ontology tool
1 parent 5140584 commit 9bea518

File tree

7 files changed

+192
-6
lines changed

7 files changed

+192
-6
lines changed

libs/labelbox/src/labelbox/schema/ontology.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from labelbox.orm.db_object import DbObject
1414
from labelbox.orm.model import Field, Relationship
15+
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
16+
from labelbox.schema.tool_building.tool_type import ToolType
1517

1618
FeatureSchemaId: Type[str] = Annotated[
1719
str, StringConstraints(min_length=25, max_length=25)
@@ -187,7 +189,7 @@ def __post_init__(self):
187189
@classmethod
188190
def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
189191
return cls(
190-
class_type=cls.Type(dictionary["type"]),
192+
class_type=Classification.Type(dictionary["type"]),
191193
name=dictionary["name"],
192194
instructions=dictionary["instructions"],
193195
required=dictionary.get("required", False),
@@ -351,7 +353,7 @@ class Type(Enum):
351353
@classmethod
352354
def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
353355
return cls(
354-
class_type=cls.Type(dictionary["type"]),
356+
class_type=Type(dictionary["type"]),
355357
name=dictionary["name"],
356358
instructions=dictionary["instructions"],
357359
required=True, # always required
@@ -458,7 +460,7 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
458460
schema_id=dictionary.get("schemaNodeId", None),
459461
feature_schema_id=dictionary.get("featureSchemaId", None),
460462
required=dictionary.get("required", False),
461-
tool=cls.Type(dictionary["tool"]),
463+
tool=Tool.Type(dictionary["tool"]),
462464
classifications=[
463465
Classification.from_dict(c)
464466
for c in dictionary["classifications"]
@@ -488,6 +490,16 @@ def add_classification(self, classification: Classification) -> None:
488490
self.classifications.append(classification)
489491

490492

493+
def tool_cls_from_type(tool_type: str):
494+
if tool_type.lower() == ToolType.STEP_REASONING.value:
495+
from labelbox.schema.tool_building.step_reasoning_tool import (
496+
StepReasoningTool,
497+
)
498+
499+
return StepReasoningTool
500+
return Tool
501+
502+
491503
class Ontology(DbObject):
492504
"""An ontology specifies which tools and classifications are available
493505
to a project. This is read only for now.
@@ -521,11 +533,18 @@ def __init__(self, *args, **kwargs) -> None:
521533
Union[List[Classification], List[PromptResponseClassification]]
522534
] = None
523535

536+
def _tool_deserializer_cls(self, tool: Dict[str, Any]) -> Tool:
537+
import pdb
538+
539+
pdb.set_trace()
540+
return Tool
541+
524542
def tools(self) -> List[Tool]:
525543
"""Get list of tools (AKA objects) in an Ontology."""
526544
if self._tools is None:
527545
self._tools = [
528-
Tool.from_dict(tool) for tool in self.normalized["tools"]
546+
tool_cls_from_type(tool["tool"]).from_dict(tool)
547+
for tool in self.normalized["tools"]
529548
]
530549
return self._tools
531550

@@ -581,7 +600,7 @@ class OntologyBuilder:
581600
582601
"""
583602

584-
tools: List[Tool] = field(default_factory=list)
603+
tools: List[Union[Tool, StepReasoningTool]] = field(default_factory=list)
585604
classifications: List[
586605
Union[Classification, PromptResponseClassification]
587606
] = field(default_factory=list)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
import labelbox.schema.tool_building.tool_type
2+
import labelbox.schema.tool_building.step_reasoning_tool
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from dataclasses import dataclass, field
2+
from typing import Any, Dict, Optional
3+
4+
from labelbox.schema.tool_building.tool_type import ToolType
5+
6+
7+
@dataclass
8+
class StepReasoningVariant:
9+
id: int
10+
name: str
11+
12+
def asdict(self) -> Dict[str, Any]:
13+
return {"id": self.id, "name": self.name}
14+
15+
16+
@dataclass
17+
class IncorrectStepReasoningVariant:
18+
id: int
19+
name: str
20+
regenerate_conversations_after_incorrect_step: Optional[bool] = False
21+
rate_alternative_responses: Optional[bool] = False
22+
23+
def asdict(self) -> Dict[str, Any]:
24+
actions = []
25+
if self.regenerate_conversations_after_incorrect_step:
26+
actions.append("regenerateSteps")
27+
if self.rate_alternative_responses:
28+
actions.append("generateAlternatives")
29+
return {"id": self.id, "name": self.name, "actions": actions}
30+
31+
32+
@dataclass
33+
class StepReasoningVariants:
34+
CORRECT_STEP_ID = 0
35+
NEUTRAL_STEP_ID = 1
36+
INCORRECT_STEP_ID = 2
37+
38+
correct_step: StepReasoningVariant = field(
39+
default=StepReasoningVariant(CORRECT_STEP_ID, "Correct"), init=False
40+
)
41+
neutral_step: StepReasoningVariant = field(
42+
default=StepReasoningVariant(NEUTRAL_STEP_ID, "Neutral"), init=False
43+
)
44+
incorrect_step: IncorrectStepReasoningVariant = field(
45+
default=IncorrectStepReasoningVariant(INCORRECT_STEP_ID, "Incorrect"),
46+
)
47+
48+
def asdict(self):
49+
return [
50+
self.correct_step.asdict(),
51+
self.neutral_step.asdict(),
52+
self.incorrect_step.asdict(),
53+
]
54+
55+
56+
@dataclass
57+
class StepReasoningDefinition:
58+
variants: StepReasoningVariants = field(
59+
default_factory=StepReasoningVariants
60+
)
61+
version: int = field(default=1)
62+
title: Optional[str] = None
63+
value: Optional[str] = None
64+
color: Optional[str] = None
65+
66+
def asdict(self) -> Dict[str, Any]:
67+
result = {"variants": self.variants.asdict(), "version": self.version}
68+
if self.title is not None:
69+
result["title"] = self.title
70+
if self.value is not None:
71+
result["value"] = self.value
72+
if self.color is not None:
73+
result["color"] = self.color
74+
return result
75+
76+
77+
@dataclass
78+
class StepReasoningTool:
79+
name: str
80+
type: ToolType = field(default=ToolType.STEP_REASONING, init=False)
81+
required: bool = False
82+
schema_id: Optional[str] = None
83+
feature_schema_id: Optional[str] = None
84+
color: Optional[str] = None
85+
definition: StepReasoningDefinition = field(
86+
default_factory=StepReasoningDefinition
87+
)
88+
89+
def set_regenerate_conversations_after_incorrect_step(self):
90+
self.definition.variants.incorrect_step.regenerate_conversations_after_incorrect_step = True
91+
92+
def set_rate_alternative_responses(self):
93+
self.definition.variants.incorrect_step.rate_alternative_responses = (
94+
True
95+
)
96+
97+
def asdict(self) -> Dict[str, Any]:
98+
self.set_rate_alternative_responses()
99+
self.set_regenerate_conversations_after_incorrect_step()
100+
return {
101+
"tool": self.type.value,
102+
"name": self.name,
103+
"required": self.required,
104+
"schemaNodeId": self.schema_id,
105+
"featureSchemaId": self.feature_schema_id,
106+
"definition": self.definition.asdict(),
107+
}
108+
109+
@classmethod
110+
def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningTool":
111+
return cls(
112+
name=dictionary["name"],
113+
schema_id=dictionary.get("schemaNodeId", None),
114+
feature_schema_id=dictionary.get("featureSchemaId", None),
115+
required=dictionary.get("required", False),
116+
definition=StepReasoningDefinition(**dictionary["definition"]),
117+
)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from enum import Enum
2+
3+
4+
class ToolType(Enum):
5+
STEP_REASONING = "step-reasoning"

libs/labelbox/tests/integration/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from labelbox.schema.enums import AnnotationImportState
3535
from labelbox.schema.invite import Invite
3636
from labelbox.schema.ontology_kind import OntologyKind
37+
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
38+
from labelbox.schema.tool_building.tool_type import ToolType
3739
from labelbox.schema.user import User
3840

3941

@@ -575,6 +577,7 @@ def feature_schema(client, point):
575577
@pytest.fixture
576578
def chat_evaluation_ontology(client, rand_gen):
577579
ontology_name = f"test-chat-evaluation-ontology-{rand_gen(str)}"
580+
578581
ontology_builder = OntologyBuilder(
579582
tools=[
580583
Tool(
@@ -589,6 +592,7 @@ def chat_evaluation_ontology(client, rand_gen):
589592
tool=Tool.Type.MESSAGE_RANKING,
590593
name="model output multi ranking",
591594
),
595+
StepReasoningTool(name="step reasoning"),
592596
],
593597
classifications=[
594598
Classification(

libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_create_chat_evaluation_ontology_project(
1717
# here we are essentially testing the ontology creation which is a fixture
1818
assert ontology
1919
assert ontology.name
20-
assert len(ontology.tools()) == 3
20+
assert len(ontology.tools()) == 4
2121
for tool in ontology.tools():
2222
assert tool.schema_id
2323
assert tool.feature_schema_id
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
2+
3+
4+
def test_step_reasoning_as_dict_default():
5+
tool = StepReasoningTool(name="step reasoning")
6+
assert tool.asdict() == {
7+
"tool": "step-reasoning",
8+
"name": "step reasoning",
9+
"required": False,
10+
"schemaNodeId": None,
11+
"featureSchemaId": None,
12+
"variants": [
13+
{"id": 0, "name": "Correct"},
14+
{"id": 1, "name": "Neutral"},
15+
{"id": 2, "name": "Incorrect", "actions": []},
16+
],
17+
}
18+
19+
20+
def test_step_reasoning_as_dict_with_actions():
21+
tool = StepReasoningTool(name="step reasoning")
22+
tool.set_rate_alternative_responses()
23+
tool.set_regenerate_conversations_after_incorrect_step()
24+
assert tool.asdict() == {
25+
"tool": "step-reasoning",
26+
"name": "step reasoning",
27+
"required": False,
28+
"schemaNodeId": None,
29+
"featureSchemaId": None,
30+
"variants": [
31+
{"id": 0, "name": "Correct"},
32+
{"id": 1, "name": "Neutral"},
33+
{
34+
"id": 2,
35+
"name": "Incorrect",
36+
"actions": ["regenerateSteps", "generateAlternatives"],
37+
},
38+
],
39+
}

0 commit comments

Comments
 (0)