Skip to content

Commit d77eb7c

Browse files
authored
fix: add items to parameter schema (#62)
* fix: add items to parameter schema * update tests * resolve comments
1 parent 8a17e91 commit d77eb7c

File tree

3 files changed

+46
-15
lines changed

3 files changed

+46
-15
lines changed

src/toolbox_langchain/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ParameterSchema(BaseModel):
3030
type: str
3131
description: str
3232
authSources: Optional[list[str]] = None
33+
items: Optional["ParameterSchema"] = None
3334

3435

3536
class ToolSchema(BaseModel):
@@ -100,27 +101,28 @@ def _schema_to_model(model_name: str, schema: list[ParameterSchema]) -> Type[Bas
100101
(
101102
# TODO: Remove the hardcoded optional types once optional fields
102103
# are supported by Toolbox.
103-
Optional[_parse_type(field.type)],
104+
Optional[_parse_type(field)],
104105
Field(description=field.description),
105106
),
106107
)
107108

108109
return create_model(model_name, **field_definitions)
109110

110111

111-
def _parse_type(type_: str) -> Any:
112+
def _parse_type(schema_: ParameterSchema) -> Any:
112113
"""
113114
Converts a schema type to a JSON type.
114115
115116
Args:
116-
type_: The type name to convert.
117+
schema_: The ParameterSchema to convert.
117118
118119
Returns:
119120
A valid JSON type.
120121
121122
Raises:
122123
ValueError: If the given type is not supported.
123124
"""
125+
type_ = schema_.type
124126

125127
if type_ == "string":
126128
return str
@@ -131,7 +133,10 @@ def _parse_type(type_: str) -> Any:
131133
elif type_ == "boolean":
132134
return bool
133135
elif type_ == "array":
134-
return list[Union[str, int, float, bool]]
136+
if isinstance(schema_, ParameterSchema) and schema_.items:
137+
return list[_parse_type(schema_.items)] # type: ignore
138+
else:
139+
raise ValueError(f"Schema missing field items")
135140
else:
136141
raise ValueError(f"Unsupported schema type: {type_}")
137142

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Contains pytest fixtures that are accessible from all
15+
"""Contains pytest fixtures that are accessible from all
1616
files present in the same directory."""
1717

1818
from __future__ import annotations

tests/test_utils.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -165,21 +165,47 @@ def test_schema_to_model_empty(self):
165165
assert len(model.model_fields) == 0
166166

167167
@pytest.mark.parametrize(
168-
"type_string, expected_type",
168+
"parameter_schema, expected_type",
169169
[
170-
("string", str),
171-
("integer", int),
172-
("float", float),
173-
("boolean", bool),
174-
("array", list[Union[str, int, float, bool]]),
170+
(ParameterSchema(name="foo", description="bar", type="string"), str),
171+
(ParameterSchema(name="foo", description="bar", type="integer"), int),
172+
(ParameterSchema(name="foo", description="bar", type="float"), float),
173+
(ParameterSchema(name="foo", description="bar", type="boolean"), bool),
174+
(
175+
ParameterSchema(
176+
name="foo",
177+
description="bar",
178+
type="array",
179+
items=ParameterSchema(
180+
name="foo", description="bar", type="integer"
181+
),
182+
),
183+
list[int],
184+
),
175185
],
176186
)
177-
def test_parse_type(self, type_string, expected_type):
178-
assert _parse_type(type_string) == expected_type
187+
def test_parse_type(self, parameter_schema, expected_type):
188+
assert _parse_type(parameter_schema) == expected_type
179189

180-
def test_parse_type_invalid(self):
190+
@pytest.mark.parametrize(
191+
"fail_parameter_schema",
192+
[
193+
(ParameterSchema(name="foo", description="bar", type="invalid")),
194+
(
195+
ParameterSchema(
196+
name="foo",
197+
description="bar",
198+
type="array",
199+
items=ParameterSchema(
200+
name="foo", description="bar", type="invalid"
201+
),
202+
)
203+
),
204+
],
205+
)
206+
def test_parse_type_invalid(self, fail_parameter_schema):
181207
with pytest.raises(ValueError):
182-
_parse_type("invalid")
208+
_parse_type(fail_parameter_schema)
183209

184210
@pytest.mark.asyncio
185211
@patch("aiohttp.ClientSession.post")

0 commit comments

Comments
 (0)