Skip to content

Commit cbc6d95

Browse files
authored
[PLT-1684] Vb/step reasoning ontology plt 1684 (#1879)
1 parent 0a3e3b1 commit cbc6d95

File tree

13 files changed

+362
-13
lines changed

13 files changed

+362
-13
lines changed

docs/labelbox/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Labelbox Python SDK Documentation
4646
search-filters
4747
send-to-annotate-params
4848
slice
49+
step_reasoning_tool
4950
task
5051
task-queue
5152
user

docs/labelbox/step_reasoning_tool.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Step Reasoning Tool
2+
===============================================================================================
3+
4+
.. automodule:: labelbox.schema.tool_building.step_reasoning_tool
5+
:members:
6+
:show-inheritance:

libs/labelbox/mypy.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ ignore_errors = True
1212
[mypy-lbox.exceptions]
1313
ignore_missing_imports = True
1414

15-
[mypy-lbox.call_info"]
15+
[mypy-lbox.call_info]
1616
ignore_missing_imports = True

libs/labelbox/src/labelbox/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
FeatureSchema,
5757
Ontology,
5858
PromptResponseClassification,
59-
Tool,
59+
tool_type_cls_from_type,
6060
)
6161
from labelbox.schema.ontology_kind import (
6262
EditorTaskType,
@@ -1098,7 +1098,8 @@ def create_ontology_from_feature_schemas(
10981098
if "tool" in feature_schema.normalized:
10991099
tool = feature_schema.normalized["tool"]
11001100
try:
1101-
Tool.Type(tool)
1101+
tool_type_cls = tool_type_cls_from_type(tool)
1102+
tool_type_cls(tool)
11021103
tools.append(feature_schema.normalized)
11031104
except ValueError:
11041105
raise ValueError(

libs/labelbox/src/labelbox/schema/ontology.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from labelbox.orm.db_object import DbObject
1414
from labelbox.orm.model import Field, Relationship
15+
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
16+
from labelbox.schema.tool_building.tool_type import ToolType
1517

1618
FeatureSchemaId: Type[str] = Annotated[
1719
str, StringConstraints(min_length=25, max_length=25)
@@ -187,7 +189,7 @@ def __post_init__(self):
187189
@classmethod
188190
def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
189191
return cls(
190-
class_type=cls.Type(dictionary["type"]),
192+
class_type=Classification.Type(dictionary["type"]),
191193
name=dictionary["name"],
192194
instructions=dictionary["instructions"],
193195
required=dictionary.get("required", False),
@@ -351,7 +353,7 @@ class Type(Enum):
351353
@classmethod
352354
def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
353355
return cls(
354-
class_type=cls.Type(dictionary["type"]),
356+
class_type=PromptResponseClassification.Type(dictionary["type"]),
355357
name=dictionary["name"],
356358
instructions=dictionary["instructions"],
357359
required=True, # always required
@@ -458,7 +460,7 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
458460
schema_id=dictionary.get("schemaNodeId", None),
459461
feature_schema_id=dictionary.get("featureSchemaId", None),
460462
required=dictionary.get("required", False),
461-
tool=cls.Type(dictionary["tool"]),
463+
tool=Tool.Type(dictionary["tool"]),
462464
classifications=[
463465
Classification.from_dict(c)
464466
for c in dictionary["classifications"]
@@ -488,6 +490,18 @@ def add_classification(self, classification: Classification) -> None:
488490
self.classifications.append(classification)
489491

490492

493+
def tool_cls_from_type(tool_type: str):
494+
if tool_type.lower() == ToolType.STEP_REASONING.value:
495+
return StepReasoningTool
496+
return Tool
497+
498+
499+
def tool_type_cls_from_type(tool_type: str):
500+
if tool_type.lower() == ToolType.STEP_REASONING.value:
501+
return ToolType
502+
return Tool.Type
503+
504+
491505
class Ontology(DbObject):
492506
"""An ontology specifies which tools and classifications are available
493507
to a project. This is read only for now.
@@ -525,7 +539,8 @@ def tools(self) -> List[Tool]:
525539
"""Get list of tools (AKA objects) in an Ontology."""
526540
if self._tools is None:
527541
self._tools = [
528-
Tool.from_dict(tool) for tool in self.normalized["tools"]
542+
tool_cls_from_type(tool["tool"]).from_dict(tool)
543+
for tool in self.normalized["tools"]
529544
]
530545
return self._tools
531546

@@ -581,7 +596,7 @@ class OntologyBuilder:
581596
582597
"""
583598

584-
tools: List[Tool] = field(default_factory=list)
599+
tools: List[Union[Tool, StepReasoningTool]] = field(default_factory=list)
585600
classifications: List[
586601
Union[Classification, PromptResponseClassification]
587602
] = field(default_factory=list)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
import labelbox.schema.tool_building.tool_type
2+
import labelbox.schema.tool_building.step_reasoning_tool
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import warnings
2+
from dataclasses import dataclass, field
3+
from typing import Any, Dict, List, Optional
4+
5+
from labelbox.schema.tool_building.tool_type import ToolType
6+
7+
8+
@dataclass
9+
class StepReasoningVariant:
10+
id: int
11+
name: str
12+
13+
def asdict(self) -> Dict[str, Any]:
14+
return {"id": self.id, "name": self.name}
15+
16+
17+
@dataclass
18+
class IncorrectStepReasoningVariant:
19+
id: int
20+
name: str
21+
regenerate_conversations_after_incorrect_step: Optional[bool] = True
22+
rate_alternative_responses: Optional[bool] = False
23+
24+
def asdict(self) -> Dict[str, Any]:
25+
actions = []
26+
if self.regenerate_conversations_after_incorrect_step:
27+
actions.append("regenerateSteps")
28+
if self.rate_alternative_responses:
29+
actions.append("generateAndRateAlternativeSteps")
30+
return {"id": self.id, "name": self.name, "actions": actions}
31+
32+
@classmethod
33+
def from_dict(
34+
cls, dictionary: Dict[str, Any]
35+
) -> "IncorrectStepReasoningVariant":
36+
return cls(
37+
id=dictionary["id"],
38+
name=dictionary["name"],
39+
regenerate_conversations_after_incorrect_step="regenerateSteps"
40+
in dictionary.get("actions", []),
41+
rate_alternative_responses="generateAndRateAlternativeSteps"
42+
in dictionary.get("actions", []),
43+
)
44+
45+
46+
def _create_correct_step() -> StepReasoningVariant:
47+
return StepReasoningVariant(
48+
id=StepReasoningVariants.CORRECT_STEP_ID, name="Correct"
49+
)
50+
51+
52+
def _create_neutral_step() -> StepReasoningVariant:
53+
return StepReasoningVariant(
54+
id=StepReasoningVariants.NEUTRAL_STEP_ID, name="Neutral"
55+
)
56+
57+
58+
def _create_incorrect_step() -> IncorrectStepReasoningVariant:
59+
return IncorrectStepReasoningVariant(
60+
id=StepReasoningVariants.INCORRECT_STEP_ID, name="Incorrect"
61+
)
62+
63+
64+
@dataclass
65+
class StepReasoningVariants:
66+
"""
67+
This class is used to define the possible options for evaluating a step
68+
Currently the options are correct, neutral, and incorrect
69+
"""
70+
71+
CORRECT_STEP_ID = 0
72+
NEUTRAL_STEP_ID = 1
73+
INCORRECT_STEP_ID = 2
74+
75+
correct_step: StepReasoningVariant = field(
76+
default_factory=_create_correct_step
77+
)
78+
neutral_step: StepReasoningVariant = field(
79+
default_factory=_create_neutral_step
80+
)
81+
incorrect_step: IncorrectStepReasoningVariant = field(
82+
default_factory=_create_incorrect_step
83+
)
84+
85+
def asdict(self):
86+
return [
87+
self.correct_step.asdict(),
88+
self.neutral_step.asdict(),
89+
self.incorrect_step.asdict(),
90+
]
91+
92+
@classmethod
93+
def from_dict(cls, dictionary: List[Dict[str, Any]]):
94+
correct_step = None
95+
neutral_step = None
96+
incorrect_step = None
97+
98+
for variant in dictionary:
99+
if variant["id"] == cls.CORRECT_STEP_ID:
100+
correct_step = StepReasoningVariant(**variant)
101+
elif variant["id"] == cls.NEUTRAL_STEP_ID:
102+
neutral_step = StepReasoningVariant(**variant)
103+
elif variant["id"] == cls.INCORRECT_STEP_ID:
104+
incorrect_step = IncorrectStepReasoningVariant.from_dict(
105+
variant
106+
)
107+
108+
if not all([correct_step, neutral_step, incorrect_step]):
109+
raise ValueError("Invalid step reasoning variants")
110+
111+
return cls(
112+
correct_step=correct_step, # type: ignore
113+
neutral_step=neutral_step, # type: ignore
114+
incorrect_step=incorrect_step, # type: ignore
115+
)
116+
117+
118+
@dataclass
119+
class StepReasoningDefinition:
120+
variants: StepReasoningVariants = field(
121+
default_factory=StepReasoningVariants
122+
)
123+
version: int = field(default=1)
124+
title: Optional[str] = None
125+
value: Optional[str] = None
126+
127+
def asdict(self) -> Dict[str, Any]:
128+
result = {"variants": self.variants.asdict(), "version": self.version}
129+
if self.title is not None:
130+
result["title"] = self.title
131+
if self.value is not None:
132+
result["value"] = self.value
133+
return result
134+
135+
@classmethod
136+
def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningDefinition":
137+
variants = StepReasoningVariants.from_dict(dictionary["variants"])
138+
title = dictionary.get("title", None)
139+
value = dictionary.get("value", None)
140+
return cls(variants=variants, title=title, value=value)
141+
142+
143+
@dataclass
144+
class StepReasoningTool:
145+
"""
146+
Use this class in OntologyBuilder to create a tool for step reasoning
147+
The definition field lists the possible options to evaulate a step
148+
"""
149+
150+
name: str
151+
type: ToolType = field(default=ToolType.STEP_REASONING, init=False)
152+
required: bool = False
153+
schema_id: Optional[str] = None
154+
feature_schema_id: Optional[str] = None
155+
color: Optional[str] = None
156+
definition: StepReasoningDefinition = field(
157+
default_factory=StepReasoningDefinition
158+
)
159+
160+
def __post_init__(self):
161+
warnings.warn(
162+
"This feature is experimental and subject to change.",
163+
)
164+
165+
def reset_regenerate_conversations_after_incorrect_step(self):
166+
"""
167+
For live models, the default acation will invoke the model to generate alternatives if a step is marked as incorrect
168+
This method will reset the action to not regenerate the conversation
169+
"""
170+
self.definition.variants.incorrect_step.regenerate_conversations_after_incorrect_step = False
171+
172+
def set_rate_alternative_responses(self):
173+
"""
174+
For live models, will require labelers to rate the alternatives generated by the model
175+
"""
176+
self.definition.variants.incorrect_step.rate_alternative_responses = (
177+
True
178+
)
179+
180+
def asdict(self) -> Dict[str, Any]:
181+
return {
182+
"tool": self.type.value,
183+
"name": self.name,
184+
"required": self.required,
185+
"schemaNodeId": self.schema_id,
186+
"featureSchemaId": self.feature_schema_id,
187+
"definition": self.definition.asdict(),
188+
}
189+
190+
@classmethod
191+
def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningTool":
192+
return cls(
193+
name=dictionary["name"],
194+
schema_id=dictionary.get("schemaNodeId", None),
195+
feature_schema_id=dictionary.get("featureSchemaId", None),
196+
required=dictionary.get("required", False),
197+
definition=StepReasoningDefinition.from_dict(
198+
dictionary["definition"]
199+
),
200+
)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from enum import Enum
2+
3+
4+
class ToolType(Enum):
5+
STEP_REASONING = "step-reasoning"

libs/labelbox/tests/integration/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from labelbox.schema.data_row import DataRowMetadataField
2323
from labelbox.schema.ontology_kind import OntologyKind
24+
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
2425
from labelbox.schema.user import User
2526

2627

@@ -562,6 +563,7 @@ def feature_schema(client, point):
562563
@pytest.fixture
563564
def chat_evaluation_ontology(client, rand_gen):
564565
ontology_name = f"test-chat-evaluation-ontology-{rand_gen(str)}"
566+
565567
ontology_builder = OntologyBuilder(
566568
tools=[
567569
Tool(
@@ -576,6 +578,7 @@ def chat_evaluation_ontology(client, rand_gen):
576578
tool=Tool.Type.MESSAGE_RANKING,
577579
name="model output multi ranking",
578580
),
581+
StepReasoningTool(name="step reasoning"),
579582
],
580583
classifications=[
581584
Classification(
@@ -626,14 +629,12 @@ def chat_evaluation_ontology(client, rand_gen):
626629
),
627630
],
628631
)
629-
630632
ontology = client.create_ontology(
631633
ontology_name,
632634
ontology_builder.asdict(),
633635
media_type=MediaType.Conversational,
634636
ontology_kind=OntologyKind.ModelEvaluation,
635637
)
636-
637638
yield ontology
638639

639640
try:

0 commit comments

Comments
 (0)