Skip to content

Commit b29d1c0

Browse files
author
Val Brodsky
committed
Update to support create_ontology_from_feature_schemas
1 parent ff31b93 commit b29d1c0

File tree

9 files changed

+181
-41
lines changed

9 files changed

+181
-41
lines changed

libs/labelbox/src/labelbox/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
FeatureSchema,
5757
Ontology,
5858
PromptResponseClassification,
59-
Tool,
59+
tool_type_cls_from_type,
6060
)
6161
from labelbox.schema.ontology_kind import (
6262
EditorTaskType,
@@ -1098,7 +1098,8 @@ def create_ontology_from_feature_schemas(
10981098
if "tool" in feature_schema.normalized:
10991099
tool = feature_schema.normalized["tool"]
11001100
try:
1101-
Tool.Type(tool)
1101+
tool_type_cls = tool_type_cls_from_type(tool)
1102+
tool_type_cls(tool)
11021103
tools.append(feature_schema.normalized)
11031104
except ValueError:
11041105
raise ValueError(

libs/labelbox/src/labelbox/schema/ontology.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ class Type(Enum):
353353
@classmethod
354354
def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
355355
return cls(
356-
class_type=Type(dictionary["type"]),
356+
class_type=PromptResponseClassification.Type(dictionary["type"]),
357357
name=dictionary["name"],
358358
instructions=dictionary["instructions"],
359359
required=True, # always required
@@ -492,14 +492,16 @@ def add_classification(self, classification: Classification) -> None:
492492

493493
def tool_cls_from_type(tool_type: str):
494494
if tool_type.lower() == ToolType.STEP_REASONING.value:
495-
from labelbox.schema.tool_building.step_reasoning_tool import (
496-
StepReasoningTool,
497-
)
498-
499495
return StepReasoningTool
500496
return Tool
501497

502498

499+
def tool_type_cls_from_type(tool_type: str):
500+
if tool_type.lower() == ToolType.STEP_REASONING.value:
501+
return ToolType
502+
return Tool.Type
503+
504+
503505
class Ontology(DbObject):
504506
"""An ontology specifies which tools and classifications are available
505507
to a project. This is read only for now.
@@ -533,12 +535,6 @@ def __init__(self, *args, **kwargs) -> None:
533535
Union[List[Classification], List[PromptResponseClassification]]
534536
] = None
535537

536-
def _tool_deserializer_cls(self, tool: Dict[str, Any]) -> Tool:
537-
import pdb
538-
539-
pdb.set_trace()
540-
return Tool
541-
542538
def tools(self) -> List[Tool]:
543539
"""Get list of tools (AKA objects) in an Ontology."""
544540
if self._tools is None:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
import labelbox.schema.tool_building.tool_type
2-
import labelbox.schema.tool_building.step_reasoning_tool
2+
import labelbox.schema.tool_building.step_reasoning_tool

libs/labelbox/src/labelbox/schema/tool_building/step_reasoning_tool.py

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Any, Dict, Optional
2+
from typing import Any, Dict, List, Optional
33

44
from labelbox.schema.tool_building.tool_type import ToolType
55

@@ -25,9 +25,40 @@ def asdict(self) -> Dict[str, Any]:
2525
if self.regenerate_conversations_after_incorrect_step:
2626
actions.append("regenerateSteps")
2727
if self.rate_alternative_responses:
28-
actions.append("generateAlternatives")
28+
actions.append("generateAndRateAlternativeSteps")
2929
return {"id": self.id, "name": self.name, "actions": actions}
3030

31+
@classmethod
32+
def from_dict(
33+
cls, dictionary: Dict[str, Any]
34+
) -> "IncorrectStepReasoningVariant":
35+
return cls(
36+
id=dictionary["id"],
37+
name=dictionary["name"],
38+
regenerate_conversations_after_incorrect_step="regenerateSteps"
39+
in dictionary.get("actions", []),
40+
rate_alternative_responses="generateAndRateAlternativeSteps"
41+
in dictionary.get("actions", []),
42+
)
43+
44+
45+
def _create_correct_step() -> StepReasoningVariant:
46+
return StepReasoningVariant(
47+
id=StepReasoningVariants.CORRECT_STEP_ID, name="Correct"
48+
)
49+
50+
51+
def _create_neutral_step() -> StepReasoningVariant:
52+
return StepReasoningVariant(
53+
id=StepReasoningVariants.NEUTRAL_STEP_ID, name="Neutral"
54+
)
55+
56+
57+
def _create_incorrect_step() -> IncorrectStepReasoningVariant:
58+
return IncorrectStepReasoningVariant(
59+
id=StepReasoningVariants.INCORRECT_STEP_ID, name="Incorrect"
60+
)
61+
3162

3263
@dataclass
3364
class StepReasoningVariants:
@@ -36,13 +67,13 @@ class StepReasoningVariants:
3667
INCORRECT_STEP_ID = 2
3768

3869
correct_step: StepReasoningVariant = field(
39-
default=StepReasoningVariant(CORRECT_STEP_ID, "Correct"), init=False
70+
default_factory=_create_correct_step
4071
)
4172
neutral_step: StepReasoningVariant = field(
42-
default=StepReasoningVariant(NEUTRAL_STEP_ID, "Neutral"), init=False
73+
default_factory=_create_neutral_step
4374
)
4475
incorrect_step: IncorrectStepReasoningVariant = field(
45-
default=IncorrectStepReasoningVariant(INCORRECT_STEP_ID, "Incorrect"),
76+
default_factory=_create_incorrect_step
4677
)
4778

4879
def asdict(self):
@@ -52,6 +83,31 @@ def asdict(self):
5283
self.incorrect_step.asdict(),
5384
]
5485

86+
@classmethod
87+
def from_dict(cls, dictionary: List[Dict[str, Any]]):
88+
correct_step = None
89+
neutral_step = None
90+
incorrect_step = None
91+
92+
for variant in dictionary:
93+
if variant["id"] == cls.CORRECT_STEP_ID:
94+
correct_step = StepReasoningVariant(**variant)
95+
elif variant["id"] == cls.NEUTRAL_STEP_ID:
96+
neutral_step = StepReasoningVariant(**variant)
97+
elif variant["id"] == cls.INCORRECT_STEP_ID:
98+
incorrect_step = IncorrectStepReasoningVariant.from_dict(
99+
variant
100+
)
101+
102+
if not all([correct_step, neutral_step, incorrect_step]):
103+
raise ValueError("Invalid step reasoning variants")
104+
105+
return cls(
106+
correct_step=correct_step, # type: ignore
107+
neutral_step=neutral_step, # type: ignore
108+
incorrect_step=incorrect_step, # type: ignore
109+
)
110+
55111

56112
@dataclass
57113
class StepReasoningDefinition:
@@ -61,18 +117,22 @@ class StepReasoningDefinition:
61117
version: int = field(default=1)
62118
title: Optional[str] = None
63119
value: Optional[str] = None
64-
color: Optional[str] = None
65120

66121
def asdict(self) -> Dict[str, Any]:
67122
result = {"variants": self.variants.asdict(), "version": self.version}
68123
if self.title is not None:
69124
result["title"] = self.title
70125
if self.value is not None:
71126
result["value"] = self.value
72-
if self.color is not None:
73-
result["color"] = self.color
74127
return result
75128

129+
@classmethod
130+
def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningDefinition":
131+
variants = StepReasoningVariants.from_dict(dictionary["variants"])
132+
title = dictionary.get("title", None)
133+
value = dictionary.get("value", None)
134+
return cls(variants=variants, title=title, value=value)
135+
76136

77137
@dataclass
78138
class StepReasoningTool:
@@ -113,5 +173,7 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningTool":
113173
schema_id=dictionary.get("schemaNodeId", None),
114174
feature_schema_id=dictionary.get("featureSchemaId", None),
115175
required=dictionary.get("required", False),
116-
definition=StepReasoningDefinition(**dictionary["definition"]),
176+
definition=StepReasoningDefinition.from_dict(
177+
dictionary["definition"]
178+
),
117179
)

libs/labelbox/tests/integration/conftest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,14 +630,12 @@ def chat_evaluation_ontology(client, rand_gen):
630630
),
631631
],
632632
)
633-
634633
ontology = client.create_ontology(
635634
ontology_name,
636635
ontology_builder.asdict(),
637636
media_type=MediaType.Conversational,
638637
ontology_kind=OntologyKind.ModelEvaluation,
639638
)
640-
641639
yield ontology
642640

643641
try:

libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_create_chat_evaluation_ontology_project(
4141

4242

4343
def test_create_chat_evaluation_ontology_project_existing_dataset(
44-
client, chat_evaluation_ontology, chat_evaluation_project_append_to_dataset
44+
chat_evaluation_ontology, chat_evaluation_project_append_to_dataset
4545
):
4646
ontology = chat_evaluation_ontology
4747

@@ -83,6 +83,29 @@ def tools_json():
8383
"schemaNodeId": None,
8484
"featureSchemaId": None,
8585
},
86+
{
87+
"tool": "step-reasoning",
88+
"name": "step reasoning",
89+
"required": True,
90+
"schemaNodeId": None,
91+
"featureSchemaId": None,
92+
"color": "#0000ff",
93+
"definition": {
94+
"variants": [
95+
{"id": 0, "name": "Correct"},
96+
{"id": 1, "name": "Neutral"},
97+
{
98+
"id": 2,
99+
"name": "Incorrect",
100+
"actions": [
101+
"regenerateSteps",
102+
"generateAndRateAlternativeSteps",
103+
],
104+
},
105+
],
106+
"version": 1,
107+
},
108+
},
86109
]
87110

88111
return tools

libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33

44
def test_create_offline_chat_evaluation_project(
5-
client,
65
rand_gen,
76
offline_chat_evaluation_project,
87
chat_evaluation_ontology,

libs/labelbox/tests/integration/test_ontology.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from labelbox import MediaType, OntologyBuilder, Tool
77
from labelbox.orm.model import Entity
8+
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
89

910

1011
def test_feature_schema_is_not_archived(client, ontology):
@@ -322,3 +323,47 @@ def test_unarchive_feature_schema_node_for_non_existing_ontology(
322323
client.unarchive_feature_schema_node(
323324
"invalid-ontology", feature_schema_to_unarchive["featureSchemaId"]
324325
)
326+
327+
328+
def test_step_reasoning_ontology(chat_evaluation_ontology):
329+
ontology = chat_evaluation_ontology
330+
step_reasoning_tool = None
331+
for tool in ontology.normalized["tools"]:
332+
if tool["tool"] == "step-reasoning":
333+
step_reasoning_tool = tool
334+
break
335+
assert step_reasoning_tool is not None
336+
assert step_reasoning_tool["definition"]["variants"] == [
337+
{"id": 0, "name": "Correct"},
338+
{"id": 1, "name": "Neutral"},
339+
{
340+
"id": 2,
341+
"name": "Incorrect",
342+
"actions": ["regenerateSteps", "generateAndRateAlternativeSteps"],
343+
},
344+
]
345+
assert step_reasoning_tool["definition"]["version"] == 1
346+
assert step_reasoning_tool["schemaNodeId"] is not None
347+
assert step_reasoning_tool["featureSchemaId"] is not None
348+
349+
step_reasoning_tool = None
350+
for tool in ontology.tools():
351+
if isinstance(tool, StepReasoningTool):
352+
step_reasoning_tool = tool
353+
break
354+
assert step_reasoning_tool is not None
355+
assert step_reasoning_tool.definition.variants.asdict() == [
356+
{
357+
"id": 0,
358+
"name": "Correct",
359+
},
360+
{
361+
"id": 1,
362+
"name": "Neutral",
363+
},
364+
{
365+
"id": 2,
366+
"name": "Incorrect",
367+
"actions": ["regenerateSteps", "generateAndRateAlternativeSteps"],
368+
},
369+
]

libs/labelbox/tests/unit/test_unit_step_reasoning_tool.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,21 @@ def test_step_reasoning_as_dict_default():
99
"required": False,
1010
"schemaNodeId": None,
1111
"featureSchemaId": None,
12-
"variants": [
13-
{"id": 0, "name": "Correct"},
14-
{"id": 1, "name": "Neutral"},
15-
{"id": 2, "name": "Incorrect", "actions": []},
16-
],
12+
"definition": {
13+
"variants": [
14+
{"id": 0, "name": "Correct"},
15+
{"id": 1, "name": "Neutral"},
16+
{
17+
"id": 2,
18+
"name": "Incorrect",
19+
"actions": [
20+
"regenerateSteps",
21+
"generateAndRateAlternativeSteps",
22+
],
23+
},
24+
],
25+
"version": 1,
26+
},
1727
}
1828

1929

@@ -27,13 +37,19 @@ def test_step_reasoning_as_dict_with_actions():
2737
"required": False,
2838
"schemaNodeId": None,
2939
"featureSchemaId": None,
30-
"variants": [
31-
{"id": 0, "name": "Correct"},
32-
{"id": 1, "name": "Neutral"},
33-
{
34-
"id": 2,
35-
"name": "Incorrect",
36-
"actions": ["regenerateSteps", "generateAlternatives"],
37-
},
38-
],
40+
"definition": {
41+
"variants": [
42+
{"id": 0, "name": "Correct"},
43+
{"id": 1, "name": "Neutral"},
44+
{
45+
"id": 2,
46+
"name": "Incorrect",
47+
"actions": [
48+
"regenerateSteps",
49+
"generateAndRateAlternativeSteps",
50+
],
51+
},
52+
],
53+
"version": 1,
54+
},
3955
}

0 commit comments

Comments
 (0)