2
2
from typing import Any , Dict , List , Optional
3
3
4
4
from labelbox .schema .tool_building .tool_type import ToolType
5
-
6
-
7
- @dataclass
8
- class StepReasoningVariant :
9
- id : int
10
- name : str
11
-
12
- def asdict (self ) -> Dict [str , Any ]:
13
- return {"id" : self .id , "name" : self .name }
14
-
15
-
16
- @dataclass
17
- class IncorrectStepReasoningVariant :
18
- id : int
19
- name : str
20
- regenerate_conversations_after_incorrect_step : Optional [bool ] = False
21
- rate_alternative_responses : Optional [bool ] = False
22
-
23
- def asdict (self ) -> Dict [str , Any ]:
24
- actions = []
25
- if self .regenerate_conversations_after_incorrect_step :
26
- actions .append ("regenerateSteps" )
27
- if self .rate_alternative_responses :
28
- actions .append ("generateAndRateAlternativeSteps" )
29
- return {"id" : self .id , "name" : self .name , "actions" : actions }
30
-
31
- @classmethod
32
- def from_dict (
33
- cls , dictionary : Dict [str , Any ]
34
- ) -> "IncorrectStepReasoningVariant" :
35
- return cls (
36
- id = dictionary ["id" ],
37
- name = dictionary ["name" ],
38
- regenerate_conversations_after_incorrect_step = "regenerateSteps"
39
- in dictionary .get ("actions" , []),
40
- rate_alternative_responses = "generateAndRateAlternativeSteps"
41
- in dictionary .get ("actions" , []),
42
- )
43
-
44
-
45
- def _create_correct_step () -> StepReasoningVariant :
46
- return StepReasoningVariant (
47
- id = StepReasoningVariants .CORRECT_STEP_ID , name = "Correct"
48
- )
49
-
50
-
51
- def _create_neutral_step () -> StepReasoningVariant :
52
- return StepReasoningVariant (
53
- id = StepReasoningVariants .NEUTRAL_STEP_ID , name = "Neutral"
54
- )
55
-
56
-
57
- def _create_incorrect_step () -> IncorrectStepReasoningVariant :
58
- return IncorrectStepReasoningVariant (
59
- id = StepReasoningVariants .INCORRECT_STEP_ID , name = "Incorrect"
60
- )
5
+ from labelbox .schema .tool_building .variant import Variant , VariantWithActions
61
6
62
7
63
8
@dataclass
@@ -67,18 +12,23 @@ class StepReasoningVariants:
67
12
Currently the options are correct, neutral, and incorrect
68
13
"""
69
14
70
- CORRECT_STEP_ID = 0
71
- NEUTRAL_STEP_ID = 1
72
- INCORRECT_STEP_ID = 2
73
-
74
- correct_step : StepReasoningVariant = field (
75
- default_factory = _create_correct_step
15
+ correct_step : Variant = field (
16
+ default_factory = lambda : Variant (id = 0 , name = "Correct" )
76
17
)
77
- neutral_step : StepReasoningVariant = field (
78
- default_factory = _create_neutral_step
18
+ neutral_step : Variant = field (
19
+ default_factory = lambda : Variant ( id = 1 , name = "Neutral" )
79
20
)
80
- incorrect_step : IncorrectStepReasoningVariant = field (
81
- default_factory = _create_incorrect_step
21
+
22
+ incorrect_step : VariantWithActions = field (
23
+ default_factory = lambda : VariantWithActions (
24
+ id = 2 ,
25
+ name = "Incorrect" ,
26
+ _available_actions = {
27
+ "regenerateSteps" ,
28
+ "generateAndRateAlternativeSteps" ,
29
+ },
30
+ actions = ["regenerateSteps" ], # regenerateSteps is on by default
31
+ )
82
32
)
83
33
84
34
def asdict (self ):
@@ -95,14 +45,12 @@ def from_dict(cls, dictionary: List[Dict[str, Any]]):
95
45
incorrect_step = None
96
46
97
47
for variant in dictionary :
98
- if variant ["id" ] == cls .CORRECT_STEP_ID :
99
- correct_step = StepReasoningVariant (** variant )
100
- elif variant ["id" ] == cls .NEUTRAL_STEP_ID :
101
- neutral_step = StepReasoningVariant (** variant )
102
- elif variant ["id" ] == cls .INCORRECT_STEP_ID :
103
- incorrect_step = IncorrectStepReasoningVariant .from_dict (
104
- variant
105
- )
48
+ if variant ["id" ] == 0 :
49
+ correct_step = Variant (** variant )
50
+ elif variant ["id" ] == 1 :
51
+ neutral_step = Variant (** variant )
52
+ elif variant ["id" ] == 2 :
53
+ incorrect_step = VariantWithActions (** variant )
106
54
107
55
if not all ([correct_step , neutral_step , incorrect_step ]):
108
56
raise ValueError ("Invalid step reasoning variants" )
0 commit comments