Skip to content

Commit 6d130b6

Browse files
authored
Stop incorrectly treating side-by-side JSON schema $refs as recursion (#1697)
1 parent 7bfb947 commit 6d130b6

File tree

2 files changed

+56
-13
lines changed

2 files changed

+56
-13
lines changed

pydantic_ai_slim/pydantic_ai/models/_json_schema.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525
self.simplify_nullable_unions = simplify_nullable_unions
2626

2727
self.defs: dict[str, JsonSchema] = self.schema.get('$defs', {})
28-
self.refs_stack = tuple[str, ...]()
28+
self.refs_stack: list[str] = []
2929
self.recursive_refs = set[str]()
3030

3131
@abstractmethod
@@ -62,13 +62,16 @@ def walk(self) -> JsonSchema:
6262
return handled
6363

6464
def _handle(self, schema: JsonSchema) -> JsonSchema:
65+
nested_refs = 0
6566
if self.prefer_inlined_defs:
6667
while ref := schema.get('$ref'):
6768
key = re.sub(r'^#/\$defs/', '', ref)
6869
if key in self.refs_stack:
6970
self.recursive_refs.add(key)
7071
break # recursive ref can't be unpacked
71-
self.refs_stack += (key,)
72+
self.refs_stack.append(key)
73+
nested_refs += 1
74+
7275
def_schema = self.defs.get(key)
7376
if def_schema is None: # pragma: no cover
7477
raise UserError(f'Could not find $ref definition for {key}')
@@ -87,6 +90,9 @@ def _handle(self, schema: JsonSchema) -> JsonSchema:
8790
# Apply the base transform
8891
schema = self.transform(schema)
8992

93+
if nested_refs > 0:
94+
self.refs_stack = self.refs_stack[:-nested_refs]
95+
9096
return schema
9197

9298
def _handle_object(self, schema: JsonSchema) -> JsonSchema:

tests/models/test_gemini.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,17 @@ async def test_require_response_tool(allow_model_requests: None):
164164

165165

166166
async def test_json_def_replaced(allow_model_requests: None):
167+
class Axis(BaseModel):
168+
label: str
169+
170+
class Chart(BaseModel):
171+
x_axis: Axis
172+
y_axis: Axis
173+
167174
class Location(BaseModel):
168175
lat: float
169176
lng: float = 1.1
177+
chart: Chart
170178

171179
class Locations(BaseModel):
172180
locations: list[Location]
@@ -175,15 +183,28 @@ class Locations(BaseModel):
175183
assert json_schema == snapshot(
176184
{
177185
'$defs': {
186+
'Axis': {
187+
'properties': {'label': {'title': 'Label', 'type': 'string'}},
188+
'required': ['label'],
189+
'title': 'Axis',
190+
'type': 'object',
191+
},
192+
'Chart': {
193+
'properties': {'x_axis': {'$ref': '#/$defs/Axis'}, 'y_axis': {'$ref': '#/$defs/Axis'}},
194+
'required': ['x_axis', 'y_axis'],
195+
'title': 'Chart',
196+
'type': 'object',
197+
},
178198
'Location': {
179199
'properties': {
180200
'lat': {'title': 'Lat', 'type': 'number'},
181201
'lng': {'default': 1.1, 'title': 'Lng', 'type': 'number'},
202+
'chart': {'$ref': '#/$defs/Chart'},
182203
},
183-
'required': ['lat'],
204+
'required': ['lat', 'chart'],
184205
'title': 'Location',
185206
'type': 'object',
186-
}
207+
},
187208
},
188209
'properties': {'locations': {'items': {'$ref': '#/$defs/Location'}, 'title': 'Locations', 'type': 'array'}},
189210
'required': ['locations'],
@@ -201,20 +222,36 @@ class Locations(BaseModel):
201222
mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool])
202223
mrp = m.customize_request_parameters(mrp)
203224
assert m._get_tools(mrp) == snapshot(
204-
_GeminiTools(
205-
function_declarations=[
206-
_GeminiFunction(
207-
name='result',
208-
description='This is the tool for the final Result',
209-
parameters={
225+
{
226+
'function_declarations': [
227+
{
228+
'name': 'result',
229+
'description': 'This is the tool for the final Result',
230+
'parameters': {
210231
'properties': {
211232
'locations': {
212233
'items': {
213234
'properties': {
214235
'lat': {'type': 'number'},
215236
'lng': {'type': 'number'},
237+
'chart': {
238+
'properties': {
239+
'x_axis': {
240+
'properties': {'label': {'type': 'string'}},
241+
'required': ['label'],
242+
'type': 'object',
243+
},
244+
'y_axis': {
245+
'properties': {'label': {'type': 'string'}},
246+
'required': ['label'],
247+
'type': 'object',
248+
},
249+
},
250+
'required': ['x_axis', 'y_axis'],
251+
'type': 'object',
252+
},
216253
},
217-
'required': ['lat'],
254+
'required': ['lat', 'chart'],
218255
'type': 'object',
219256
},
220257
'type': 'array',
@@ -223,9 +260,9 @@ class Locations(BaseModel):
223260
'required': ['locations'],
224261
'type': 'object',
225262
},
226-
)
263+
}
227264
]
228-
)
265+
}
229266
)
230267

231268

0 commit comments

Comments
 (0)