Skip to content

[PLT-1684] Vb/step reasoning ontology plt 1684 #1879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/labelbox/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Labelbox Python SDK Documentation
search-filters
send-to-annotate-params
slice
step_reasoning_tool
task
task-queue
user
Expand Down
6 changes: 6 additions & 0 deletions docs/labelbox/step_reasoning_tool.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Step Reasoning Tool
===============================================================================================

.. automodule:: labelbox.schema.tool_building.step_reasoning_tool
:members:
:show-inheritance:
5 changes: 3 additions & 2 deletions libs/labelbox/src/labelbox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
FeatureSchema,
Ontology,
PromptResponseClassification,
Tool,
tool_type_cls_from_type,
)
from labelbox.schema.ontology_kind import (
EditorTaskType,
Expand Down Expand Up @@ -1098,7 +1098,8 @@ def create_ontology_from_feature_schemas(
if "tool" in feature_schema.normalized:
tool = feature_schema.normalized["tool"]
try:
Tool.Type(tool)
tool_type_cls = tool_type_cls_from_type(tool)
tool_type_cls(tool)
tools.append(feature_schema.normalized)
except ValueError:
raise ValueError(
Expand Down
25 changes: 20 additions & 5 deletions libs/labelbox/src/labelbox/schema/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from labelbox.orm.db_object import DbObject
from labelbox.orm.model import Field, Relationship
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
from labelbox.schema.tool_building.tool_type import ToolType

FeatureSchemaId: Type[str] = Annotated[
str, StringConstraints(min_length=25, max_length=25)
Expand Down Expand Up @@ -187,7 +189,7 @@ def __post_init__(self):
@classmethod
def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
return cls(
class_type=cls.Type(dictionary["type"]),
class_type=Classification.Type(dictionary["type"]),
name=dictionary["name"],
instructions=dictionary["instructions"],
required=dictionary.get("required", False),
Expand Down Expand Up @@ -351,7 +353,7 @@ class Type(Enum):
@classmethod
def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
return cls(
class_type=cls.Type(dictionary["type"]),
class_type=PromptResponseClassification.Type(dictionary["type"]),
name=dictionary["name"],
instructions=dictionary["instructions"],
required=True, # always required
Expand Down Expand Up @@ -458,7 +460,7 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
schema_id=dictionary.get("schemaNodeId", None),
feature_schema_id=dictionary.get("featureSchemaId", None),
required=dictionary.get("required", False),
tool=cls.Type(dictionary["tool"]),
tool=Tool.Type(dictionary["tool"]),
classifications=[
Classification.from_dict(c)
for c in dictionary["classifications"]
Expand Down Expand Up @@ -488,6 +490,18 @@ def add_classification(self, classification: Classification) -> None:
self.classifications.append(classification)


def tool_cls_from_type(tool_type: str):
if tool_type.lower() == ToolType.STEP_REASONING.value:
return StepReasoningTool
return Tool


def tool_type_cls_from_type(tool_type: str):
if tool_type.lower() == ToolType.STEP_REASONING.value:
return ToolType
return Tool.Type


class Ontology(DbObject):
"""An ontology specifies which tools and classifications are available
to a project. This is read only for now.
Expand Down Expand Up @@ -525,7 +539,8 @@ def tools(self) -> List[Tool]:
"""Get list of tools (AKA objects) in an Ontology."""
if self._tools is None:
self._tools = [
Tool.from_dict(tool) for tool in self.normalized["tools"]
tool_cls_from_type(tool["tool"]).from_dict(tool)
for tool in self.normalized["tools"]
]
return self._tools

Expand Down Expand Up @@ -581,7 +596,7 @@ class OntologyBuilder:

"""

tools: List[Tool] = field(default_factory=list)
tools: List[Union[Tool, StepReasoningTool]] = field(default_factory=list)
classifications: List[
Union[Classification, PromptResponseClassification]
] = field(default_factory=list)
Expand Down
2 changes: 2 additions & 0 deletions libs/labelbox/src/labelbox/schema/tool_building/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import labelbox.schema.tool_building.tool_type
import labelbox.schema.tool_building.step_reasoning_tool
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

from labelbox.schema.tool_building.tool_type import ToolType


@dataclass
class StepReasoningVariant:
id: int
name: str

def asdict(self) -> Dict[str, Any]:
return {"id": self.id, "name": self.name}


@dataclass
class IncorrectStepReasoningVariant:
id: int
name: str
regenerate_conversations_after_incorrect_step: Optional[bool] = True
rate_alternative_responses: Optional[bool] = False

def asdict(self) -> Dict[str, Any]:
actions = []
if self.regenerate_conversations_after_incorrect_step:
actions.append("regenerateSteps")
if self.rate_alternative_responses:
actions.append("generateAndRateAlternativeSteps")
return {"id": self.id, "name": self.name, "actions": actions}

@classmethod
def from_dict(
cls, dictionary: Dict[str, Any]
) -> "IncorrectStepReasoningVariant":
return cls(
id=dictionary["id"],
name=dictionary["name"],
regenerate_conversations_after_incorrect_step="regenerateSteps"
in dictionary.get("actions", []),
rate_alternative_responses="generateAndRateAlternativeSteps"
in dictionary.get("actions", []),
)


def _create_correct_step() -> StepReasoningVariant:
return StepReasoningVariant(
id=StepReasoningVariants.CORRECT_STEP_ID, name="Correct"
)


def _create_neutral_step() -> StepReasoningVariant:
return StepReasoningVariant(
id=StepReasoningVariants.NEUTRAL_STEP_ID, name="Neutral"
)


def _create_incorrect_step() -> IncorrectStepReasoningVariant:
return IncorrectStepReasoningVariant(
id=StepReasoningVariants.INCORRECT_STEP_ID, name="Incorrect"
)


@dataclass
class StepReasoningVariants:
"""
This class is used to define the possible options for evaluating a step
Currently the options are correct, neutral, and incorrect
"""

CORRECT_STEP_ID = 0
NEUTRAL_STEP_ID = 1
INCORRECT_STEP_ID = 2

correct_step: StepReasoningVariant = field(
default_factory=_create_correct_step
)
neutral_step: StepReasoningVariant = field(
default_factory=_create_neutral_step
)
incorrect_step: IncorrectStepReasoningVariant = field(
default_factory=_create_incorrect_step
)

def asdict(self):
return [
self.correct_step.asdict(),
self.neutral_step.asdict(),
self.incorrect_step.asdict(),
]

@classmethod
def from_dict(cls, dictionary: List[Dict[str, Any]]):
correct_step = None
neutral_step = None
incorrect_step = None

for variant in dictionary:
if variant["id"] == cls.CORRECT_STEP_ID:
correct_step = StepReasoningVariant(**variant)
elif variant["id"] == cls.NEUTRAL_STEP_ID:
neutral_step = StepReasoningVariant(**variant)
elif variant["id"] == cls.INCORRECT_STEP_ID:
incorrect_step = IncorrectStepReasoningVariant.from_dict(
variant
)

if not all([correct_step, neutral_step, incorrect_step]):
raise ValueError("Invalid step reasoning variants")

return cls(
correct_step=correct_step, # type: ignore
neutral_step=neutral_step, # type: ignore
incorrect_step=incorrect_step, # type: ignore
)


@dataclass
class StepReasoningDefinition:
variants: StepReasoningVariants = field(
default_factory=StepReasoningVariants
)
version: int = field(default=1)
title: Optional[str] = None
value: Optional[str] = None

def asdict(self) -> Dict[str, Any]:
result = {"variants": self.variants.asdict(), "version": self.version}
if self.title is not None:
result["title"] = self.title
if self.value is not None:
result["value"] = self.value
return result

@classmethod
def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningDefinition":
variants = StepReasoningVariants.from_dict(dictionary["variants"])
title = dictionary.get("title", None)
value = dictionary.get("value", None)
return cls(variants=variants, title=title, value=value)


@dataclass
class StepReasoningTool:
"""
Use this class in OntologyBuilder to create a tool for step reasoning
The definition field lists the possible options to evaulate a step
"""

name: str
type: ToolType = field(default=ToolType.STEP_REASONING, init=False)
required: bool = False
schema_id: Optional[str] = None
feature_schema_id: Optional[str] = None
color: Optional[str] = None
definition: StepReasoningDefinition = field(
default_factory=StepReasoningDefinition
)

def reset_regenerate_conversations_after_incorrect_step(self):
"""
For live models, the default acation will invoke the model to generate alternatives if a step is marked as incorrect
This method will reset the action to not regenerate the conversation
"""
self.definition.variants.incorrect_step.regenerate_conversations_after_incorrect_step = False

def set_rate_alternative_responses(self):
"""
For live models, will require labelers to rate the alternatives generated by the model
"""
self.definition.variants.incorrect_step.rate_alternative_responses = (
True
)

def asdict(self) -> Dict[str, Any]:
return {
"tool": self.type.value,
"name": self.name,
"required": self.required,
"schemaNodeId": self.schema_id,
"featureSchemaId": self.feature_schema_id,
"definition": self.definition.asdict(),
}

@classmethod
def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningTool":
return cls(
name=dictionary["name"],
schema_id=dictionary.get("schemaNodeId", None),
feature_schema_id=dictionary.get("featureSchemaId", None),
required=dictionary.get("required", False),
definition=StepReasoningDefinition.from_dict(
dictionary["definition"]
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from enum import Enum


class ToolType(Enum):
STEP_REASONING = "step-reasoning"
6 changes: 4 additions & 2 deletions libs/labelbox/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
)
from labelbox.schema.data_row import DataRowMetadataField
from labelbox.schema.ontology_kind import OntologyKind
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
from labelbox.schema.tool_building.tool_type import ToolType
from labelbox.schema.user import User


Expand Down Expand Up @@ -562,6 +564,7 @@ def feature_schema(client, point):
@pytest.fixture
def chat_evaluation_ontology(client, rand_gen):
ontology_name = f"test-chat-evaluation-ontology-{rand_gen(str)}"

ontology_builder = OntologyBuilder(
tools=[
Tool(
Expand All @@ -576,6 +579,7 @@ def chat_evaluation_ontology(client, rand_gen):
tool=Tool.Type.MESSAGE_RANKING,
name="model output multi ranking",
),
StepReasoningTool(name="step reasoning"),
],
classifications=[
Classification(
Expand Down Expand Up @@ -626,14 +630,12 @@ def chat_evaluation_ontology(client, rand_gen):
),
],
)

ontology = client.create_ontology(
ontology_name,
ontology_builder.asdict(),
media_type=MediaType.Conversational,
ontology_kind=OntologyKind.ModelEvaluation,
)

yield ontology

try:
Expand Down
Loading
Loading