Skip to content

Commit df64dc2

Browse files
author
Val Brodsky
committed
Refactor StepReasoning to also reuse Variants
1 parent 301d817 commit df64dc2

File tree

1 file changed

+22
-74
lines changed

1 file changed

+22
-74
lines changed

libs/labelbox/src/labelbox/schema/tool_building/step_reasoning_tool.py

Lines changed: 22 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -2,62 +2,7 @@
22
from typing import Any, Dict, List, Optional
33

44
from labelbox.schema.tool_building.tool_type import ToolType
5-
6-
7-
@dataclass
8-
class StepReasoningVariant:
9-
id: int
10-
name: str
11-
12-
def asdict(self) -> Dict[str, Any]:
13-
return {"id": self.id, "name": self.name}
14-
15-
16-
@dataclass
17-
class IncorrectStepReasoningVariant:
18-
id: int
19-
name: str
20-
regenerate_conversations_after_incorrect_step: Optional[bool] = False
21-
rate_alternative_responses: Optional[bool] = False
22-
23-
def asdict(self) -> Dict[str, Any]:
24-
actions = []
25-
if self.regenerate_conversations_after_incorrect_step:
26-
actions.append("regenerateSteps")
27-
if self.rate_alternative_responses:
28-
actions.append("generateAndRateAlternativeSteps")
29-
return {"id": self.id, "name": self.name, "actions": actions}
30-
31-
@classmethod
32-
def from_dict(
33-
cls, dictionary: Dict[str, Any]
34-
) -> "IncorrectStepReasoningVariant":
35-
return cls(
36-
id=dictionary["id"],
37-
name=dictionary["name"],
38-
regenerate_conversations_after_incorrect_step="regenerateSteps"
39-
in dictionary.get("actions", []),
40-
rate_alternative_responses="generateAndRateAlternativeSteps"
41-
in dictionary.get("actions", []),
42-
)
43-
44-
45-
def _create_correct_step() -> StepReasoningVariant:
46-
return StepReasoningVariant(
47-
id=StepReasoningVariants.CORRECT_STEP_ID, name="Correct"
48-
)
49-
50-
51-
def _create_neutral_step() -> StepReasoningVariant:
52-
return StepReasoningVariant(
53-
id=StepReasoningVariants.NEUTRAL_STEP_ID, name="Neutral"
54-
)
55-
56-
57-
def _create_incorrect_step() -> IncorrectStepReasoningVariant:
58-
return IncorrectStepReasoningVariant(
59-
id=StepReasoningVariants.INCORRECT_STEP_ID, name="Incorrect"
60-
)
5+
from labelbox.schema.tool_building.variant import Variant, VariantWithActions
616

627

638
@dataclass
@@ -67,18 +12,23 @@ class StepReasoningVariants:
6712
Currently the options are correct, neutral, and incorrect
6813
"""
6914

70-
CORRECT_STEP_ID = 0
71-
NEUTRAL_STEP_ID = 1
72-
INCORRECT_STEP_ID = 2
73-
74-
correct_step: StepReasoningVariant = field(
75-
default_factory=_create_correct_step
15+
correct_step: Variant = field(
16+
default_factory=lambda: Variant(id=0, name="Correct")
7617
)
77-
neutral_step: StepReasoningVariant = field(
78-
default_factory=_create_neutral_step
18+
neutral_step: Variant = field(
19+
default_factory=lambda: Variant(id=1, name="Neutral")
7920
)
80-
incorrect_step: IncorrectStepReasoningVariant = field(
81-
default_factory=_create_incorrect_step
21+
22+
incorrect_step: VariantWithActions = field(
23+
default_factory=lambda: VariantWithActions(
24+
id=2,
25+
name="Incorrect",
26+
_available_actions={
27+
"regenerateSteps",
28+
"generateAndRateAlternativeSteps",
29+
},
30+
actions=["regenerateSteps"], # regenerateSteps is on by default
31+
)
8232
)
8333

8434
def asdict(self):
@@ -95,14 +45,12 @@ def from_dict(cls, dictionary: List[Dict[str, Any]]):
9545
incorrect_step = None
9646

9747
for variant in dictionary:
98-
if variant["id"] == cls.CORRECT_STEP_ID:
99-
correct_step = StepReasoningVariant(**variant)
100-
elif variant["id"] == cls.NEUTRAL_STEP_ID:
101-
neutral_step = StepReasoningVariant(**variant)
102-
elif variant["id"] == cls.INCORRECT_STEP_ID:
103-
incorrect_step = IncorrectStepReasoningVariant.from_dict(
104-
variant
105-
)
48+
if variant["id"] == 0:
49+
correct_step = Variant(**variant)
50+
elif variant["id"] == 1:
51+
neutral_step = Variant(**variant)
52+
elif variant["id"] == 2:
53+
incorrect_step = VariantWithActions(**variant)
10654

10755
if not all([correct_step, neutral_step, incorrect_step]):
10856
raise ValueError("Invalid step reasoning variants")

0 commit comments

Comments
 (0)