Skip to content

Commit 9207b07

Browse files
author
Val Brodsky
committed
Update to support create_ontology_from_feature_schemas
1 parent 9bea518 commit 9207b07

File tree

9 files changed

+155
-36
lines changed

9 files changed

+155
-36
lines changed

libs/labelbox/src/labelbox/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
FeatureSchema,
5858
Ontology,
5959
PromptResponseClassification,
60-
Tool,
60+
tool_type_cls_from_type,
6161
)
6262
from labelbox.schema.ontology_kind import (
6363
EditorTaskType,
@@ -1106,7 +1106,8 @@ def create_ontology_from_feature_schemas(
11061106
if "tool" in feature_schema.normalized:
11071107
tool = feature_schema.normalized["tool"]
11081108
try:
1109-
Tool.Type(tool)
1109+
tool_type_cls = tool_type_cls_from_type(tool)
1110+
tool_type_cls(tool)
11101111
tools.append(feature_schema.normalized)
11111112
except ValueError:
11121113
raise ValueError(

libs/labelbox/src/labelbox/schema/ontology.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -492,14 +492,16 @@ def add_classification(self, classification: Classification) -> None:
492492

493493
def tool_cls_from_type(tool_type: str):
494494
if tool_type.lower() == ToolType.STEP_REASONING.value:
495-
from labelbox.schema.tool_building.step_reasoning_tool import (
496-
StepReasoningTool,
497-
)
498-
499495
return StepReasoningTool
500496
return Tool
501497

502498

499+
def tool_type_cls_from_type(tool_type: str):
500+
if tool_type.lower() == ToolType.STEP_REASONING.value:
501+
return ToolType
502+
return Tool.Type
503+
504+
503505
class Ontology(DbObject):
504506
"""An ontology specifies which tools and classifications are available
505507
to a project. This is read only for now.
@@ -533,12 +535,6 @@ def __init__(self, *args, **kwargs) -> None:
533535
Union[List[Classification], List[PromptResponseClassification]]
534536
] = None
535537

536-
def _tool_deserializer_cls(self, tool: Dict[str, Any]) -> Tool:
537-
import pdb
538-
539-
pdb.set_trace()
540-
return Tool
541-
542538
def tools(self) -> List[Tool]:
543539
"""Get list of tools (AKA objects) in an Ontology."""
544540
if self._tools is None:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
import labelbox.schema.tool_building.tool_type
2-
import labelbox.schema.tool_building.step_reasoning_tool
2+
import labelbox.schema.tool_building.step_reasoning_tool

libs/labelbox/src/labelbox/schema/tool_building/step_reasoning_tool.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Any, Dict, Optional
2+
from typing import Any, Dict, List, Optional
33

44
from labelbox.schema.tool_building.tool_type import ToolType
55

@@ -28,6 +28,19 @@ def asdict(self) -> Dict[str, Any]:
2828
actions.append("generateAlternatives")
2929
return {"id": self.id, "name": self.name, "actions": actions}
3030

31+
@classmethod
32+
def from_dict(
33+
cls, dictionary: Dict[str, Any]
34+
) -> "IncorrectStepReasoningVariant":
35+
return cls(
36+
id=dictionary["id"],
37+
name=dictionary["name"],
38+
regenerate_conversations_after_incorrect_step="regenerateSteps"
39+
in dictionary.get("actions", []),
40+
rate_alternative_responses="generateAlternatives"
41+
in dictionary.get("actions", []),
42+
)
43+
3144

3245
@dataclass
3346
class StepReasoningVariants:
@@ -36,13 +49,15 @@ class StepReasoningVariants:
3649
INCORRECT_STEP_ID = 2
3750

3851
correct_step: StepReasoningVariant = field(
39-
default=StepReasoningVariant(CORRECT_STEP_ID, "Correct"), init=False
52+
default=StepReasoningVariant(id=CORRECT_STEP_ID, name="Correct")
4053
)
4154
neutral_step: StepReasoningVariant = field(
42-
default=StepReasoningVariant(NEUTRAL_STEP_ID, "Neutral"), init=False
55+
default=StepReasoningVariant(id=NEUTRAL_STEP_ID, name="Neutral")
4356
)
4457
incorrect_step: IncorrectStepReasoningVariant = field(
45-
default=IncorrectStepReasoningVariant(INCORRECT_STEP_ID, "Incorrect"),
58+
default=IncorrectStepReasoningVariant(
59+
id=INCORRECT_STEP_ID, name="Incorrect"
60+
),
4661
)
4762

4863
def asdict(self):
@@ -52,6 +67,31 @@ def asdict(self):
5267
self.incorrect_step.asdict(),
5368
]
5469

70+
@classmethod
71+
def from_dict(cls, dictionary: List[Dict[str, Any]]):
72+
correct_step = None
73+
neutral_step = None
74+
incorrect_step = None
75+
76+
for variant in dictionary:
77+
if variant["id"] == cls.CORRECT_STEP_ID:
78+
correct_step = StepReasoningVariant(**variant)
79+
elif variant["id"] == cls.NEUTRAL_STEP_ID:
80+
neutral_step = StepReasoningVariant(**variant)
81+
elif variant["id"] == cls.INCORRECT_STEP_ID:
82+
incorrect_step = IncorrectStepReasoningVariant.from_dict(
83+
variant
84+
)
85+
86+
if not all([correct_step, neutral_step, incorrect_step]):
87+
raise ValueError("Invalid step reasoning variants")
88+
89+
return cls(
90+
correct_step=correct_step, # type: ignore
91+
neutral_step=neutral_step, # type: ignore
92+
incorrect_step=incorrect_step, # type: ignore
93+
)
94+
5595

5696
@dataclass
5797
class StepReasoningDefinition:
@@ -73,6 +113,14 @@ def asdict(self) -> Dict[str, Any]:
73113
result["color"] = self.color
74114
return result
75115

116+
@classmethod
117+
def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningDefinition":
118+
variants = StepReasoningVariants.from_dict(dictionary["variants"])
119+
title = dictionary.get("title", None)
120+
value = dictionary.get("value", None)
121+
color = dictionary.get("color", None)
122+
return cls(variants=variants, title=title, value=value, color=color)
123+
76124

77125
@dataclass
78126
class StepReasoningTool:
@@ -113,5 +161,7 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningTool":
113161
schema_id=dictionary.get("schemaNodeId", None),
114162
feature_schema_id=dictionary.get("featureSchemaId", None),
115163
required=dictionary.get("required", False),
116-
definition=StepReasoningDefinition(**dictionary["definition"]),
164+
definition=StepReasoningDefinition.from_dict(
165+
dictionary["definition"]
166+
),
117167
)

libs/labelbox/tests/integration/conftest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,14 +643,12 @@ def chat_evaluation_ontology(client, rand_gen):
643643
),
644644
],
645645
)
646-
647646
ontology = client.create_ontology(
648647
ontology_name,
649648
ontology_builder.asdict(),
650649
media_type=MediaType.Conversational,
651650
ontology_kind=OntologyKind.ModelEvaluation,
652651
)
653-
654652
yield ontology
655653

656654
try:

libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_create_chat_evaluation_ontology_project(
4343

4444

4545
def test_create_chat_evaluation_ontology_project_existing_dataset(
46-
client, chat_evaluation_ontology, chat_evaluation_project_append_to_dataset
46+
chat_evaluation_ontology, chat_evaluation_project_append_to_dataset
4747
):
4848
ontology = chat_evaluation_ontology
4949

@@ -85,6 +85,26 @@ def tools_json():
8585
"schemaNodeId": None,
8686
"featureSchemaId": None,
8787
},
88+
{
89+
"tool": "step-reasoning",
90+
"name": "step reasoning",
91+
"required": True,
92+
"schemaNodeId": None,
93+
"featureSchemaId": None,
94+
"color": "#0000ff",
95+
"definition": {
96+
"variants": [
97+
{"id": 0, "name": "Correct"},
98+
{"id": 1, "name": "Neutral"},
99+
{
100+
"id": 2,
101+
"name": "Incorrect",
102+
"actions": ["regenerateSteps", "generateAlternatives"],
103+
},
104+
],
105+
"version": 1,
106+
},
107+
},
88108
]
89109

90110
return tools

libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33

44
def test_create_offline_chat_evaluation_project(
5-
client,
65
rand_gen,
76
offline_chat_evaluation_project,
87
chat_evaluation_ontology,

libs/labelbox/tests/integration/test_ontology.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from labelbox import MediaType, OntologyBuilder, Tool
77
from labelbox.orm.model import Entity
8+
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
89

910

1011
def test_feature_schema_is_not_archived(client, ontology):
@@ -322,3 +323,47 @@ def test_unarchive_feature_schema_node_for_non_existing_ontology(
322323
client.unarchive_feature_schema_node(
323324
"invalid-ontology", feature_schema_to_unarchive["featureSchemaId"]
324325
)
326+
327+
328+
def test_step_reasoning_ontology(chat_evaluation_ontology):
329+
ontology = chat_evaluation_ontology
330+
step_reasoning_tool = None
331+
for tool in ontology.normalized["tools"]:
332+
if tool["tool"] == "step-reasoning":
333+
step_reasoning_tool = tool
334+
break
335+
assert step_reasoning_tool is not None
336+
assert step_reasoning_tool["definition"]["variants"] == [
337+
{"id": 0, "name": "Correct"},
338+
{"id": 1, "name": "Neutral"},
339+
{
340+
"id": 2,
341+
"name": "Incorrect",
342+
"actions": ["regenerateSteps", "generateAlternatives"],
343+
},
344+
]
345+
assert step_reasoning_tool["definition"]["version"] == 1
346+
assert step_reasoning_tool["schemaNodeId"] is not None
347+
assert step_reasoning_tool["featureSchemaId"] is not None
348+
349+
step_reasoning_tool = None
350+
for tool in ontology.tools():
351+
if isinstance(tool, StepReasoningTool):
352+
step_reasoning_tool = tool
353+
break
354+
assert step_reasoning_tool is not None
355+
assert step_reasoning_tool.definition.variants.asdict() == [
356+
{
357+
"id": 0,
358+
"name": "Correct",
359+
},
360+
{
361+
"id": 1,
362+
"name": "Neutral",
363+
},
364+
{
365+
"id": 2,
366+
"name": "Incorrect",
367+
"actions": ["regenerateSteps", "generateAlternatives"],
368+
},
369+
]

libs/labelbox/tests/unit/test_unit_step_reasoning_tool.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,18 @@ def test_step_reasoning_as_dict_default():
99
"required": False,
1010
"schemaNodeId": None,
1111
"featureSchemaId": None,
12-
"variants": [
13-
{"id": 0, "name": "Correct"},
14-
{"id": 1, "name": "Neutral"},
15-
{"id": 2, "name": "Incorrect", "actions": []},
16-
],
12+
"definition": {
13+
"variants": [
14+
{"id": 0, "name": "Correct"},
15+
{"id": 1, "name": "Neutral"},
16+
{
17+
"id": 2,
18+
"name": "Incorrect",
19+
"actions": ["regenerateSteps", "generateAlternatives"],
20+
},
21+
],
22+
"version": 1,
23+
},
1724
}
1825

1926

@@ -27,13 +34,16 @@ def test_step_reasoning_as_dict_with_actions():
2734
"required": False,
2835
"schemaNodeId": None,
2936
"featureSchemaId": None,
30-
"variants": [
31-
{"id": 0, "name": "Correct"},
32-
{"id": 1, "name": "Neutral"},
33-
{
34-
"id": 2,
35-
"name": "Incorrect",
36-
"actions": ["regenerateSteps", "generateAlternatives"],
37-
},
38-
],
37+
"definition": {
38+
"variants": [
39+
{"id": 0, "name": "Correct"},
40+
{"id": 1, "name": "Neutral"},
41+
{
42+
"id": 2,
43+
"name": "Incorrect",
44+
"actions": ["regenerateSteps", "generateAlternatives"],
45+
},
46+
],
47+
"version": 1,
48+
},
3949
}

0 commit comments

Comments
 (0)