Skip to content

Commit 073191b

Browse files
author
Val Brodsky
committed
Refactor StepReasoning to also reuse Variants
1 parent 80798d4 commit 073191b

File tree

1 file changed

+22
-74
lines changed

1 file changed

+22
-74
lines changed

libs/labelbox/src/labelbox/schema/tool_building/step_reasoning_tool.py

Lines changed: 22 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3,62 +3,7 @@
33
from typing import Any, Dict, List, Optional
44

55
from labelbox.schema.tool_building.tool_type import ToolType
6-
7-
8-
@dataclass
9-
class StepReasoningVariant:
10-
id: int
11-
name: str
12-
13-
def asdict(self) -> Dict[str, Any]:
14-
return {"id": self.id, "name": self.name}
15-
16-
17-
@dataclass
18-
class IncorrectStepReasoningVariant:
19-
id: int
20-
name: str
21-
regenerate_conversations_after_incorrect_step: Optional[bool] = True
22-
rate_alternative_responses: Optional[bool] = True
23-
24-
def asdict(self) -> Dict[str, Any]:
25-
actions = []
26-
if self.regenerate_conversations_after_incorrect_step:
27-
actions.append("regenerateSteps")
28-
if self.rate_alternative_responses:
29-
actions.append("generateAndRateAlternativeSteps")
30-
return {"id": self.id, "name": self.name, "actions": actions}
31-
32-
@classmethod
33-
def from_dict(
34-
cls, dictionary: Dict[str, Any]
35-
) -> "IncorrectStepReasoningVariant":
36-
return cls(
37-
id=dictionary["id"],
38-
name=dictionary["name"],
39-
regenerate_conversations_after_incorrect_step="regenerateSteps"
40-
in dictionary.get("actions", []),
41-
rate_alternative_responses="generateAndRateAlternativeSteps"
42-
in dictionary.get("actions", []),
43-
)
44-
45-
46-
def _create_correct_step() -> StepReasoningVariant:
47-
return StepReasoningVariant(
48-
id=StepReasoningVariants.CORRECT_STEP_ID, name="Correct"
49-
)
50-
51-
52-
def _create_neutral_step() -> StepReasoningVariant:
53-
return StepReasoningVariant(
54-
id=StepReasoningVariants.NEUTRAL_STEP_ID, name="Neutral"
55-
)
56-
57-
58-
def _create_incorrect_step() -> IncorrectStepReasoningVariant:
59-
return IncorrectStepReasoningVariant(
60-
id=StepReasoningVariants.INCORRECT_STEP_ID, name="Incorrect"
61-
)
6+
from labelbox.schema.tool_building.variant import Variant, VariantWithActions
627

638

649
@dataclass
@@ -68,18 +13,23 @@ class StepReasoningVariants:
6813
Currently the options are correct, neutral, and incorrect
6914
"""
7015

71-
CORRECT_STEP_ID = 0
72-
NEUTRAL_STEP_ID = 1
73-
INCORRECT_STEP_ID = 2
74-
75-
correct_step: StepReasoningVariant = field(
76-
default_factory=_create_correct_step
16+
correct_step: Variant = field(
17+
default_factory=lambda: Variant(id=0, name="Correct")
7718
)
78-
neutral_step: StepReasoningVariant = field(
79-
default_factory=_create_neutral_step
19+
neutral_step: Variant = field(
20+
default_factory=lambda: Variant(id=1, name="Neutral")
8021
)
81-
incorrect_step: IncorrectStepReasoningVariant = field(
82-
default_factory=_create_incorrect_step
22+
23+
incorrect_step: VariantWithActions = field(
24+
default_factory=lambda: VariantWithActions(
25+
id=2,
26+
name="Incorrect",
27+
_available_actions={
28+
"regenerateSteps",
29+
"generateAndRateAlternativeSteps",
30+
},
31+
actions=["regenerateSteps"], # regenerateSteps is on by default
32+
)
8333
)
8434

8535
def asdict(self):
@@ -96,14 +46,12 @@ def from_dict(cls, dictionary: List[Dict[str, Any]]):
9646
incorrect_step = None
9747

9848
for variant in dictionary:
99-
if variant["id"] == cls.CORRECT_STEP_ID:
100-
correct_step = StepReasoningVariant(**variant)
101-
elif variant["id"] == cls.NEUTRAL_STEP_ID:
102-
neutral_step = StepReasoningVariant(**variant)
103-
elif variant["id"] == cls.INCORRECT_STEP_ID:
104-
incorrect_step = IncorrectStepReasoningVariant.from_dict(
105-
variant
106-
)
49+
if variant["id"] == 0:
50+
correct_step = Variant(**variant)
51+
elif variant["id"] == 1:
52+
neutral_step = Variant(**variant)
53+
elif variant["id"] == 2:
54+
incorrect_step = VariantWithActions(**variant)
10755

10856
if not all([correct_step, neutral_step, incorrect_step]):
10957
raise ValueError("Invalid step reasoning variants")

0 commit comments

Comments
 (0)