Skip to content

Commit 45601d4

Browse files
author
Val Brodsky
committed
Add FactCheckingTool
1 parent f9ed608 commit 45601d4

File tree

10 files changed

+341
-4
lines changed

10 files changed

+341
-4
lines changed

docs/labelbox/fact-checking-tool.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Fact Checking Tool
2+
===============================================================================================
3+
4+
.. automodule:: labelbox.schema.tool_building.fact_checking_tool
5+
:members:
6+
:show-inheritance:

docs/labelbox/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Labelbox Python SDK Documentation
1919
enums
2020
exceptions
2121
export-task
22+
fact-checking-tool
2223
foundry-client
2324
foundry-model
2425
identifiable

libs/labelbox/src/labelbox/schema/ontology.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@
1212

1313
from labelbox.orm.db_object import DbObject
1414
from labelbox.orm.model import Field, Relationship
15+
from labelbox.schema.tool_building.fact_checking_tool import FactCheckingTool
1516
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
1617
from labelbox.schema.tool_building.tool_type import ToolType
18+
from labelbox.schema.tool_building.tool_type_mapping import (
19+
map_tool_type_to_tool_cls,
20+
)
1721

1822
FeatureSchemaId: Type[str] = Annotated[
1923
str, StringConstraints(min_length=25, max_length=25)
@@ -491,13 +495,14 @@ def add_classification(self, classification: Classification) -> None:
491495

492496

493497
def tool_cls_from_type(tool_type: str):
494-
if tool_type.lower() == ToolType.STEP_REASONING.value:
495-
return StepReasoningTool
498+
tool_cls = map_tool_type_to_tool_cls(tool_type)
499+
if tool_cls is not None:
500+
return tool_cls
496501
return Tool
497502

498503

499504
def tool_type_cls_from_type(tool_type: str):
500-
if tool_type.lower() == ToolType.STEP_REASONING.value:
505+
if ToolType.valid(tool_type):
501506
return ToolType
502507
return Tool.Type
503508

@@ -596,7 +601,9 @@ class OntologyBuilder:
596601
597602
"""
598603

599-
tools: List[Union[Tool, StepReasoningTool]] = field(default_factory=list)
604+
tools: List[Union[Tool, StepReasoningTool, FactCheckingTool]] = field(
605+
default_factory=list
606+
)
600607
classifications: List[
601608
Union[Classification, PromptResponseClassification]
602609
] = field(default_factory=list)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
import labelbox.schema.tool_building.tool_type
2+
import labelbox.schema.tool_building.variant
23
import labelbox.schema.tool_building.step_reasoning_tool
4+
import labelbox.schema.tool_building.fact_checking_tool
5+
import labelbox.schema.tool_building.tool_type_mapping
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
from dataclasses import dataclass, field
2+
from typing import Any, Dict, List, Optional, Set
3+
4+
from labelbox.schema.tool_building.tool_type import ToolType
5+
from labelbox.schema.tool_building.variant import (
6+
Variant,
7+
VariantWithActions,
8+
)
9+
10+
11+
@dataclass
12+
class FactCheckingVariants:
13+
"""
14+
This class is used to define the possible options for evaluating a step
15+
Currently the options are correct, neutral, and incorrect
16+
"""
17+
18+
accurate_step: Variant = field(
19+
default_factory=lambda: Variant(id=0, name="Accurate")
20+
)
21+
inaccurate_step: Variant = field(
22+
default_factory=lambda: Variant(id=1, name="Inaccurate")
23+
)
24+
disputed_step: Variant = field(
25+
default_factory=lambda: Variant(id=2, name="Disputed")
26+
)
27+
unsupported_step: VariantWithActions = field(
28+
default_factory=lambda: VariantWithActions(
29+
id=3, name="Unsupported", _available_actions={"writeJustification"}
30+
)
31+
)
32+
cant_confidently_assess_step: VariantWithActions = field(
33+
default_factory=lambda: VariantWithActions(
34+
id=4,
35+
name="Can't confidently assess",
36+
_available_actions={"writeJustification"},
37+
)
38+
)
39+
no_factual_information_step: VariantWithActions = field(
40+
default_factory=lambda: VariantWithActions(
41+
id=5,
42+
name="No factual information",
43+
_available_actions={"writeJustification"},
44+
)
45+
)
46+
47+
def asdict(self):
48+
return [
49+
self.accurate_step.asdict(),
50+
self.inaccurate_step.asdict(),
51+
self.disputed_step.asdict(),
52+
self.unsupported_step.asdict(),
53+
self.cant_confidently_assess_step.asdict(),
54+
self.no_factual_information_step.asdict(),
55+
]
56+
57+
@classmethod
58+
def from_dict(cls, dictionary: List[Dict[str, Any]]):
59+
accurate_step = None
60+
inaccurate_step = None
61+
disputed_step = None
62+
unsupported_step = None
63+
cant_confidently_assess_step = None
64+
no_factual_information_step = None
65+
66+
for variant in dictionary:
67+
if variant["id"] == 0:
68+
accurate_step = Variant(**variant)
69+
elif variant["id"] == 1:
70+
inaccurate_step = Variant(**variant)
71+
elif variant["id"] == 2:
72+
disputed_step = Variant(**variant)
73+
elif variant["id"] == 3:
74+
unsupported_step = VariantWithActions(**variant)
75+
elif variant["id"] == 4:
76+
cant_confidently_assess_step = VariantWithActions(**variant)
77+
elif variant["id"] == 5:
78+
no_factual_information_step = VariantWithActions(**variant)
79+
80+
if not all(
81+
[
82+
accurate_step,
83+
inaccurate_step,
84+
disputed_step,
85+
unsupported_step,
86+
cant_confidently_assess_step,
87+
no_factual_information_step,
88+
]
89+
):
90+
raise ValueError("Missing variant")
91+
92+
return cls(
93+
accurate_step=accurate_step,
94+
inaccurate_step=inaccurate_step,
95+
disputed_step=disputed_step,
96+
unsupported_step=unsupported_step,
97+
cant_confidently_assess_step=cant_confidently_assess_step,
98+
no_factual_information_step=no_factual_information_step,
99+
) # type: ignore
100+
101+
102+
@dataclass
103+
class FactCheckingDefinition:
104+
variants: FactCheckingVariants = field(default_factory=FactCheckingVariants)
105+
version: int = field(default=1)
106+
title: Optional[str] = None
107+
value: Optional[str] = None
108+
109+
def asdict(self) -> Dict[str, Any]:
110+
result = {"variants": self.variants.asdict(), "version": self.version}
111+
if self.title is not None:
112+
result["title"] = self.title
113+
if self.value is not None:
114+
result["value"] = self.value
115+
return result
116+
117+
@classmethod
118+
def from_dict(cls, dictionary: Dict[str, Any]) -> "FactCheckingDefinition":
119+
variants = FactCheckingVariants.from_dict(dictionary["variants"])
120+
title = dictionary.get("title", None)
121+
value = dictionary.get("value", None)
122+
return cls(variants=variants, title=title, value=value)
123+
124+
125+
@dataclass
126+
class FactCheckingTool:
127+
"""
128+
Use this class in OntologyBuilder to create a tool for step fact checking
129+
"""
130+
131+
name: str
132+
type: ToolType = field(default=ToolType.FACT_CHECKING, init=False)
133+
required: bool = False
134+
schema_id: Optional[str] = None
135+
feature_schema_id: Optional[str] = None
136+
color: Optional[str] = None
137+
definition: FactCheckingDefinition = field(
138+
default_factory=FactCheckingDefinition
139+
)
140+
141+
def set_unsupported_step_actions(
142+
self, actions: Set[str] = {"writeJustification"}
143+
) -> None:
144+
self.definition.variants.unsupported_step.set_actions(actions)
145+
146+
def set_cant_confidently_assess_step_actions(
147+
self, actions: Set[str] = {"writeJustification"}
148+
) -> None:
149+
self.definition.variants.cant_confidently_assess_step.set_actions(
150+
actions
151+
)
152+
153+
def set_no_factual_information_step_actions(
154+
self, actions: Set[str] = {"writeJustification"}
155+
) -> None:
156+
self.definition.variants.no_factual_information_step.set_actions(
157+
actions
158+
)
159+
160+
def asdict(self) -> Dict[str, Any]:
161+
return {
162+
"tool": self.type.value,
163+
"name": self.name,
164+
"required": self.required,
165+
"schemaNodeId": self.schema_id,
166+
"featureSchemaId": self.feature_schema_id,
167+
"definition": self.definition.asdict(),
168+
}
169+
170+
@classmethod
171+
def from_dict(cls, dictionary: Dict[str, Any]) -> "FactCheckingTool":
172+
return cls(
173+
name=dictionary["name"],
174+
schema_id=dictionary.get("schemaNodeId", None),
175+
feature_schema_id=dictionary.get("featureSchemaId", None),
176+
required=dictionary.get("required", False),
177+
definition=FactCheckingDefinition.from_dict(
178+
dictionary["definition"]
179+
),
180+
)

libs/labelbox/src/labelbox/schema/tool_building/tool_type.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,12 @@
33

44
class ToolType(Enum):
55
STEP_REASONING = "step-reasoning"
6+
FACT_CHECKING = "fact-checking"
7+
8+
@classmethod
9+
def valid(cls, tool_type: str) -> bool:
10+
try:
11+
ToolType(tool_type.lower())
12+
return True
13+
except ValueError:
14+
return False
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from labelbox.schema.tool_building.fact_checking_tool import FactCheckingTool
2+
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
3+
from labelbox.schema.tool_building.tool_type import ToolType
4+
5+
6+
def map_tool_type_to_tool_cls(tool_type_str: str):
7+
if not ToolType.valid(tool_type_str):
8+
raise ValueError(f"Invalid tool type {tool_type_str}")
9+
10+
tool_type = ToolType(tool_type_str.lower())
11+
if tool_type == ToolType.STEP_REASONING:
12+
return StepReasoningTool
13+
elif tool_type == ToolType.FACT_CHECKING:
14+
return FactCheckingTool
15+
16+
return None
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from dataclasses import dataclass, field
2+
from typing import Any, Dict, List, Set
3+
4+
5+
@dataclass
6+
class Variant:
7+
"""
8+
A variant is a single option in step-by-step reasoning or fact-checking tool.
9+
"""
10+
11+
id: int
12+
name: str
13+
14+
def asdict(self) -> Dict[str, Any]:
15+
return {"id": self.id, "name": self.name}
16+
17+
18+
@dataclass
19+
class VariantWithActions:
20+
id: int
21+
name: str
22+
actions: List[str] = field(default_factory=list)
23+
_available_actions: Set[str] = field(default_factory=set)
24+
25+
def set_actions(self, actions: Set[str]) -> None:
26+
for action in actions:
27+
if action in self._available_actions:
28+
self.actions.append(action)
29+
30+
def reset_actions(self) -> None:
31+
self.actions = []
32+
33+
def asdict(self) -> Dict[str, Any]:
34+
return {
35+
"id": self.id,
36+
"name": self.name,
37+
"actions": list(set(self.actions)),
38+
}

libs/labelbox/tests/integration/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from labelbox.schema.data_row import DataRowMetadataField
2323
from labelbox.schema.ontology_kind import OntologyKind
24+
from labelbox.schema.tool_building.fact_checking_tool import FactCheckingTool
2425
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
2526
from labelbox.schema.user import User
2627

@@ -579,6 +580,7 @@ def chat_evaluation_ontology(client, rand_gen):
579580
name="model output multi ranking",
580581
),
581582
StepReasoningTool(name="step reasoning"),
583+
FactCheckingTool(name="fact checking"),
582584
],
583585
classifications=[
584586
Classification(

0 commit comments

Comments
 (0)