Skip to content

Commit b34b11c

Browse files
author
Val Brodsky
committed
Update to support create_ontology_from_feature_schemas
1 parent 9bea518 commit b34b11c

File tree

9 files changed

+181
-41
lines changed

9 files changed

+181
-41
lines changed

libs/labelbox/src/labelbox/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
FeatureSchema,
5858
Ontology,
5959
PromptResponseClassification,
60-
Tool,
60+
tool_type_cls_from_type,
6161
)
6262
from labelbox.schema.ontology_kind import (
6363
EditorTaskType,
@@ -1106,7 +1106,8 @@ def create_ontology_from_feature_schemas(
11061106
if "tool" in feature_schema.normalized:
11071107
tool = feature_schema.normalized["tool"]
11081108
try:
1109-
Tool.Type(tool)
1109+
tool_type_cls = tool_type_cls_from_type(tool)
1110+
tool_type_cls(tool)
11101111
tools.append(feature_schema.normalized)
11111112
except ValueError:
11121113
raise ValueError(

libs/labelbox/src/labelbox/schema/ontology.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ class Type(Enum):
353353
@classmethod
354354
def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
355355
return cls(
356-
class_type=Type(dictionary["type"]),
356+
class_type=PromptResponseClassification.Type(dictionary["type"]),
357357
name=dictionary["name"],
358358
instructions=dictionary["instructions"],
359359
required=True, # always required
@@ -492,14 +492,16 @@ def add_classification(self, classification: Classification) -> None:
492492

493493
def tool_cls_from_type(tool_type: str):
494494
if tool_type.lower() == ToolType.STEP_REASONING.value:
495-
from labelbox.schema.tool_building.step_reasoning_tool import (
496-
StepReasoningTool,
497-
)
498-
499495
return StepReasoningTool
500496
return Tool
501497

502498

499+
def tool_type_cls_from_type(tool_type: str):
500+
if tool_type.lower() == ToolType.STEP_REASONING.value:
501+
return ToolType
502+
return Tool.Type
503+
504+
503505
class Ontology(DbObject):
504506
"""An ontology specifies which tools and classifications are available
505507
to a project. This is read only for now.
@@ -533,12 +535,6 @@ def __init__(self, *args, **kwargs) -> None:
533535
Union[List[Classification], List[PromptResponseClassification]]
534536
] = None
535537

536-
def _tool_deserializer_cls(self, tool: Dict[str, Any]) -> Tool:
537-
import pdb
538-
539-
pdb.set_trace()
540-
return Tool
541-
542538
def tools(self) -> List[Tool]:
543539
"""Get list of tools (AKA objects) in an Ontology."""
544540
if self._tools is None:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
import labelbox.schema.tool_building.tool_type
2-
import labelbox.schema.tool_building.step_reasoning_tool
2+
import labelbox.schema.tool_building.step_reasoning_tool

libs/labelbox/src/labelbox/schema/tool_building/step_reasoning_tool.py

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Any, Dict, Optional
2+
from typing import Any, Dict, List, Optional
33

44
from labelbox.schema.tool_building.tool_type import ToolType
55

@@ -25,9 +25,40 @@ def asdict(self) -> Dict[str, Any]:
2525
if self.regenerate_conversations_after_incorrect_step:
2626
actions.append("regenerateSteps")
2727
if self.rate_alternative_responses:
28-
actions.append("generateAlternatives")
28+
actions.append("generateAndRateAlternativeSteps")
2929
return {"id": self.id, "name": self.name, "actions": actions}
3030

31+
@classmethod
32+
def from_dict(
33+
cls, dictionary: Dict[str, Any]
34+
) -> "IncorrectStepReasoningVariant":
35+
return cls(
36+
id=dictionary["id"],
37+
name=dictionary["name"],
38+
regenerate_conversations_after_incorrect_step="regenerateSteps"
39+
in dictionary.get("actions", []),
40+
rate_alternative_responses="generateAndRateAlternativeSteps"
41+
in dictionary.get("actions", []),
42+
)
43+
44+
45+
def _create_correct_step() -> StepReasoningVariant:
46+
return StepReasoningVariant(
47+
id=StepReasoningVariants.CORRECT_STEP_ID, name="Correct"
48+
)
49+
50+
51+
def _create_neutral_step() -> StepReasoningVariant:
52+
return StepReasoningVariant(
53+
id=StepReasoningVariants.NEUTRAL_STEP_ID, name="Neutral"
54+
)
55+
56+
57+
def _create_incorrect_step() -> IncorrectStepReasoningVariant:
58+
return IncorrectStepReasoningVariant(
59+
id=StepReasoningVariants.INCORRECT_STEP_ID, name="Incorrect"
60+
)
61+
3162

3263
@dataclass
3364
class StepReasoningVariants:
@@ -36,13 +67,13 @@ class StepReasoningVariants:
3667
INCORRECT_STEP_ID = 2
3768

3869
correct_step: StepReasoningVariant = field(
39-
default=StepReasoningVariant(CORRECT_STEP_ID, "Correct"), init=False
70+
default_factory=_create_correct_step
4071
)
4172
neutral_step: StepReasoningVariant = field(
42-
default=StepReasoningVariant(NEUTRAL_STEP_ID, "Neutral"), init=False
73+
default_factory=_create_neutral_step
4374
)
4475
incorrect_step: IncorrectStepReasoningVariant = field(
45-
default=IncorrectStepReasoningVariant(INCORRECT_STEP_ID, "Incorrect"),
76+
default_factory=_create_incorrect_step
4677
)
4778

4879
def asdict(self):
@@ -52,6 +83,31 @@ def asdict(self):
5283
self.incorrect_step.asdict(),
5384
]
5485

86+
@classmethod
87+
def from_dict(cls, dictionary: List[Dict[str, Any]]):
88+
correct_step = None
89+
neutral_step = None
90+
incorrect_step = None
91+
92+
for variant in dictionary:
93+
if variant["id"] == cls.CORRECT_STEP_ID:
94+
correct_step = StepReasoningVariant(**variant)
95+
elif variant["id"] == cls.NEUTRAL_STEP_ID:
96+
neutral_step = StepReasoningVariant(**variant)
97+
elif variant["id"] == cls.INCORRECT_STEP_ID:
98+
incorrect_step = IncorrectStepReasoningVariant.from_dict(
99+
variant
100+
)
101+
102+
if not all([correct_step, neutral_step, incorrect_step]):
103+
raise ValueError("Invalid step reasoning variants")
104+
105+
return cls(
106+
correct_step=correct_step, # type: ignore
107+
neutral_step=neutral_step, # type: ignore
108+
incorrect_step=incorrect_step, # type: ignore
109+
)
110+
55111

56112
@dataclass
57113
class StepReasoningDefinition:
@@ -61,18 +117,22 @@ class StepReasoningDefinition:
61117
version: int = field(default=1)
62118
title: Optional[str] = None
63119
value: Optional[str] = None
64-
color: Optional[str] = None
65120

66121
def asdict(self) -> Dict[str, Any]:
67122
result = {"variants": self.variants.asdict(), "version": self.version}
68123
if self.title is not None:
69124
result["title"] = self.title
70125
if self.value is not None:
71126
result["value"] = self.value
72-
if self.color is not None:
73-
result["color"] = self.color
74127
return result
75128

129+
@classmethod
130+
def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningDefinition":
131+
variants = StepReasoningVariants.from_dict(dictionary["variants"])
132+
title = dictionary.get("title", None)
133+
value = dictionary.get("value", None)
134+
return cls(variants=variants, title=title, value=value)
135+
76136

77137
@dataclass
78138
class StepReasoningTool:
@@ -113,5 +173,7 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningTool":
113173
schema_id=dictionary.get("schemaNodeId", None),
114174
feature_schema_id=dictionary.get("featureSchemaId", None),
115175
required=dictionary.get("required", False),
116-
definition=StepReasoningDefinition(**dictionary["definition"]),
176+
definition=StepReasoningDefinition.from_dict(
177+
dictionary["definition"]
178+
),
117179
)

libs/labelbox/tests/integration/conftest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,14 +643,12 @@ def chat_evaluation_ontology(client, rand_gen):
643643
),
644644
],
645645
)
646-
647646
ontology = client.create_ontology(
648647
ontology_name,
649648
ontology_builder.asdict(),
650649
media_type=MediaType.Conversational,
651650
ontology_kind=OntologyKind.ModelEvaluation,
652651
)
653-
654652
yield ontology
655653

656654
try:

libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_create_chat_evaluation_ontology_project(
4343

4444

4545
def test_create_chat_evaluation_ontology_project_existing_dataset(
46-
client, chat_evaluation_ontology, chat_evaluation_project_append_to_dataset
46+
chat_evaluation_ontology, chat_evaluation_project_append_to_dataset
4747
):
4848
ontology = chat_evaluation_ontology
4949

@@ -85,6 +85,29 @@ def tools_json():
8585
"schemaNodeId": None,
8686
"featureSchemaId": None,
8787
},
88+
{
89+
"tool": "step-reasoning",
90+
"name": "step reasoning",
91+
"required": True,
92+
"schemaNodeId": None,
93+
"featureSchemaId": None,
94+
"color": "#0000ff",
95+
"definition": {
96+
"variants": [
97+
{"id": 0, "name": "Correct"},
98+
{"id": 1, "name": "Neutral"},
99+
{
100+
"id": 2,
101+
"name": "Incorrect",
102+
"actions": [
103+
"regenerateSteps",
104+
"generateAndRateAlternativeSteps",
105+
],
106+
},
107+
],
108+
"version": 1,
109+
},
110+
},
88111
]
89112

90113
return tools

libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33

44
def test_create_offline_chat_evaluation_project(
5-
client,
65
rand_gen,
76
offline_chat_evaluation_project,
87
chat_evaluation_ontology,

libs/labelbox/tests/integration/test_ontology.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from labelbox import MediaType, OntologyBuilder, Tool
77
from labelbox.orm.model import Entity
8+
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
89

910

1011
def test_feature_schema_is_not_archived(client, ontology):
@@ -322,3 +323,47 @@ def test_unarchive_feature_schema_node_for_non_existing_ontology(
322323
client.unarchive_feature_schema_node(
323324
"invalid-ontology", feature_schema_to_unarchive["featureSchemaId"]
324325
)
326+
327+
328+
def test_step_reasoning_ontology(chat_evaluation_ontology):
329+
ontology = chat_evaluation_ontology
330+
step_reasoning_tool = None
331+
for tool in ontology.normalized["tools"]:
332+
if tool["tool"] == "step-reasoning":
333+
step_reasoning_tool = tool
334+
break
335+
assert step_reasoning_tool is not None
336+
assert step_reasoning_tool["definition"]["variants"] == [
337+
{"id": 0, "name": "Correct"},
338+
{"id": 1, "name": "Neutral"},
339+
{
340+
"id": 2,
341+
"name": "Incorrect",
342+
"actions": ["regenerateSteps", "generateAndRateAlternativeSteps"],
343+
},
344+
]
345+
assert step_reasoning_tool["definition"]["version"] == 1
346+
assert step_reasoning_tool["schemaNodeId"] is not None
347+
assert step_reasoning_tool["featureSchemaId"] is not None
348+
349+
step_reasoning_tool = None
350+
for tool in ontology.tools():
351+
if isinstance(tool, StepReasoningTool):
352+
step_reasoning_tool = tool
353+
break
354+
assert step_reasoning_tool is not None
355+
assert step_reasoning_tool.definition.variants.asdict() == [
356+
{
357+
"id": 0,
358+
"name": "Correct",
359+
},
360+
{
361+
"id": 1,
362+
"name": "Neutral",
363+
},
364+
{
365+
"id": 2,
366+
"name": "Incorrect",
367+
"actions": ["regenerateSteps", "generateAndRateAlternativeSteps"],
368+
},
369+
]

libs/labelbox/tests/unit/test_unit_step_reasoning_tool.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,21 @@ def test_step_reasoning_as_dict_default():
99
"required": False,
1010
"schemaNodeId": None,
1111
"featureSchemaId": None,
12-
"variants": [
13-
{"id": 0, "name": "Correct"},
14-
{"id": 1, "name": "Neutral"},
15-
{"id": 2, "name": "Incorrect", "actions": []},
16-
],
12+
"definition": {
13+
"variants": [
14+
{"id": 0, "name": "Correct"},
15+
{"id": 1, "name": "Neutral"},
16+
{
17+
"id": 2,
18+
"name": "Incorrect",
19+
"actions": [
20+
"regenerateSteps",
21+
"generateAndRateAlternativeSteps",
22+
],
23+
},
24+
],
25+
"version": 1,
26+
},
1727
}
1828

1929

@@ -27,13 +37,19 @@ def test_step_reasoning_as_dict_with_actions():
2737
"required": False,
2838
"schemaNodeId": None,
2939
"featureSchemaId": None,
30-
"variants": [
31-
{"id": 0, "name": "Correct"},
32-
{"id": 1, "name": "Neutral"},
33-
{
34-
"id": 2,
35-
"name": "Incorrect",
36-
"actions": ["regenerateSteps", "generateAlternatives"],
37-
},
38-
],
40+
"definition": {
41+
"variants": [
42+
{"id": 0, "name": "Correct"},
43+
{"id": 1, "name": "Neutral"},
44+
{
45+
"id": 2,
46+
"name": "Incorrect",
47+
"actions": [
48+
"regenerateSteps",
49+
"generateAndRateAlternativeSteps",
50+
],
51+
},
52+
],
53+
"version": 1,
54+
},
3955
}

0 commit comments

Comments
 (0)