Skip to content

Commit bc04ae3

Browse files
author
Val Brodsky
committed
Refactor StepReasoning to also reuse Variants
1 parent 2bce775 commit bc04ae3

File tree

6 files changed

+147
-106
lines changed

6 files changed

+147
-106
lines changed

libs/labelbox/src/labelbox/schema/tool_building/fact_checking_tool.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import warnings
12
from dataclasses import dataclass, field
2-
from typing import Any, Dict, List, Optional, Set
3+
from enum import Enum
4+
from typing import Any, Dict, List, Optional
35

46
from labelbox.schema.tool_building.tool_type import ToolType
57
from labelbox.schema.tool_building.variant import (
@@ -8,6 +10,18 @@
810
)
911

1012

13+
class UnsupportedStepActions(Enum):
14+
WRITE_JUSTIFICATION = "writeJustification"
15+
16+
17+
class CanConfidentlyAssessStepActions(Enum):
18+
WRITE_JUSTIFICATION = "writeJustification"
19+
20+
21+
class NoFactualInformationStepActions(Enum):
22+
WRITE_JUSTIFICATION = "writeJustification"
23+
24+
1125
@dataclass
1226
class FactCheckingVariants:
1327
"""
@@ -26,21 +40,32 @@ class FactCheckingVariants:
2640
)
2741
unsupported_step: VariantWithActions = field(
2842
default_factory=lambda: VariantWithActions(
29-
id=3, name="Unsupported", _available_actions={"writeJustification"}
43+
id=3,
44+
name="Unsupported",
45+
_available_actions={
46+
action.value for action in UnsupportedStepActions
47+
},
48+
actions=[UnsupportedStepActions.WRITE_JUSTIFICATION.value],
3049
)
3150
)
3251
cant_confidently_assess_step: VariantWithActions = field(
3352
default_factory=lambda: VariantWithActions(
3453
id=4,
3554
name="Can't confidently assess",
36-
_available_actions={"writeJustification"},
55+
_available_actions={
56+
action.value for action in CanConfidentlyAssessStepActions
57+
},
58+
actions=[CanConfidentlyAssessStepActions.WRITE_JUSTIFICATION.value],
3759
)
3860
)
3961
no_factual_information_step: VariantWithActions = field(
4062
default_factory=lambda: VariantWithActions(
4163
id=5,
4264
name="No factual information",
43-
_available_actions={"writeJustification"},
65+
_available_actions={
66+
action.value for action in NoFactualInformationStepActions
67+
},
68+
actions=[NoFactualInformationStepActions.WRITE_JUSTIFICATION.value],
4469
)
4570
)
4671

@@ -138,23 +163,31 @@ class FactCheckingTool:
138163
default_factory=FactCheckingDefinition
139164
)
140165

166+
def __post_init__(self):
167+
warnings.warn(
168+
"This feature is experimental and subject to change.",
169+
)
170+
141171
def set_unsupported_step_actions(
142-
self, actions: Set[str] = {"writeJustification"}
172+
self, actions: List[UnsupportedStepActions]
143173
) -> None:
144-
self.definition.variants.unsupported_step.set_actions(actions)
174+
actions_values = [action.value for action in actions]
175+
self.definition.variants.unsupported_step.set_actions(actions_values)
145176

146177
def set_cant_confidently_assess_step_actions(
147-
self, actions: Set[str] = {"writeJustification"}
178+
self, actions: List[CanConfidentlyAssessStepActions]
148179
) -> None:
180+
actions_values = [action.value for action in actions]
149181
self.definition.variants.cant_confidently_assess_step.set_actions(
150-
actions
182+
actions_values
151183
)
152184

153185
def set_no_factual_information_step_actions(
154-
self, actions: Set[str] = {"writeJustification"}
186+
self, actions: List[NoFactualInformationStepActions]
155187
) -> None:
188+
actions_values = [action.value for action in actions]
156189
self.definition.variants.no_factual_information_step.set_actions(
157-
actions
190+
actions_values
158191
)
159192

160193
def asdict(self) -> Dict[str, Any]:

libs/labelbox/src/labelbox/schema/tool_building/step_reasoning_tool.py

Lines changed: 29 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,15 @@
11
import warnings
22
from dataclasses import dataclass, field
3+
from enum import Enum
34
from typing import Any, Dict, List, Optional
45

56
from labelbox.schema.tool_building.tool_type import ToolType
7+
from labelbox.schema.tool_building.variant import Variant, VariantWithActions
68

79

8-
@dataclass
9-
class StepReasoningVariant:
10-
id: int
11-
name: str
12-
13-
def asdict(self) -> Dict[str, Any]:
14-
return {"id": self.id, "name": self.name}
15-
16-
17-
@dataclass
18-
class IncorrectStepReasoningVariant:
19-
id: int
20-
name: str
21-
regenerate_conversations_after_incorrect_step: Optional[bool] = True
22-
rate_alternative_responses: Optional[bool] = True
23-
24-
def asdict(self) -> Dict[str, Any]:
25-
actions = []
26-
if self.regenerate_conversations_after_incorrect_step:
27-
actions.append("regenerateSteps")
28-
if self.rate_alternative_responses:
29-
actions.append("generateAndRateAlternativeSteps")
30-
return {"id": self.id, "name": self.name, "actions": actions}
31-
32-
@classmethod
33-
def from_dict(
34-
cls, dictionary: Dict[str, Any]
35-
) -> "IncorrectStepReasoningVariant":
36-
return cls(
37-
id=dictionary["id"],
38-
name=dictionary["name"],
39-
regenerate_conversations_after_incorrect_step="regenerateSteps"
40-
in dictionary.get("actions", []),
41-
rate_alternative_responses="generateAndRateAlternativeSteps"
42-
in dictionary.get("actions", []),
43-
)
44-
45-
46-
def _create_correct_step() -> StepReasoningVariant:
47-
return StepReasoningVariant(
48-
id=StepReasoningVariants.CORRECT_STEP_ID, name="Correct"
49-
)
50-
51-
52-
def _create_neutral_step() -> StepReasoningVariant:
53-
return StepReasoningVariant(
54-
id=StepReasoningVariants.NEUTRAL_STEP_ID, name="Neutral"
55-
)
56-
57-
58-
def _create_incorrect_step() -> IncorrectStepReasoningVariant:
59-
return IncorrectStepReasoningVariant(
60-
id=StepReasoningVariants.INCORRECT_STEP_ID, name="Incorrect"
61-
)
10+
class IncorrectStepActions(Enum):
11+
REGENERATE_STEPS = "regenerateSteps"
12+
GENERATE_AND_RATE_ALTERNATIVE_STEPS = "generateAndRateAlternativeSteps"
6213

6314

6415
@dataclass
@@ -68,18 +19,22 @@ class StepReasoningVariants:
6819
Currently the options are correct, neutral, and incorrect
6920
"""
7021

71-
CORRECT_STEP_ID = 0
72-
NEUTRAL_STEP_ID = 1
73-
INCORRECT_STEP_ID = 2
74-
75-
correct_step: StepReasoningVariant = field(
76-
default_factory=_create_correct_step
22+
correct_step: Variant = field(
23+
default_factory=lambda: Variant(id=0, name="Correct")
7724
)
78-
neutral_step: StepReasoningVariant = field(
79-
default_factory=_create_neutral_step
25+
neutral_step: Variant = field(
26+
default_factory=lambda: Variant(id=1, name="Neutral")
8027
)
81-
incorrect_step: IncorrectStepReasoningVariant = field(
82-
default_factory=_create_incorrect_step
28+
29+
incorrect_step: VariantWithActions = field(
30+
default_factory=lambda: VariantWithActions(
31+
id=2,
32+
name="Incorrect",
33+
_available_actions={
34+
action.value for action in IncorrectStepActions
35+
},
36+
actions=["regenerateSteps"], # regenerateSteps is on by default
37+
)
8338
)
8439

8540
def asdict(self):
@@ -96,14 +51,12 @@ def from_dict(cls, dictionary: List[Dict[str, Any]]):
9651
incorrect_step = None
9752

9853
for variant in dictionary:
99-
if variant["id"] == cls.CORRECT_STEP_ID:
100-
correct_step = StepReasoningVariant(**variant)
101-
elif variant["id"] == cls.NEUTRAL_STEP_ID:
102-
neutral_step = StepReasoningVariant(**variant)
103-
elif variant["id"] == cls.INCORRECT_STEP_ID:
104-
incorrect_step = IncorrectStepReasoningVariant.from_dict(
105-
variant
106-
)
54+
if variant["id"] == 0:
55+
correct_step = Variant(**variant)
56+
elif variant["id"] == 1:
57+
neutral_step = Variant(**variant)
58+
elif variant["id"] == 2:
59+
incorrect_step = VariantWithActions(**variant)
10760

10861
if not all([correct_step, neutral_step, incorrect_step]):
10962
raise ValueError("Invalid step reasoning variants")
@@ -162,20 +115,12 @@ def __post_init__(self):
162115
"This feature is experimental and subject to change.",
163116
)
164117

165-
def reset_regenerate_conversations_after_incorrect_step(self):
118+
def set_incorrect_step_actions(self, actions: List[IncorrectStepActions]):
166119
"""
167-
For live models, the default acation will invoke the model to generate alternatives if a step is marked as incorrect
168-
This method will reset the action to not regenerate the conversation
120+
For live models, will invoke the model to generate alternatives if a step is marked as incorrect
169121
"""
170-
self.definition.variants.incorrect_step.regenerate_conversations_after_incorrect_step = False
171-
172-
def reset_rate_alternative_responses(self):
173-
"""
174-
For live models, will require labelers to rate the alternatives generated by the model
175-
"""
176-
self.definition.variants.incorrect_step.rate_alternative_responses = (
177-
False
178-
)
122+
actions_values = [action.value for action in actions]
123+
self.definition.variants.incorrect_step.set_actions(actions_values)
179124

180125
def asdict(self) -> Dict[str, Any]:
181126
return {

libs/labelbox/src/labelbox/schema/tool_building/tool_type_mapping.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@
55

66
def map_tool_type_to_tool_cls(tool_type_str: str):
77
if not ToolType.valid(tool_type_str):
8-
raise ValueError(f"Invalid tool type {tool_type_str}")
8+
return None
99

1010
tool_type = ToolType(tool_type_str.lower())
1111
if tool_type == ToolType.STEP_REASONING:
1212
return StepReasoningTool
1313
elif tool_type == ToolType.FACT_CHECKING:
1414
return FactCheckingTool
15-
16-
return None

libs/labelbox/src/labelbox/schema/tool_building/variant.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ class VariantWithActions:
2222
actions: List[str] = field(default_factory=list)
2323
_available_actions: Set[str] = field(default_factory=set)
2424

25-
def set_actions(self, actions: Set[str]) -> None:
25+
def set_actions(self, actions: List[str]) -> None:
26+
self.actions = []
2627
for action in actions:
2728
if action in self._available_actions:
2829
self.actions.append(action)
@@ -31,8 +32,11 @@ def reset_actions(self) -> None:
3132
self.actions = []
3233

3334
def asdict(self) -> Dict[str, Any]:
34-
return {
35+
data = {
3536
"id": self.id,
3637
"name": self.name,
37-
"actions": list(set(self.actions)),
3838
}
39+
if len(self.actions) > 0:
40+
data["actions"] = self.actions
41+
42+
return data

libs/labelbox/tests/integration/test_ontology.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from labelbox import MediaType, OntologyBuilder, Tool
77
from labelbox.orm.model import Entity
8+
from labelbox.schema.tool_building.fact_checking_tool import FactCheckingTool
89
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
910

1011

@@ -339,7 +340,7 @@ def test_step_reasoning_ontology(chat_evaluation_ontology):
339340
{
340341
"id": 2,
341342
"name": "Incorrect",
342-
"actions": [],
343+
"actions": ["regenerateSteps"],
343344
},
344345
]
345346
assert step_reasoning_tool["definition"]["version"] == 1
@@ -364,6 +365,59 @@ def test_step_reasoning_ontology(chat_evaluation_ontology):
364365
{
365366
"id": 2,
366367
"name": "Incorrect",
367-
"actions": [],
368+
"actions": ["regenerateSteps"],
369+
},
370+
]
371+
372+
373+
def test_fact_checking_ontology(chat_evaluation_ontology):
374+
ontology = chat_evaluation_ontology
375+
fact_checking = None
376+
for tool in ontology.normalized["tools"]:
377+
if tool["tool"] == "fact-checking":
378+
fact_checking = tool
379+
break
380+
assert fact_checking is not None
381+
assert fact_checking["definition"]["variants"] == [
382+
{"id": 0, "name": "Accurate"},
383+
{"id": 1, "name": "Inaccurate"},
384+
{"id": 2, "name": "Disputed"},
385+
{"id": 3, "name": "Unsupported", "actions": ["writeJustification"]},
386+
{
387+
"id": 4,
388+
"name": "Can't confidently assess",
389+
"actions": ["writeJustification"],
390+
},
391+
{
392+
"id": 5,
393+
"name": "No factual information",
394+
"actions": ["writeJustification"],
395+
},
396+
]
397+
assert fact_checking["definition"]["version"] == 1
398+
assert fact_checking["schemaNodeId"] is not None
399+
assert fact_checking["featureSchemaId"] is not None
400+
401+
fact_checking = None
402+
for tool in ontology.tools():
403+
if isinstance(tool, FactCheckingTool):
404+
fact_checking = tool
405+
break
406+
assert fact_checking is not None
407+
408+
assert fact_checking.definition.variants.asdict() == [
409+
{"id": 0, "name": "Accurate"},
410+
{"id": 1, "name": "Inaccurate"},
411+
{"id": 2, "name": "Disputed"},
412+
{"id": 3, "name": "Unsupported", "actions": ["writeJustification"]},
413+
{
414+
"id": 4,
415+
"name": "Can't confidently assess",
416+
"actions": ["writeJustification"],
417+
},
418+
{
419+
"id": 5,
420+
"name": "No factual information",
421+
"actions": ["writeJustification"],
368422
},
369423
]

libs/labelbox/tests/unit/test_unit_step_reasoning_tool.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
1+
from labelbox.schema.tool_building.step_reasoning_tool import (
2+
IncorrectStepActions,
3+
StepReasoningTool,
4+
)
25

36

47
def test_step_reasoning_as_dict_default():
@@ -16,7 +19,7 @@ def test_step_reasoning_as_dict_default():
1619
{
1720
"id": 2,
1821
"name": "Incorrect",
19-
"actions": [],
22+
"actions": ["regenerateSteps"],
2023
},
2124
],
2225
"version": 1,
@@ -26,8 +29,12 @@ def test_step_reasoning_as_dict_default():
2629

2730
def test_step_reasoning_as_dict_with_actions():
2831
tool = StepReasoningTool(name="step reasoning")
29-
tool.reset_rate_alternative_responses()
30-
tool.reset_regenerate_conversations_after_incorrect_step()
32+
tool.set_incorrect_step_actions(
33+
[
34+
IncorrectStepActions.REGENERATE_STEPS,
35+
IncorrectStepActions.GENERATE_AND_RATE_ALTERNATIVE_STEPS,
36+
]
37+
)
3138
assert tool.asdict() == {
3239
"tool": "step-reasoning",
3340
"name": "step reasoning",

0 commit comments

Comments
 (0)