Skip to content

Commit 01b292c

Browse files
author
Val Brodsky
committed
Update to support create_ontology_from_feature_schemas
1 parent 9bea518 commit 01b292c

File tree

8 files changed

+151
-24
lines changed

8 files changed

+151
-24
lines changed

libs/labelbox/src/labelbox/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
FeatureSchema,
5858
Ontology,
5959
PromptResponseClassification,
60-
Tool,
60+
tool_type_cls_from_type,
6161
)
6262
from labelbox.schema.ontology_kind import (
6363
EditorTaskType,
@@ -1106,7 +1106,8 @@ def create_ontology_from_feature_schemas(
11061106
if "tool" in feature_schema.normalized:
11071107
tool = feature_schema.normalized["tool"]
11081108
try:
1109-
Tool.Type(tool)
1109+
tool_type_cls = tool_type_cls_from_type(tool)
1110+
tool_type_cls(tool)
11101111
tools.append(feature_schema.normalized)
11111112
except ValueError:
11121113
raise ValueError(

libs/labelbox/src/labelbox/schema/ontology.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,12 @@ def tool_cls_from_type(tool_type: str):
500500
return Tool
501501

502502

503+
def tool_type_cls_from_type(tool_type: str):
504+
if tool_type.lower() == ToolType.STEP_REASONING.value:
505+
return ToolType
506+
return Tool.Type
507+
508+
503509
class Ontology(DbObject):
504510
"""An ontology specifies which tools and classifications are available
505511
to a project. This is read only for now.

libs/labelbox/src/labelbox/schema/tool_building/step_reasoning_tool.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Any, Dict, Optional
2+
from typing import Any, Dict, List, Optional
33

44
from labelbox.schema.tool_building.tool_type import ToolType
55

@@ -28,6 +28,19 @@ def asdict(self) -> Dict[str, Any]:
2828
actions.append("generateAlternatives")
2929
return {"id": self.id, "name": self.name, "actions": actions}
3030

31+
@classmethod
32+
def from_dict(
33+
cls, dictionary: Dict[str, Any]
34+
) -> "IncorrectStepReasoningVariant":
35+
return cls(
36+
id=dictionary["id"],
37+
name=dictionary["name"],
38+
regenerate_conversations_after_incorrect_step="regenerateSteps"
39+
in dictionary.get("actions", []),
40+
rate_alternative_responses="generateAlternatives"
41+
in dictionary.get("actions", []),
42+
)
43+
3144

3245
@dataclass
3346
class StepReasoningVariants:
@@ -36,10 +49,10 @@ class StepReasoningVariants:
3649
INCORRECT_STEP_ID = 2
3750

3851
correct_step: StepReasoningVariant = field(
39-
default=StepReasoningVariant(CORRECT_STEP_ID, "Correct"), init=False
52+
default=StepReasoningVariant(CORRECT_STEP_ID, "Correct")
4053
)
4154
neutral_step: StepReasoningVariant = field(
42-
default=StepReasoningVariant(NEUTRAL_STEP_ID, "Neutral"), init=False
55+
default=StepReasoningVariant(NEUTRAL_STEP_ID, "Neutral")
4356
)
4457
incorrect_step: IncorrectStepReasoningVariant = field(
4558
default=IncorrectStepReasoningVariant(INCORRECT_STEP_ID, "Incorrect"),
@@ -52,6 +65,31 @@ def asdict(self):
5265
self.incorrect_step.asdict(),
5366
]
5467

68+
@classmethod
69+
def from_dict(cls, dictionary: List[Dict[str, Any]]):
70+
correct_step = None
71+
neutral_step = None
72+
incorrect_step = None
73+
74+
for variant in dictionary:
75+
if variant["id"] == cls.CORRECT_STEP_ID:
76+
correct_step = StepReasoningVariant(**variant)
77+
elif variant["id"] == cls.NEUTRAL_STEP_ID:
78+
neutral_step = StepReasoningVariant(**variant)
79+
elif variant["id"] == cls.INCORRECT_STEP_ID:
80+
incorrect_step = IncorrectStepReasoningVariant.from_dict(
81+
variant
82+
)
83+
84+
if not all([correct_step, neutral_step, incorrect_step]):
85+
raise ValueError("Invalid step reasoning variants")
86+
87+
return cls(
88+
correct_step=correct_step, # type: ignore
89+
neutral_step=neutral_step, # type: ignore
90+
incorrect_step=incorrect_step, # type: ignore
91+
)
92+
5593

5694
@dataclass
5795
class StepReasoningDefinition:
@@ -73,6 +111,14 @@ def asdict(self) -> Dict[str, Any]:
73111
result["color"] = self.color
74112
return result
75113

114+
@classmethod
115+
def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningDefinition":
116+
variants = StepReasoningVariants.from_dict(dictionary["variants"])
117+
title = dictionary.get("title", None)
118+
value = dictionary.get("value", None)
119+
color = dictionary.get("color", None)
120+
return cls(variants=variants, title=title, value=value, color=color)
121+
76122

77123
@dataclass
78124
class StepReasoningTool:
@@ -113,5 +159,7 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningTool":
113159
schema_id=dictionary.get("schemaNodeId", None),
114160
feature_schema_id=dictionary.get("featureSchemaId", None),
115161
required=dictionary.get("required", False),
116-
definition=StepReasoningDefinition(**dictionary["definition"]),
162+
definition=StepReasoningDefinition.from_dict(
163+
dictionary["definition"]
164+
),
117165
)

libs/labelbox/tests/integration/conftest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,14 +643,12 @@ def chat_evaluation_ontology(client, rand_gen):
643643
),
644644
],
645645
)
646-
647646
ontology = client.create_ontology(
648647
ontology_name,
649648
ontology_builder.asdict(),
650649
media_type=MediaType.Conversational,
651650
ontology_kind=OntologyKind.ModelEvaluation,
652651
)
653-
654652
yield ontology
655653

656654
try:

libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_create_chat_evaluation_ontology_project(
4343

4444

4545
def test_create_chat_evaluation_ontology_project_existing_dataset(
46-
client, chat_evaluation_ontology, chat_evaluation_project_append_to_dataset
46+
chat_evaluation_ontology, chat_evaluation_project_append_to_dataset
4747
):
4848
ontology = chat_evaluation_ontology
4949

@@ -85,6 +85,26 @@ def tools_json():
8585
"schemaNodeId": None,
8686
"featureSchemaId": None,
8787
},
88+
{
89+
"tool": "step-reasoning",
90+
"name": "step reasoning",
91+
"required": True,
92+
"schemaNodeId": None,
93+
"featureSchemaId": None,
94+
"color": "#0000ff",
95+
"definition": {
96+
"variants": [
97+
{"id": 0, "name": "Correct"},
98+
{"id": 1, "name": "Neutral"},
99+
{
100+
"id": 2,
101+
"name": "Incorrect",
102+
"actions": ["regenerateSteps", "generateAlternatives"],
103+
},
104+
],
105+
"version": 1,
106+
},
107+
},
88108
]
89109

90110
return tools

libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33

44
def test_create_offline_chat_evaluation_project(
5-
client,
65
rand_gen,
76
offline_chat_evaluation_project,
87
chat_evaluation_ontology,

libs/labelbox/tests/integration/test_ontology.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from labelbox import MediaType, OntologyBuilder, Tool
77
from labelbox.orm.model import Entity
8+
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
89

910

1011
def test_feature_schema_is_not_archived(client, ontology):
@@ -322,3 +323,47 @@ def test_unarchive_feature_schema_node_for_non_existing_ontology(
322323
client.unarchive_feature_schema_node(
323324
"invalid-ontology", feature_schema_to_unarchive["featureSchemaId"]
324325
)
326+
327+
328+
def test_step_reasoning_ontology(chat_evaluation_ontology):
329+
ontology = chat_evaluation_ontology
330+
step_reasoning_tool = None
331+
for tool in ontology.normalized["tools"]:
332+
if tool["tool"] == "step-reasoning":
333+
step_reasoning_tool = tool
334+
break
335+
assert step_reasoning_tool is not None
336+
assert step_reasoning_tool["definition"]["variants"] == [
337+
{"id": 0, "name": "Correct"},
338+
{"id": 1, "name": "Neutral"},
339+
{
340+
"id": 2,
341+
"name": "Incorrect",
342+
"actions": ["regenerateSteps", "generateAlternatives"],
343+
},
344+
]
345+
assert step_reasoning_tool["definition"]["version"] == 1
346+
assert step_reasoning_tool["schemaNodeId"] is not None
347+
assert step_reasoning_tool["featureSchemaId"] is not None
348+
349+
step_reasoning_tool = None
350+
for tool in ontology.tools():
351+
if isinstance(tool, StepReasoningTool):
352+
step_reasoning_tool = tool
353+
break
354+
assert step_reasoning_tool is not None
355+
assert step_reasoning_tool.definition.variants.asdict() == [
356+
{
357+
"id": 0,
358+
"name": "Correct",
359+
},
360+
{
361+
"id": 1,
362+
"name": "Neutral",
363+
},
364+
{
365+
"id": 2,
366+
"name": "Incorrect",
367+
"actions": ["regenerateSteps", "generateAlternatives"],
368+
},
369+
]

libs/labelbox/tests/unit/test_unit_step_reasoning_tool.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,18 @@ def test_step_reasoning_as_dict_default():
99
"required": False,
1010
"schemaNodeId": None,
1111
"featureSchemaId": None,
12-
"variants": [
13-
{"id": 0, "name": "Correct"},
14-
{"id": 1, "name": "Neutral"},
15-
{"id": 2, "name": "Incorrect", "actions": []},
16-
],
12+
"definition": {
13+
"variants": [
14+
{"id": 0, "name": "Correct"},
15+
{"id": 1, "name": "Neutral"},
16+
{
17+
"id": 2,
18+
"name": "Incorrect",
19+
"actions": ["regenerateSteps", "generateAlternatives"],
20+
},
21+
],
22+
"version": 1,
23+
},
1724
}
1825

1926

@@ -27,13 +34,16 @@ def test_step_reasoning_as_dict_with_actions():
2734
"required": False,
2835
"schemaNodeId": None,
2936
"featureSchemaId": None,
30-
"variants": [
31-
{"id": 0, "name": "Correct"},
32-
{"id": 1, "name": "Neutral"},
33-
{
34-
"id": 2,
35-
"name": "Incorrect",
36-
"actions": ["regenerateSteps", "generateAlternatives"],
37-
},
38-
],
37+
"definition": {
38+
"variants": [
39+
{"id": 0, "name": "Correct"},
40+
{"id": 1, "name": "Neutral"},
41+
{
42+
"id": 2,
43+
"name": "Incorrect",
44+
"actions": ["regenerateSteps", "generateAlternatives"],
45+
},
46+
],
47+
"version": 1,
48+
},
3949
}

0 commit comments

Comments
 (0)