Skip to content

Commit 7f0dc54

Browse files
author
Val Brodsky
committed
Update FactCheckingTool to match the latest api
1 parent 11e89eb commit 7f0dc54

File tree

1 file changed

+43
-23
lines changed

1 file changed

+43
-23
lines changed

libs/labelbox/src/labelbox/schema/tool_building/fact_checking_tool.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55

66
from labelbox.schema.tool_building.tool_type import ToolType
77
from labelbox.schema.tool_building.variant import (
8-
Variant,
98
VariantWithActions,
109
)
1110

1211

12+
class FactCheckingActions(Enum):
13+
WRITE_JUSTIFICATION = "writeJustification"
14+
15+
1316
class UnsupportedStepActions(Enum):
1417
WRITE_JUSTIFICATION = "writeJustification"
1518

@@ -29,43 +32,53 @@ class FactCheckingVariants:
2932
NOTE do not change the variants directly
3033
"""
3134

32-
accurate_step: Variant = field(
33-
default_factory=lambda: Variant(id=0, name="Accurate")
35+
accurate_step: VariantWithActions = field(
36+
default_factory=lambda: VariantWithActions(
37+
id=0,
38+
name="Accurate",
39+
_available_actions={action.value for action in FactCheckingActions},
40+
actions=[action.value for action in FactCheckingActions],
41+
)
3442
)
35-
inaccurate_step: Variant = field(
36-
default_factory=lambda: Variant(id=1, name="Inaccurate")
43+
44+
inaccurate_step: VariantWithActions = field(
45+
default_factory=lambda: VariantWithActions(
46+
id=1,
47+
name="Inaccurate",
48+
_available_actions={action.value for action in FactCheckingActions},
49+
actions=[action.value for action in FactCheckingActions],
50+
)
3751
)
38-
disputed_step: Variant = field(
39-
default_factory=lambda: Variant(id=2, name="Disputed")
52+
disputed_step: VariantWithActions = field(
53+
default_factory=lambda: VariantWithActions(
54+
id=2,
55+
name="Disputed",
56+
_available_actions={action.value for action in FactCheckingActions},
57+
actions=[action.value for action in FactCheckingActions],
58+
)
4059
)
4160
unsupported_step: VariantWithActions = field(
4261
default_factory=lambda: VariantWithActions(
4362
id=3,
4463
name="Unsupported",
45-
_available_actions={
46-
action.value for action in UnsupportedStepActions
47-
},
48-
actions=[UnsupportedStepActions.WRITE_JUSTIFICATION.value],
64+
_available_actions=set(),
65+
actions=[],
4966
)
5067
)
5168
cant_confidently_assess_step: VariantWithActions = field(
5269
default_factory=lambda: VariantWithActions(
5370
id=4,
5471
name="Can't confidently assess",
55-
_available_actions={
56-
action.value for action in CanConfidentlyAssessStepActions
57-
},
58-
actions=[CanConfidentlyAssessStepActions.WRITE_JUSTIFICATION.value],
72+
_available_actions=set(),
73+
actions=[],
5974
)
6075
)
6176
no_factual_information_step: VariantWithActions = field(
6277
default_factory=lambda: VariantWithActions(
6378
id=5,
6479
name="No factual information",
65-
_available_actions={
66-
action.value for action in NoFactualInformationStepActions
67-
},
68-
actions=[NoFactualInformationStepActions.WRITE_JUSTIFICATION.value],
80+
_available_actions=set(),
81+
actions=[],
6982
)
7083
)
7184

@@ -90,11 +103,11 @@ def from_dict(cls, dictionary: List[Dict[str, Any]]):
90103

91104
for variant in dictionary:
92105
if variant["id"] == 0:
93-
accurate_step = Variant(**variant)
106+
accurate_step = VariantWithActions(**variant)
94107
elif variant["id"] == 1:
95-
inaccurate_step = Variant(**variant)
108+
inaccurate_step = VariantWithActions(**variant)
96109
elif variant["id"] == 2:
97-
disputed_step = Variant(**variant)
110+
disputed_step = VariantWithActions(**variant)
98111
elif variant["id"] == 3:
99112
unsupported_step = VariantWithActions(**variant)
100113
elif variant["id"] == 4:
@@ -126,6 +139,12 @@ def from_dict(cls, dictionary: List[Dict[str, Any]]):
126139

127140
@dataclass
128141
class FactCheckingDefinition:
142+
name: str
143+
type: ToolType = field(default=ToolType.FACT_CHECKING, init=False)
144+
required: bool = False
145+
schema_id: Optional[str] = None
146+
feature_schema_id: Optional[str] = None
147+
color: Optional[str] = None
129148
variants: FactCheckingVariants = field(default_factory=FactCheckingVariants)
130149
version: int = field(default=1)
131150
title: Optional[str] = None
@@ -145,10 +164,11 @@ def asdict(self) -> Dict[str, Any]:
145164

146165
@classmethod
147166
def from_dict(cls, dictionary: Dict[str, Any]) -> "FactCheckingDefinition":
167+
name = dictionary["name"]
148168
variants = FactCheckingVariants.from_dict(dictionary["variants"])
149169
title = dictionary.get("title", None)
150170
value = dictionary.get("value", None)
151-
return cls(variants=variants, title=title, value=value)
171+
return cls(name=name, variants=variants, title=title, value=value)
152172

153173

154174
@dataclass

0 commit comments

Comments
 (0)