From ef129ef94fed5ef9183485e0f0a28b460e624b5d Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Fri, 18 Oct 2024 16:27:23 -0700 Subject: [PATCH 1/3] Add step by step reasoning ontology tool --- libs/labelbox/src/labelbox/schema/ontology.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index 3acd9e1e2..76cb91c0c 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -492,6 +492,10 @@ def add_classification(self, classification: Classification) -> None: def tool_cls_from_type(tool_type: str): if tool_type.lower() == ToolType.STEP_REASONING.value: + from labelbox.schema.tool_building.step_reasoning_tool import ( + StepReasoningTool, + ) + return StepReasoningTool return Tool @@ -535,6 +539,12 @@ def __init__(self, *args, **kwargs) -> None: Union[List[Classification], List[PromptResponseClassification]] ] = None + def _tool_deserializer_cls(self, tool: Dict[str, Any]) -> Tool: + import pdb + + pdb.set_trace() + return Tool + def tools(self) -> List[Tool]: """Get list of tools (AKA objects) in an Ontology.""" if self._tools is None: From 37d6dc0615016b83e614fd348ce8720ee72067af Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Tue, 22 Oct 2024 15:51:46 -0700 Subject: [PATCH 2/3] Update to support create_ontology_from_feature_schemas --- libs/labelbox/src/labelbox/schema/ontology.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index 76cb91c0c..3acd9e1e2 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -492,10 +492,6 @@ def add_classification(self, classification: Classification) -> None: def tool_cls_from_type(tool_type: str): if tool_type.lower() == ToolType.STEP_REASONING.value: - from labelbox.schema.tool_building.step_reasoning_tool import ( - StepReasoningTool, - ) - return StepReasoningTool return Tool @@ -539,12 +535,6 @@ def __init__(self, *args, **kwargs) -> None: Union[List[Classification], List[PromptResponseClassification]] ] = None - def _tool_deserializer_cls(self, tool: Dict[str, Any]) -> Tool: - import pdb - - pdb.set_trace() - return Tool - def tools(self) -> List[Tool]: """Get list of tools (AKA objects) in an Ontology.""" if self._tools is None: From d4dd8447bbcd1ca95c4f666b7e16160b12179417 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Fri, 25 Oct 2024 15:21:38 -0700 Subject: [PATCH 3/3] Add FactCheckingTool Refactor StepReasoning to also reuse Variants --- docs/labelbox/fact-checking-tool.rst | 6 + docs/labelbox/index.rst | 1 + libs/labelbox/src/labelbox/__init__.py | 4 +- libs/labelbox/src/labelbox/schema/ontology.py | 20 +- .../labelbox/schema/tool_building/__init__.py | 2 + .../tool_building/base_step_reasoning_tool.py | 109 ++++++++ .../tool_building/fact_checking_tool.py | 84 +++++++ .../tool_building/step_reasoning_tool.py | 233 +++--------------- .../schema/tool_building/tool_type.py | 9 + .../schema/tool_building/tool_type_mapping.py | 14 ++ libs/labelbox/tests/integration/conftest.py | 4 +- .../test_chat_evaluation_ontology_project.py | 2 +- .../tests/integration/test_ontology.py | 95 +++++-- .../unit/test_unit_fact_checking_tool.py | 43 ++++ .../unit/test_unit_step_ontology_variants.py | 31 +++ .../unit/test_unit_step_reasoning_tool.py | 57 ++++- 16 files changed, 480 insertions(+), 234 deletions(-) create mode 100644 docs/labelbox/fact-checking-tool.rst create mode 100644 libs/labelbox/src/labelbox/schema/tool_building/base_step_reasoning_tool.py create mode 100644 libs/labelbox/src/labelbox/schema/tool_building/fact_checking_tool.py create mode 100644 libs/labelbox/src/labelbox/schema/tool_building/tool_type_mapping.py create mode 100644 libs/labelbox/tests/unit/test_unit_fact_checking_tool.py create mode 100644 libs/labelbox/tests/unit/test_unit_step_ontology_variants.py diff --git a/docs/labelbox/fact-checking-tool.rst b/docs/labelbox/fact-checking-tool.rst new file mode 100644 index 000000000..75eea4e22 --- /dev/null +++ b/docs/labelbox/fact-checking-tool.rst @@ -0,0 +1,6 @@ +Fact Checking Tool +=============================================================================================== + +.. automodule:: labelbox.schema.tool_building.fact_checking_tool + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/labelbox/index.rst b/docs/labelbox/index.rst index 15ff9a0a9..f28de02fe 100644 --- a/docs/labelbox/index.rst +++ b/docs/labelbox/index.rst @@ -19,6 +19,7 @@ Labelbox Python SDK Documentation enums exceptions export-task + fact-checking-tool foundry-client foundry-model identifiable diff --git a/libs/labelbox/src/labelbox/__init__.py b/libs/labelbox/src/labelbox/__init__.py index 60940fa22..74fc047ed 100644 --- a/libs/labelbox/src/labelbox/__init__.py +++ b/libs/labelbox/src/labelbox/__init__.py @@ -55,8 +55,8 @@ ResponseOption, Tool, ) -from labelbox.schema.ontology import PromptResponseClassification -from labelbox.schema.ontology import ResponseOption +from labelbox.schema.tool_building.fact_checking_tool import FactCheckingTool +from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool from labelbox.schema.role import Role, ProjectRole from labelbox.schema.invite import Invite, InviteLimit from labelbox.schema.data_row_metadata import ( diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index 3acd9e1e2..607bf8fec 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -12,8 +12,12 @@ from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field, Relationship +from labelbox.schema.tool_building.fact_checking_tool import FactCheckingTool from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool from labelbox.schema.tool_building.tool_type import ToolType +from labelbox.schema.tool_building.tool_type_mapping import ( + map_tool_type_to_tool_cls, +) FeatureSchemaId: Type[str] = Annotated[ str, StringConstraints(min_length=25, max_length=25) @@ -490,14 +494,20 @@ def add_classification(self, classification: Classification) -> None: self.classifications.append(classification) +""" +The following 2 functions help to bridge the gap between the step reasoning all other tool ontologies. +""" + + def tool_cls_from_type(tool_type: str): - if tool_type.lower() == ToolType.STEP_REASONING.value: - return StepReasoningTool + tool_cls = map_tool_type_to_tool_cls(tool_type) + if tool_cls is not None: + return tool_cls return Tool def tool_type_cls_from_type(tool_type: str): - if tool_type.lower() == ToolType.STEP_REASONING.value: + if ToolType.valid(tool_type): return ToolType return Tool.Type @@ -596,7 +606,9 @@ class OntologyBuilder: """ - tools: List[Union[Tool, StepReasoningTool]] = field(default_factory=list) + tools: List[Union[Tool, StepReasoningTool, FactCheckingTool]] = field( + default_factory=list + ) classifications: List[ Union[Classification, PromptResponseClassification] ] = field(default_factory=list) diff --git a/libs/labelbox/src/labelbox/schema/tool_building/__init__.py b/libs/labelbox/src/labelbox/schema/tool_building/__init__.py index 45098ef84..dab325388 100644 --- a/libs/labelbox/src/labelbox/schema/tool_building/__init__.py +++ b/libs/labelbox/src/labelbox/schema/tool_building/__init__.py @@ -1,2 +1,4 @@ import labelbox.schema.tool_building.tool_type import labelbox.schema.tool_building.step_reasoning_tool +import labelbox.schema.tool_building.fact_checking_tool +import labelbox.schema.tool_building.tool_type_mapping diff --git a/libs/labelbox/src/labelbox/schema/tool_building/base_step_reasoning_tool.py b/libs/labelbox/src/labelbox/schema/tool_building/base_step_reasoning_tool.py new file mode 100644 index 000000000..c9c9809d0 --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/tool_building/base_step_reasoning_tool.py @@ -0,0 +1,109 @@ +import warnings +from abc import ABC +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set + +from labelbox.schema.tool_building.tool_type import ToolType + + +@dataclass +class _Variant: + id: int + name: str + actions: List[str] = field(default_factory=list) + _available_actions: Set[str] = field(default_factory=set) + + def set_actions(self, actions: List[str]) -> None: + self.actions = [] + for action in actions: + if action in self._available_actions: + self.actions.append(action) + else: + warnings.warn( + f"Variant ID {self.id} {action} is an invalid action, skipping" + ) + + def reset_actions(self) -> None: + self.actions = [] + + def asdict(self) -> Dict[str, Any]: + return { + "id": self.id, + "name": self.name, + "actions": self.actions, + } + + def _post_init(self): + # Call set_actions to remove any invalid actions + self.set_actions(self.actions) + + +@dataclass +class _Definition: + variants: List[_Variant] + version: int = field(default=1) + title: Optional[str] = None + value: Optional[str] = None + + def __post_init__(self): + if self.version != 1: + raise ValueError("Invalid version") + + def asdict(self) -> Dict[str, Any]: + result = { + "variants": [variant.asdict() for variant in self.variants], + "version": self.version, + } + if self.title is not None: + result["title"] = self.title + if self.value is not None: + result["value"] = self.value + return result + + @classmethod + def from_dict(cls, dictionary: Dict[str, Any]) -> "_Definition": + variants = [_Variant(**variant) for variant in dictionary["variants"]] + title = dictionary.get("title", None) + value = dictionary.get("value", None) + return cls(variants=variants, title=title, value=value) + + +@dataclass +class _BaseStepReasoningTool(ABC): + name: str + definition: _Definition + type: Optional[ToolType] = None + schema_id: Optional[str] = None + feature_schema_id: Optional[str] = None + color: Optional[str] = None + required: bool = False + + def __post_init__(self): + warnings.warn( + "This feature is experimental and subject to change.", + ) + + if not self.name.strip(): + raise ValueError("Name is required") + + def asdict(self) -> Dict[str, Any]: + return { + "tool": self.type.value if self.type else None, + "name": self.name, + "required": self.required, + "schemaNodeId": self.schema_id, + "featureSchemaId": self.feature_schema_id, + "definition": self.definition.asdict(), + "color": self.color, + } + + @classmethod + def from_dict(cls, dictionary: Dict[str, Any]) -> "_BaseStepReasoningTool": + return cls( + name=dictionary["name"], + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None), + required=dictionary.get("required", False), + definition=_Definition.from_dict(dictionary["definition"]), + color=dictionary.get("color", None), + ) diff --git a/libs/labelbox/src/labelbox/schema/tool_building/fact_checking_tool.py b/libs/labelbox/src/labelbox/schema/tool_building/fact_checking_tool.py new file mode 100644 index 000000000..4a02a482c --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/tool_building/fact_checking_tool.py @@ -0,0 +1,84 @@ +from dataclasses import dataclass, field +from enum import Enum + +from labelbox.schema.tool_building.base_step_reasoning_tool import ( + _BaseStepReasoningTool, + _Definition, + _Variant, +) +from labelbox.schema.tool_building.tool_type import ToolType + + +class FactCheckingActions(Enum): + WRITE_JUSTIFICATION = "justification" + + +def build_fact_checking_definition(): + accurate_step = _Variant( + id=0, + name="Accurate", + _available_actions={action.value for action in FactCheckingActions}, + actions=[action.value for action in FactCheckingActions], + ) + inaccurate_step = _Variant( + id=1, + name="Inaccurate", + _available_actions={action.value for action in FactCheckingActions}, + actions=[action.value for action in FactCheckingActions], + ) + disputed_step = _Variant( + id=2, + name="Disputed", + _available_actions={action.value for action in FactCheckingActions}, + actions=[action.value for action in FactCheckingActions], + ) + unsupported_step = _Variant( + id=3, + name="Unsupported", + _available_actions=set(), + actions=[], + ) + cant_confidently_assess_step = _Variant( + id=4, + name="Can't confidently assess", + _available_actions=set(), + actions=[], + ) + no_factual_information_step = _Variant( + id=5, + name="No factual information", + _available_actions=set(), + actions=[], + ) + variants = [ + accurate_step, + inaccurate_step, + disputed_step, + unsupported_step, + cant_confidently_assess_step, + no_factual_information_step, + ] + return _Definition(variants=variants) + + +@dataclass +class FactCheckingTool(_BaseStepReasoningTool): + """ + Use this class in OntologyBuilder to create a tool for fact checking + """ + + type: ToolType = field(default=ToolType.FACT_CHECKING, init=False) + definition: _Definition = field( + default_factory=build_fact_checking_definition + ) + + def __post_init__(self): + super().__post_init__() + # Set available actions for variants 0, 1, 2 'out of band' since they are not passed in the definition + self._set_variant_available_actions() + + def _set_variant_available_actions(self): + for variant in self.definition.variants: + if variant.id in [0, 1, 2]: + for action in FactCheckingActions: + variant._available_actions.add(action.value) diff --git a/libs/labelbox/src/labelbox/schema/tool_building/step_reasoning_tool.py b/libs/labelbox/src/labelbox/schema/tool_building/step_reasoning_tool.py index 084816e55..c69c6f22a 100644 --- a/libs/labelbox/src/labelbox/schema/tool_building/step_reasoning_tool.py +++ b/libs/labelbox/src/labelbox/schema/tool_building/step_reasoning_tool.py @@ -1,218 +1,55 @@ -import warnings from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from enum import Enum +from labelbox.schema.tool_building.base_step_reasoning_tool import ( + _BaseStepReasoningTool, + _Definition, + _Variant, +) from labelbox.schema.tool_building.tool_type import ToolType -@dataclass -class StepReasoningVariant: - id: int - name: str - actions: List[str] = field(default_factory=list) - - def asdict(self) -> Dict[str, Any]: - return {"id": self.id, "name": self.name, "actions": self.actions} - - -@dataclass -class IncorrectStepReasoningVariant: - id: int - name: str - regenerate_steps: Optional[bool] = True - generate_and_rate_alternative_steps: Optional[bool] = True - rewrite_step: Optional[bool] = True - justification: Optional[bool] = True - - def asdict(self) -> Dict[str, Any]: - actions = [] - if self.regenerate_steps: - actions.append("regenerateSteps") - if self.generate_and_rate_alternative_steps: - actions.append("generateAndRateAlternativeSteps") - if self.rewrite_step: - actions.append("rewriteStep") - if self.justification: - actions.append("justification") - return {"id": self.id, "name": self.name, "actions": actions} +class IncorrectStepActions(Enum): + REGENERATE_STEPS = "regenerateSteps" + GENERATE_AND_RATE_ALTERNATIVE_STEPS = "generateAndRateAlternativeSteps" + REWRITE_STEP = "rewriteStep" + JUSTIFICATION = "justification" - @classmethod - def from_dict( - cls, dictionary: Dict[str, Any] - ) -> "IncorrectStepReasoningVariant": - return cls( - id=dictionary["id"], - name=dictionary["name"], - regenerate_steps="regenerateSteps" in dictionary.get("actions", []), - generate_and_rate_alternative_steps="generateAndRateAlternativeSteps" - in dictionary.get("actions", []), - rewrite_step="rewriteStep" in dictionary.get("actions", []), - justification="justification" in dictionary.get("actions", []), - ) - -def _create_correct_step() -> StepReasoningVariant: - return StepReasoningVariant( - id=StepReasoningVariants.CORRECT_STEP_ID, name="Correct" - ) - - -def _create_neutral_step() -> StepReasoningVariant: - return StepReasoningVariant( - id=StepReasoningVariants.NEUTRAL_STEP_ID, name="Neutral" - ) - - -def _create_incorrect_step() -> IncorrectStepReasoningVariant: - return IncorrectStepReasoningVariant( - id=StepReasoningVariants.INCORRECT_STEP_ID, name="Incorrect" +def build_step_reasoning_definition(): + correct_step = _Variant(id=0, name="Correct", actions=[]) + neutral_step = _Variant(id=1, name="Neutral", actions=[]) + incorrect_step = _Variant( + id=2, + name="Incorrect", + _available_actions={action.value for action in IncorrectStepActions}, + actions=[action.value for action in IncorrectStepActions], ) + variants = [correct_step, neutral_step, incorrect_step] + return _Definition(variants=variants) @dataclass -class StepReasoningVariants: - """ - This class is used to define the possible options for evaluating a step - Currently the options are correct, neutral, and incorrect - """ - - CORRECT_STEP_ID = 0 - NEUTRAL_STEP_ID = 1 - INCORRECT_STEP_ID = 2 - - correct_step: StepReasoningVariant = field( - default_factory=_create_correct_step - ) - neutral_step: StepReasoningVariant = field( - default_factory=_create_neutral_step - ) - incorrect_step: IncorrectStepReasoningVariant = field( - default_factory=_create_incorrect_step - ) - - def asdict(self): - return [ - self.correct_step.asdict(), - self.neutral_step.asdict(), - self.incorrect_step.asdict(), - ] - - @classmethod - def from_dict(cls, dictionary: List[Dict[str, Any]]): - correct_step = None - neutral_step = None - incorrect_step = None - - for variant in dictionary: - if variant["id"] == cls.CORRECT_STEP_ID: - correct_step = StepReasoningVariant(**variant) - elif variant["id"] == cls.NEUTRAL_STEP_ID: - neutral_step = StepReasoningVariant(**variant) - elif variant["id"] == cls.INCORRECT_STEP_ID: - incorrect_step = IncorrectStepReasoningVariant.from_dict( - variant - ) - - if not all([correct_step, neutral_step, incorrect_step]): - raise ValueError("Invalid step reasoning variants") - - return cls( - correct_step=correct_step, # type: ignore - neutral_step=neutral_step, # type: ignore - incorrect_step=incorrect_step, # type: ignore - ) - - -@dataclass -class StepReasoningDefinition: - variants: StepReasoningVariants = field( - default_factory=StepReasoningVariants - ) - version: int = field(default=1) - title: Optional[str] = None - value: Optional[str] = None - - def asdict(self) -> Dict[str, Any]: - result = {"variants": self.variants.asdict(), "version": self.version} - if self.title is not None: - result["title"] = self.title - if self.value is not None: - result["value"] = self.value - return result - - @classmethod - def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningDefinition": - variants = StepReasoningVariants.from_dict(dictionary["variants"]) - title = dictionary.get("title", None) - value = dictionary.get("value", None) - return cls(variants=variants, title=title, value=value) - - -@dataclass -class StepReasoningTool: +class StepReasoningTool(_BaseStepReasoningTool): """ Use this class in OntologyBuilder to create a tool for step reasoning The definition field lists the possible options to evaulate a step + + NOTE: color attribute is for backward compatibility only and should not be set directly """ - name: str type: ToolType = field(default=ToolType.STEP_REASONING, init=False) - required: bool = False - schema_id: Optional[str] = None - feature_schema_id: Optional[str] = None - color: Optional[str] = None - definition: StepReasoningDefinition = field( - default_factory=StepReasoningDefinition + definition: _Definition = field( + default_factory=build_step_reasoning_definition ) def __post_init__(self): - warnings.warn( - "This feature is experimental and subject to change.", - ) - - def reset_regenerate_steps(self): - """ - For live models, the default acation will invoke the model to generate alternatives if a step is marked as incorrect - This method will reset the action to not regenerate the conversation - """ - self.definition.variants.incorrect_step.regenerate_steps = False - - def reset_generate_and_rate_alternative_steps(self): - """ - For live models, will require labelers to rate the alternatives generated by the model - """ - self.definition.variants.incorrect_step.generate_and_rate_alternative_steps = False - - def reset_rewrite_step(self): - """ - For live models, will require labelers to rewrite the conversation - """ - self.definition.variants.incorrect_step.rewrite_step = False - - def reset_justification(self): - """ - For live models, will require labelers to provide a justification for their evaluation - """ - self.definition.variants.incorrect_step.justification = False - - def asdict(self) -> Dict[str, Any]: - return { - "tool": self.type.value, - "name": self.name, - "required": self.required, - "schemaNodeId": self.schema_id, - "featureSchemaId": self.feature_schema_id, - "definition": self.definition.asdict(), - } - - @classmethod - def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningTool": - return cls( - name=dictionary["name"], - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None), - required=dictionary.get("required", False), - definition=StepReasoningDefinition.from_dict( - dictionary["definition"] - ), - ) + super().__post_init__() + # Set available actions for variants 0, 1, 2 'out of band' since they are not passed in the definition + self._set_variant_available_actions() + + def _set_variant_available_actions(self): + for variant in self.definition.variants: + if variant.id == 2: + for action in IncorrectStepActions: + variant._available_actions.add(action.value) diff --git a/libs/labelbox/src/labelbox/schema/tool_building/tool_type.py b/libs/labelbox/src/labelbox/schema/tool_building/tool_type.py index bbe8f231f..faab626ff 100644 --- a/libs/labelbox/src/labelbox/schema/tool_building/tool_type.py +++ b/libs/labelbox/src/labelbox/schema/tool_building/tool_type.py @@ -3,3 +3,12 @@ class ToolType(Enum): STEP_REASONING = "step-reasoning" + FACT_CHECKING = "fact-checking" + + @classmethod + def valid(cls, tool_type: str) -> bool: + try: + ToolType(tool_type.lower()) + return True + except ValueError: + return False diff --git a/libs/labelbox/src/labelbox/schema/tool_building/tool_type_mapping.py b/libs/labelbox/src/labelbox/schema/tool_building/tool_type_mapping.py new file mode 100644 index 000000000..68bfb4890 --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/tool_building/tool_type_mapping.py @@ -0,0 +1,14 @@ +from labelbox.schema.tool_building.fact_checking_tool import FactCheckingTool +from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool +from labelbox.schema.tool_building.tool_type import ToolType + + +def map_tool_type_to_tool_cls(tool_type_str: str): + if not ToolType.valid(tool_type_str): + return None + + tool_type = ToolType(tool_type_str.lower()) + if tool_type == ToolType.STEP_REASONING: + return StepReasoningTool + elif tool_type == ToolType.FACT_CHECKING: + return FactCheckingTool diff --git a/libs/labelbox/tests/integration/conftest.py b/libs/labelbox/tests/integration/conftest.py index 8e138f4a1..acc400c21 100644 --- a/libs/labelbox/tests/integration/conftest.py +++ b/libs/labelbox/tests/integration/conftest.py @@ -12,16 +12,17 @@ Classification, Client, Dataset, + FactCheckingTool, MediaType, OntologyBuilder, Option, PromptResponseClassification, ResponseOption, + StepReasoningTool, Tool, ) from labelbox.schema.data_row import DataRowMetadataField from labelbox.schema.ontology_kind import OntologyKind -from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool from labelbox.schema.user import User @@ -579,6 +580,7 @@ def chat_evaluation_ontology(client, rand_gen): name="model output multi ranking", ), StepReasoningTool(name="step reasoning"), + FactCheckingTool(name="fact checking"), ], classifications=[ Classification( diff --git a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py index 4a3e6d511..c5db9760c 100644 --- a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py +++ b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py @@ -15,7 +15,7 @@ def test_create_chat_evaluation_ontology_project( # here we are essentially testing the ontology creation which is a fixture assert ontology assert ontology.name - assert len(ontology.tools()) == 4 + assert len(ontology.tools()) == 5 for tool in ontology.tools(): assert tool.schema_id assert tool.feature_schema_id diff --git a/libs/labelbox/tests/integration/test_ontology.py b/libs/labelbox/tests/integration/test_ontology.py index 8531be310..f87197b62 100644 --- a/libs/labelbox/tests/integration/test_ontology.py +++ b/libs/labelbox/tests/integration/test_ontology.py @@ -5,6 +5,7 @@ from labelbox import MediaType, OntologyBuilder, Tool from labelbox.orm.model import Entity +from labelbox.schema.tool_building.fact_checking_tool import FactCheckingTool from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool @@ -357,25 +358,89 @@ def test_step_reasoning_ontology(chat_evaluation_ontology): step_reasoning_tool = tool break assert step_reasoning_tool is not None - assert step_reasoning_tool.definition.variants.asdict() == [ + + assert step_reasoning_tool.definition.asdict() == { + "title": "step reasoning", + "value": "step_reasoning", + "variants": [ + { + "id": 0, + "name": "Correct", + "actions": [], + }, + { + "id": 1, + "name": "Neutral", + "actions": [], + }, + { + "id": 2, + "name": "Incorrect", + "actions": [ + "regenerateSteps", + "generateAndRateAlternativeSteps", + "rewriteStep", + "justification", + ], + }, + ], + "version": 1, + } + + +def test_fact_checking_ontology(chat_evaluation_ontology): + ontology = chat_evaluation_ontology + fact_checking = None + for tool in ontology.normalized["tools"]: + if tool["tool"] == "fact-checking": + fact_checking = tool + break + assert fact_checking is not None + assert fact_checking["definition"]["variants"] == [ + {"id": 0, "name": "Accurate", "actions": ["justification"]}, + {"id": 1, "name": "Inaccurate", "actions": ["justification"]}, + {"id": 2, "name": "Disputed", "actions": ["justification"]}, + {"id": 3, "name": "Unsupported", "actions": []}, { - "id": 0, - "name": "Correct", + "id": 4, + "name": "Can't confidently assess", "actions": [], }, { - "id": 1, - "name": "Neutral", + "id": 5, + "name": "No factual information", "actions": [], }, - { - "id": 2, - "name": "Incorrect", - "actions": [ - "regenerateSteps", - "generateAndRateAlternativeSteps", - "rewriteStep", - "justification", - ], - }, ] + assert fact_checking["definition"]["version"] == 1 + assert fact_checking["schemaNodeId"] is not None + assert fact_checking["featureSchemaId"] is not None + + fact_checking = None + for tool in ontology.tools(): + if isinstance(tool, FactCheckingTool): + fact_checking = tool + break + assert fact_checking is not None + + assert fact_checking.definition.asdict() == { + "title": "fact checking", + "value": "fact_checking", + "variants": [ + {"id": 0, "name": "Accurate", "actions": ["justification"]}, + {"id": 1, "name": "Inaccurate", "actions": ["justification"]}, + {"id": 2, "name": "Disputed", "actions": ["justification"]}, + {"id": 3, "name": "Unsupported", "actions": []}, + { + "id": 4, + "name": "Can't confidently assess", + "actions": [], + }, + { + "id": 5, + "name": "No factual information", + "actions": [], + }, + ], + "version": 1, + } diff --git a/libs/labelbox/tests/unit/test_unit_fact_checking_tool.py b/libs/labelbox/tests/unit/test_unit_fact_checking_tool.py new file mode 100644 index 000000000..019f1355b --- /dev/null +++ b/libs/labelbox/tests/unit/test_unit_fact_checking_tool.py @@ -0,0 +1,43 @@ +from labelbox.schema.tool_building.fact_checking_tool import FactCheckingTool + + +def test_fact_checking_as_dict_default(): + tool = FactCheckingTool(name="Fact Checking Tool") + + # Get the dictionary representation + tool_dict = tool.asdict() + + # Expected dictionary structure + expected_dict = { + "tool": "fact-checking", + "name": "Fact Checking Tool", + "required": False, + "schemaNodeId": None, + "featureSchemaId": None, + "color": None, + "definition": { + "variants": [ + {"id": 0, "name": "Accurate", "actions": ["justification"]}, + {"id": 1, "name": "Inaccurate", "actions": ["justification"]}, + {"id": 2, "name": "Disputed", "actions": ["justification"]}, + { + "id": 3, + "name": "Unsupported", + "actions": [], + }, + { + "id": 4, + "name": "Can't confidently assess", + "actions": [], + }, + { + "id": 5, + "name": "No factual information", + "actions": [], + }, + ], + "version": 1, + }, + } + + assert tool_dict == expected_dict diff --git a/libs/labelbox/tests/unit/test_unit_step_ontology_variants.py b/libs/labelbox/tests/unit/test_unit_step_ontology_variants.py new file mode 100644 index 000000000..faddcac9f --- /dev/null +++ b/libs/labelbox/tests/unit/test_unit_step_ontology_variants.py @@ -0,0 +1,31 @@ +from labelbox.schema.tool_building.base_step_reasoning_tool import _Variant + + +def test_variant(): + variant = _Variant( + id=0, name="Correct", _available_actions={"regenerateSteps"} + ) + variant.set_actions(["regenerateSteps"]) + assert variant.asdict() == { + "id": 0, + "name": "Correct", + "actions": ["regenerateSteps"], + } + + assert variant._available_actions == {"regenerateSteps"} + variant.reset_actions() + assert variant.asdict() == { + "id": 0, + "name": "Correct", + "actions": [], + } + + +def test_variant_actions(): + variant = _Variant( + id=0, name="Correct", _available_actions={"regenerateSteps"} + ) + variant.set_actions(["regenerateSteps"]) + assert variant.actions == ["regenerateSteps"] + variant.set_actions(["invalidAction"]) + assert variant.actions == [] diff --git a/libs/labelbox/tests/unit/test_unit_step_reasoning_tool.py b/libs/labelbox/tests/unit/test_unit_step_reasoning_tool.py index 1bb10e672..825aeaa74 100644 --- a/libs/labelbox/tests/unit/test_unit_step_reasoning_tool.py +++ b/libs/labelbox/tests/unit/test_unit_step_reasoning_tool.py @@ -1,14 +1,31 @@ -from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool +import pytest + +from labelbox.schema.tool_building.step_reasoning_tool import ( + StepReasoningTool, +) + + +def test_validations(): + with pytest.raises(ValueError): + StepReasoningTool(name="") def test_step_reasoning_as_dict_default(): tool = StepReasoningTool(name="step reasoning") + assert tool.definition.variants[2]._available_actions == { + "regenerateSteps", + "generateAndRateAlternativeSteps", + "rewriteStep", + "justification", + } + assert tool.asdict() == { "tool": "step-reasoning", "name": "step reasoning", "required": False, "schemaNodeId": None, "featureSchemaId": None, + "color": None, "definition": { "variants": [ {"id": 0, "name": "Correct", "actions": []}, @@ -29,18 +46,13 @@ def test_step_reasoning_as_dict_default(): } -def test_step_reasoning_as_dict_with_actions(): - tool = StepReasoningTool(name="step reasoning") - tool.reset_generate_and_rate_alternative_steps() - tool.reset_regenerate_steps() - tool.reset_rewrite_step() - tool.reset_justification() - assert tool.asdict() == { - "tool": "step-reasoning", - "name": "step reasoning", +def test_from_dict(): + dict = { + "schemaNodeId": "cm3pdkupv0ah8070h2ujo74th", + "featureSchemaId": "cm3pdkupv0ah7070hg7svdeeo", "required": False, - "schemaNodeId": None, - "featureSchemaId": None, + "name": "step reasoning", + "tool": "step-reasoning", "definition": { "variants": [ {"id": 0, "name": "Correct", "actions": []}, @@ -48,9 +60,28 @@ def test_step_reasoning_as_dict_with_actions(): { "id": 2, "name": "Incorrect", - "actions": [], + "actions": [ + "regenerateSteps", + "generateAndRateAlternativeSteps", + "rewriteStep", + "justification", + ], }, ], "version": 1, + "title": "step reasoning", + "value": "step_reasoning", + "color": "#ff0000", }, + "color": "#ff0000", + "archived": 0, + "classifications": [], + "kind": "StepReasoning", + } + tool = StepReasoningTool.from_dict(dict) + assert tool.definition.variants[2]._available_actions == { + "generateAndRateAlternativeSteps", + "justification", + "rewriteStep", + "regenerateSteps", }