diff --git a/docs/labelbox/index.rst b/docs/labelbox/index.rst index fa694119b..15ff9a0a9 100644 --- a/docs/labelbox/index.rst +++ b/docs/labelbox/index.rst @@ -46,6 +46,7 @@ Labelbox Python SDK Documentation search-filters send-to-annotate-params slice + step_reasoning_tool task task-queue user diff --git a/docs/labelbox/step_reasoning_tool.rst b/docs/labelbox/step_reasoning_tool.rst new file mode 100644 index 000000000..b363589e8 --- /dev/null +++ b/docs/labelbox/step_reasoning_tool.rst @@ -0,0 +1,6 @@ +Step Reasoning Tool +=============================================================================================== + +.. automodule:: labelbox.schema.tool_building.step_reasoning_tool + :members: + :show-inheritance: \ No newline at end of file diff --git a/libs/labelbox/mypy.ini b/libs/labelbox/mypy.ini index a9b715cf9..b09c45d33 100644 --- a/libs/labelbox/mypy.ini +++ b/libs/labelbox/mypy.ini @@ -12,5 +12,5 @@ ignore_errors = True [mypy-lbox.exceptions] ignore_missing_imports = True -[mypy-lbox.call_info"] +[mypy-lbox.call_info] ignore_missing_imports = True diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 52a044a3b..d4376d9b4 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -56,7 +56,7 @@ FeatureSchema, Ontology, PromptResponseClassification, - Tool, + tool_type_cls_from_type, ) from labelbox.schema.ontology_kind import ( EditorTaskType, @@ -1098,7 +1098,8 @@ def create_ontology_from_feature_schemas( if "tool" in feature_schema.normalized: tool = feature_schema.normalized["tool"] try: - Tool.Type(tool) + tool_type_cls = tool_type_cls_from_type(tool) + tool_type_cls(tool) tools.append(feature_schema.normalized) except ValueError: raise ValueError( diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index a3b388ef2..3acd9e1e2 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -12,6 +12,8 @@ from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field, Relationship +from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool +from labelbox.schema.tool_building.tool_type import ToolType FeatureSchemaId: Type[str] = Annotated[ str, StringConstraints(min_length=25, max_length=25) @@ -187,7 +189,7 @@ def __post_init__(self): @classmethod def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: return cls( - class_type=cls.Type(dictionary["type"]), + class_type=Classification.Type(dictionary["type"]), name=dictionary["name"], instructions=dictionary["instructions"], required=dictionary.get("required", False), @@ -351,7 +353,7 @@ class Type(Enum): @classmethod def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: return cls( - class_type=cls.Type(dictionary["type"]), + class_type=PromptResponseClassification.Type(dictionary["type"]), name=dictionary["name"], instructions=dictionary["instructions"], required=True, # always required @@ -458,7 +460,7 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: schema_id=dictionary.get("schemaNodeId", None), feature_schema_id=dictionary.get("featureSchemaId", None), required=dictionary.get("required", False), - tool=cls.Type(dictionary["tool"]), + tool=Tool.Type(dictionary["tool"]), classifications=[ Classification.from_dict(c) for c in dictionary["classifications"] @@ -488,6 +490,18 @@ def add_classification(self, classification: Classification) -> None: self.classifications.append(classification) +def tool_cls_from_type(tool_type: str): + if tool_type.lower() == ToolType.STEP_REASONING.value: + return StepReasoningTool + return Tool + + +def tool_type_cls_from_type(tool_type: str): + if tool_type.lower() == ToolType.STEP_REASONING.value: + return ToolType + return Tool.Type + + class Ontology(DbObject): """An ontology specifies which tools and classifications are available to a project. This is read only for now. @@ -525,7 +539,8 @@ def tools(self) -> List[Tool]: """Get list of tools (AKA objects) in an Ontology.""" if self._tools is None: self._tools = [ - Tool.from_dict(tool) for tool in self.normalized["tools"] + tool_cls_from_type(tool["tool"]).from_dict(tool) + for tool in self.normalized["tools"] ] return self._tools @@ -581,7 +596,7 @@ class OntologyBuilder: """ - tools: List[Tool] = field(default_factory=list) + tools: List[Union[Tool, StepReasoningTool]] = field(default_factory=list) classifications: List[ Union[Classification, PromptResponseClassification] ] = field(default_factory=list) diff --git a/libs/labelbox/src/labelbox/schema/tool_building/__init__.py b/libs/labelbox/src/labelbox/schema/tool_building/__init__.py new file mode 100644 index 000000000..45098ef84 --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/tool_building/__init__.py @@ -0,0 +1,2 @@ +import labelbox.schema.tool_building.tool_type +import labelbox.schema.tool_building.step_reasoning_tool diff --git a/libs/labelbox/src/labelbox/schema/tool_building/step_reasoning_tool.py b/libs/labelbox/src/labelbox/schema/tool_building/step_reasoning_tool.py new file mode 100644 index 000000000..7b0536cec --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/tool_building/step_reasoning_tool.py @@ -0,0 +1,200 @@ +import warnings +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from labelbox.schema.tool_building.tool_type import ToolType + + +@dataclass +class StepReasoningVariant: + id: int + name: str + + def asdict(self) -> Dict[str, Any]: + return {"id": self.id, "name": self.name} + + +@dataclass +class IncorrectStepReasoningVariant: + id: int + name: str + regenerate_conversations_after_incorrect_step: Optional[bool] = True + rate_alternative_responses: Optional[bool] = False + + def asdict(self) -> Dict[str, Any]: + actions = [] + if self.regenerate_conversations_after_incorrect_step: + actions.append("regenerateSteps") + if self.rate_alternative_responses: + actions.append("generateAndRateAlternativeSteps") + return {"id": self.id, "name": self.name, "actions": actions} + + @classmethod + def from_dict( + cls, dictionary: Dict[str, Any] + ) -> "IncorrectStepReasoningVariant": + return cls( + id=dictionary["id"], + name=dictionary["name"], + regenerate_conversations_after_incorrect_step="regenerateSteps" + in dictionary.get("actions", []), + rate_alternative_responses="generateAndRateAlternativeSteps" + in dictionary.get("actions", []), + ) + + +def _create_correct_step() -> StepReasoningVariant: + return StepReasoningVariant( + id=StepReasoningVariants.CORRECT_STEP_ID, name="Correct" + ) + + +def _create_neutral_step() -> StepReasoningVariant: + return StepReasoningVariant( + id=StepReasoningVariants.NEUTRAL_STEP_ID, name="Neutral" + ) + + +def _create_incorrect_step() -> IncorrectStepReasoningVariant: + return IncorrectStepReasoningVariant( + id=StepReasoningVariants.INCORRECT_STEP_ID, name="Incorrect" + ) + + +@dataclass +class StepReasoningVariants: + """ + This class is used to define the possible options for evaluating a step + Currently the options are correct, neutral, and incorrect + """ + + CORRECT_STEP_ID = 0 + NEUTRAL_STEP_ID = 1 + INCORRECT_STEP_ID = 2 + + correct_step: StepReasoningVariant = field( + default_factory=_create_correct_step + ) + neutral_step: StepReasoningVariant = field( + default_factory=_create_neutral_step + ) + incorrect_step: IncorrectStepReasoningVariant = field( + default_factory=_create_incorrect_step + ) + + def asdict(self): + return [ + self.correct_step.asdict(), + self.neutral_step.asdict(), + self.incorrect_step.asdict(), + ] + + @classmethod + def from_dict(cls, dictionary: List[Dict[str, Any]]): + correct_step = None + neutral_step = None + incorrect_step = None + + for variant in dictionary: + if variant["id"] == cls.CORRECT_STEP_ID: + correct_step = StepReasoningVariant(**variant) + elif variant["id"] == cls.NEUTRAL_STEP_ID: + neutral_step = StepReasoningVariant(**variant) + elif variant["id"] == cls.INCORRECT_STEP_ID: + incorrect_step = IncorrectStepReasoningVariant.from_dict( + variant + ) + + if not all([correct_step, neutral_step, incorrect_step]): + raise ValueError("Invalid step reasoning variants") + + return cls( + correct_step=correct_step, # type: ignore + neutral_step=neutral_step, # type: ignore + incorrect_step=incorrect_step, # type: ignore + ) + + +@dataclass +class StepReasoningDefinition: + variants: StepReasoningVariants = field( + default_factory=StepReasoningVariants + ) + version: int = field(default=1) + title: Optional[str] = None + value: Optional[str] = None + + def asdict(self) -> Dict[str, Any]: + result = {"variants": self.variants.asdict(), "version": self.version} + if self.title is not None: + result["title"] = self.title + if self.value is not None: + result["value"] = self.value + return result + + @classmethod + def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningDefinition": + variants = StepReasoningVariants.from_dict(dictionary["variants"]) + title = dictionary.get("title", None) + value = dictionary.get("value", None) + return cls(variants=variants, title=title, value=value) + + +@dataclass +class StepReasoningTool: + """ + Use this class in OntologyBuilder to create a tool for step reasoning + The definition field lists the possible options to evaulate a step + """ + + name: str + type: ToolType = field(default=ToolType.STEP_REASONING, init=False) + required: bool = False + schema_id: Optional[str] = None + feature_schema_id: Optional[str] = None + color: Optional[str] = None + definition: StepReasoningDefinition = field( + default_factory=StepReasoningDefinition + ) + + def __post_init__(self): + warnings.warn( + "This feature is experimental and subject to change.", + ) + + def reset_regenerate_conversations_after_incorrect_step(self): + """ + For live models, the default acation will invoke the model to generate alternatives if a step is marked as incorrect + This method will reset the action to not regenerate the conversation + """ + self.definition.variants.incorrect_step.regenerate_conversations_after_incorrect_step = False + + def set_rate_alternative_responses(self): + """ + For live models, will require labelers to rate the alternatives generated by the model + """ + self.definition.variants.incorrect_step.rate_alternative_responses = ( + True + ) + + def asdict(self) -> Dict[str, Any]: + return { + "tool": self.type.value, + "name": self.name, + "required": self.required, + "schemaNodeId": self.schema_id, + "featureSchemaId": self.feature_schema_id, + "definition": self.definition.asdict(), + } + + @classmethod + def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningTool": + return cls( + name=dictionary["name"], + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None), + required=dictionary.get("required", False), + definition=StepReasoningDefinition.from_dict( + dictionary["definition"] + ), + ) diff --git a/libs/labelbox/src/labelbox/schema/tool_building/tool_type.py b/libs/labelbox/src/labelbox/schema/tool_building/tool_type.py new file mode 100644 index 000000000..bbe8f231f --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/tool_building/tool_type.py @@ -0,0 +1,5 @@ +from enum import Enum + + +class ToolType(Enum): + STEP_REASONING = "step-reasoning" diff --git a/libs/labelbox/tests/integration/conftest.py b/libs/labelbox/tests/integration/conftest.py index c248bf67e..8e138f4a1 100644 --- a/libs/labelbox/tests/integration/conftest.py +++ b/libs/labelbox/tests/integration/conftest.py @@ -21,6 +21,7 @@ ) from labelbox.schema.data_row import DataRowMetadataField from labelbox.schema.ontology_kind import OntologyKind +from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool from labelbox.schema.user import User @@ -562,6 +563,7 @@ def feature_schema(client, point): @pytest.fixture def chat_evaluation_ontology(client, rand_gen): ontology_name = f"test-chat-evaluation-ontology-{rand_gen(str)}" + ontology_builder = OntologyBuilder( tools=[ Tool( @@ -576,6 +578,7 @@ def chat_evaluation_ontology(client, rand_gen): tool=Tool.Type.MESSAGE_RANKING, name="model output multi ranking", ), + StepReasoningTool(name="step reasoning"), ], classifications=[ Classification( @@ -626,14 +629,12 @@ def chat_evaluation_ontology(client, rand_gen): ), ], ) - ontology = client.create_ontology( ontology_name, ontology_builder.asdict(), media_type=MediaType.Conversational, ontology_kind=OntologyKind.ModelEvaluation, ) - yield ontology try: diff --git a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py index 2c02b77ac..bde58808b 100644 --- a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py +++ b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py @@ -15,7 +15,7 @@ def test_create_chat_evaluation_ontology_project( # here we are essentially testing the ontology creation which is a fixture assert ontology assert ontology.name - assert len(ontology.tools()) == 3 + assert len(ontology.tools()) == 4 for tool in ontology.tools(): assert tool.schema_id assert tool.feature_schema_id @@ -41,7 +41,7 @@ def test_create_chat_evaluation_ontology_project( def test_create_chat_evaluation_ontology_project_existing_dataset( - client, chat_evaluation_ontology, chat_evaluation_project_append_to_dataset + chat_evaluation_ontology, chat_evaluation_project_append_to_dataset ): ontology = chat_evaluation_ontology @@ -83,6 +83,29 @@ def tools_json(): "schemaNodeId": None, "featureSchemaId": None, }, + { + "tool": "step-reasoning", + "name": "step reasoning", + "required": True, + "schemaNodeId": None, + "featureSchemaId": None, + "color": "#0000ff", + "definition": { + "variants": [ + {"id": 0, "name": "Correct"}, + {"id": 1, "name": "Neutral"}, + { + "id": 2, + "name": "Incorrect", + "actions": [ + "regenerateSteps", + "generateAndRateAlternativeSteps", + ], + }, + ], + "version": 1, + }, + }, ] return tools diff --git a/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py b/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py index bb1756afb..8cc8ebcb3 100644 --- a/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py +++ b/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py @@ -2,7 +2,6 @@ def test_create_offline_chat_evaluation_project( - client, rand_gen, offline_chat_evaluation_project, chat_evaluation_ontology, diff --git a/libs/labelbox/tests/integration/test_ontology.py b/libs/labelbox/tests/integration/test_ontology.py index c7c7c270c..acb4e7bb1 100644 --- a/libs/labelbox/tests/integration/test_ontology.py +++ b/libs/labelbox/tests/integration/test_ontology.py @@ -5,6 +5,7 @@ from labelbox import MediaType, OntologyBuilder, Tool from labelbox.orm.model import Entity +from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool def test_feature_schema_is_not_archived(client, ontology): @@ -322,3 +323,47 @@ def test_unarchive_feature_schema_node_for_non_existing_ontology( client.unarchive_feature_schema_node( "invalid-ontology", feature_schema_to_unarchive["featureSchemaId"] ) + + +def test_step_reasoning_ontology(chat_evaluation_ontology): + ontology = chat_evaluation_ontology + step_reasoning_tool = None + for tool in ontology.normalized["tools"]: + if tool["tool"] == "step-reasoning": + step_reasoning_tool = tool + break + assert step_reasoning_tool is not None + assert step_reasoning_tool["definition"]["variants"] == [ + {"id": 0, "name": "Correct"}, + {"id": 1, "name": "Neutral"}, + { + "id": 2, + "name": "Incorrect", + "actions": ["regenerateSteps"], + }, + ] + assert step_reasoning_tool["definition"]["version"] == 1 + assert step_reasoning_tool["schemaNodeId"] is not None + assert step_reasoning_tool["featureSchemaId"] is not None + + step_reasoning_tool = None + for tool in ontology.tools(): + if isinstance(tool, StepReasoningTool): + step_reasoning_tool = tool + break + assert step_reasoning_tool is not None + assert step_reasoning_tool.definition.variants.asdict() == [ + { + "id": 0, + "name": "Correct", + }, + { + "id": 1, + "name": "Neutral", + }, + { + "id": 2, + "name": "Incorrect", + "actions": ["regenerateSteps"], + }, + ] diff --git a/libs/labelbox/tests/unit/test_unit_step_reasoning_tool.py b/libs/labelbox/tests/unit/test_unit_step_reasoning_tool.py new file mode 100644 index 000000000..4d6986b86 --- /dev/null +++ b/libs/labelbox/tests/unit/test_unit_step_reasoning_tool.py @@ -0,0 +1,51 @@ +from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool + + +def test_step_reasoning_as_dict_default(): + tool = StepReasoningTool(name="step reasoning") + assert tool.asdict() == { + "tool": "step-reasoning", + "name": "step reasoning", + "required": False, + "schemaNodeId": None, + "featureSchemaId": None, + "definition": { + "variants": [ + {"id": 0, "name": "Correct"}, + {"id": 1, "name": "Neutral"}, + { + "id": 2, + "name": "Incorrect", + "actions": ["regenerateSteps"], + }, + ], + "version": 1, + }, + } + + +def test_step_reasoning_as_dict_with_actions(): + tool = StepReasoningTool(name="step reasoning") + tool.set_rate_alternative_responses() + tool.reset_regenerate_conversations_after_incorrect_step() + assert tool.asdict() == { + "tool": "step-reasoning", + "name": "step reasoning", + "required": False, + "schemaNodeId": None, + "featureSchemaId": None, + "definition": { + "variants": [ + {"id": 0, "name": "Correct"}, + {"id": 1, "name": "Neutral"}, + { + "id": 2, + "name": "Incorrect", + "actions": [ + "generateAndRateAlternativeSteps", + ], + }, + ], + "version": 1, + }, + }