Skip to content

Commit 9054023

Browse files
authored
Add attempt to fix invalid json before throwing error (#106)
* Add attempt to fix invalid json before throwing error * Add checks to balance braces * Add function to balance braces and brackets * Fixed multiple issues test
1 parent 2b78391 commit 9054023

File tree

3 files changed

+277
-10
lines changed

3 files changed

+277
-10
lines changed

src/neo4j_genai/experimental/components/entity_relation_extractor.py

Lines changed: 88 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import enum
2020
import json
2121
import logging
22+
import re
2223
from datetime import datetime
2324
from typing import Any, Dict, List, Union
2425

@@ -50,6 +51,79 @@ class OnError(enum.Enum):
5051
NODE_TO_CHUNK_RELATIONSHIP_TYPE = "FROM_CHUNK"
5152

5253

54+
def balance_curly_braces(json_string: str) -> str:
55+
"""
56+
Balances curly braces `{}` in a JSON string. This function ensures that every opening brace has a corresponding
57+
closing brace, but only when they are not part of a string value. If there are unbalanced closing braces,
58+
they are ignored. If there are missing closing braces, they are appended at the end of the string.
59+
60+
Args:
61+
json_string (str): A potentially malformed JSON string with unbalanced curly braces.
62+
63+
Returns:
64+
str: A JSON string with balanced curly braces.
65+
"""
66+
stack = []
67+
fixed_json = []
68+
in_string = False
69+
escape = False
70+
71+
for char in json_string:
72+
if char == '"' and not escape:
73+
in_string = not in_string
74+
elif char == "\\" and in_string:
75+
escape = not escape
76+
fixed_json.append(char)
77+
continue
78+
else:
79+
escape = False
80+
81+
if not in_string:
82+
if char == "{":
83+
stack.append(char)
84+
fixed_json.append(char)
85+
elif char == "}" and stack and stack[-1] == "{":
86+
stack.pop()
87+
fixed_json.append(char)
88+
elif char == "}" and (not stack or stack[-1] != "{"):
89+
continue
90+
else:
91+
fixed_json.append(char)
92+
else:
93+
fixed_json.append(char)
94+
95+
# If stack is not empty, add missing closing braces
96+
while stack:
97+
stack.pop()
98+
fixed_json.append("}")
99+
100+
return "".join(fixed_json)
101+
102+
103+
def fix_invalid_json(invalid_json_string: str) -> str:
104+
# Fix missing quotes around field names
105+
invalid_json_string = re.sub(
106+
r"([{,]\s*)(\w+)(\s*:)", r'\1"\2"\3', invalid_json_string
107+
)
108+
109+
# Fix missing quotes around string values, correctly ignoring null, true, false, and numeric values
110+
invalid_json_string = re.sub(
111+
r"(?<=:\s)(?!(null|true|false|\d+\.?\d*))([a-zA-Z_][a-zA-Z0-9_]*)\s*(?=[,}])",
112+
r'"\2"',
113+
invalid_json_string,
114+
)
115+
116+
# Correct the specific issue: remove trailing commas within arrays or objects before closing braces or brackets
117+
invalid_json_string = re.sub(r",\s*(?=[}\]])", "", invalid_json_string)
118+
119+
# Normalize excessive curly braces
120+
invalid_json_string = re.sub(r"{{+", "{", invalid_json_string)
121+
invalid_json_string = re.sub(r"}}+", "}", invalid_json_string)
122+
123+
# Balance curly braces
124+
return balance_curly_braces(invalid_json_string)
125+
126+
53127
class EntityRelationExtractor(Component, abc.ABC):
54128
"""Abstract class for entity relation extraction components.
55129
@@ -200,16 +274,20 @@ async def extract_for_chunk(
200274
llm_result = self.llm.invoke(prompt)
201275
try:
202276
result = json.loads(llm_result.content)
203-
except json.JSONDecodeError as e:
204-
if self.on_error == OnError.RAISE:
205-
raise LLMGenerationError(
206-
f"LLM response is not valid JSON {llm_result.content}: {e}"
207-
)
208-
else:
209-
logger.error(
210-
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk_index}"
211-
)
212-
result = {"nodes": [], "relationships": []}
277+
except json.JSONDecodeError:
278+
fixed_content = fix_invalid_json(llm_result.content)
279+
try:
280+
result = json.loads(fixed_content)
281+
except json.JSONDecodeError as e:
282+
if self.on_error == OnError.RAISE:
283+
raise LLMGenerationError(
284+
f"LLM response is not valid JSON {fixed_content}: {e}"
285+
)
286+
else:
287+
logger.error(
288+
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk_index}"
289+
)
290+
result = {"nodes": [], "relationships": []}
213291
try:
214292
chunk_graph = Neo4jGraph(**result)
215293
except ValidationError as e:

src/neo4j_genai/generation/prompts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ class ERExtractionTemplate(PromptTemplate):
140140
Do respect the source and target node types for relationship and
141141
the relationship direction.
142142
143+
Do not return any additional information other than the JSON in it.
144+
143145
Examples:
144146
{examples}
145147

tests/unit/experimental/components/test_entity_relation_extractor.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
import json
1718
from unittest.mock import MagicMock
1819

1920
import pytest
@@ -22,6 +23,8 @@
2223
EntityRelationExtractor,
2324
LLMEntityRelationExtractor,
2425
OnError,
26+
balance_curly_braces,
27+
fix_invalid_json,
2528
)
2629
from neo4j_genai.experimental.components.types import (
2730
Neo4jGraph,
@@ -214,3 +217,187 @@ async def test_extractor_custom_prompt() -> None:
214217
chunks = TextChunks(chunks=[TextChunk(text="some text")])
215218
await extractor.run(chunks=chunks)
216219
llm.invoke.assert_called_once_with("this is my prompt")
220+
221+
222+
def test_fix_unquoted_keys() -> None:
223+
json_string = '{name: "John", age: "30"}'
224+
expected_result = '{"name": "John", "age": "30"}'
225+
226+
fixed_json = fix_invalid_json(json_string)
227+
228+
assert json.loads(fixed_json)
229+
assert fixed_json == expected_result
230+
231+
232+
def test_fix_unquoted_string_values() -> None:
233+
json_string = '{"name": John, "age": 30}'
234+
expected_result = '{"name": "John", "age": 30}'
235+
236+
fixed_json = fix_invalid_json(json_string)
237+
238+
assert json.loads(fixed_json)
239+
assert fixed_json == expected_result
240+
241+
242+
def test_remove_trailing_commas() -> None:
243+
json_string = '{"name": "John", "age": 30,}'
244+
expected_result = '{"name": "John", "age": 30}'
245+
246+
fixed_json = fix_invalid_json(json_string)
247+
248+
assert json.loads(fixed_json)
249+
assert fixed_json == expected_result
250+
251+
252+
def test_fix_excessive_braces() -> None:
253+
json_string = '{{"name": "John"}}'
254+
expected_result = '{"name": "John"}'
255+
256+
fixed_json = fix_invalid_json(json_string)
257+
258+
assert json.loads(fixed_json)
259+
assert fixed_json == expected_result
260+
261+
262+
def test_fix_multiple_issues() -> None:
263+
json_string = '{name: John, "hobbies": ["reading", "swimming",], "age": 30}'
264+
expected_result = '{"name": "John", "hobbies": ["reading", "swimming"], "age": 30}'
265+
266+
fixed_json = fix_invalid_json(json_string)
267+
268+
assert json.loads(fixed_json)
269+
assert fixed_json == expected_result
270+
271+
272+
def test_fix_null_values() -> None:
273+
json_string = '{"name": John, "nickname": null}'
274+
expected_result = '{"name": "John", "nickname": null}'
275+
276+
fixed_json = fix_invalid_json(json_string)
277+
278+
assert json.loads(fixed_json)
279+
assert fixed_json == expected_result
280+
281+
282+
def test_fix_numeric_values() -> None:
283+
json_string = '{"age": 30, "score": 95.5}'
284+
expected_result = '{"age": 30, "score": 95.5}'
285+
286+
fixed_json = fix_invalid_json(json_string)
287+
288+
assert json.loads(fixed_json)
289+
assert fixed_json == expected_result
290+
291+
292+
def test_balance_curly_braces_missing_closing() -> None:
293+
json_string = '{"name": "John", "hobbies": {"reading": "yes"'
294+
expected_result = '{"name": "John", "hobbies": {"reading": "yes"}}'
295+
296+
fixed_json = balance_curly_braces(json_string)
297+
298+
assert json.loads(fixed_json)
299+
assert fixed_json == expected_result
300+
301+
302+
def test_balance_curly_braces_extra_closing() -> None:
303+
json_string = '{"name": "John", "hobbies": {"reading": "yes"}}}'
304+
expected_result = '{"name": "John", "hobbies": {"reading": "yes"}}'
305+
306+
fixed_json = balance_curly_braces(json_string)
307+
308+
assert json.loads(fixed_json)
309+
assert fixed_json == expected_result
310+
311+
312+
def test_balance_curly_braces_balanced_input() -> None:
313+
json_string = '{"name": "John", "hobbies": {"reading": "yes"}, "age": 30}'
314+
expected_result = json_string
315+
316+
fixed_json = balance_curly_braces(json_string)
317+
318+
assert json.loads(fixed_json)
319+
assert fixed_json == expected_result
320+
321+
322+
def test_balance_curly_braces_nested_structure() -> None:
323+
json_string = '{"person": {"name": "John", "hobbies": {"reading": "yes"}}}'
324+
expected_result = json_string
325+
326+
fixed_json = balance_curly_braces(json_string)
327+
328+
assert json.loads(fixed_json)
329+
assert fixed_json == expected_result
330+
331+
332+
def test_balance_curly_braces_unbalanced_nested() -> None:
333+
json_string = '{"person": {"name": "John", "hobbies": {"reading": "yes"}}'
334+
expected_result = '{"person": {"name": "John", "hobbies": {"reading": "yes"}}}'
335+
336+
fixed_json = balance_curly_braces(json_string)
337+
338+
assert json.loads(fixed_json)
339+
assert fixed_json == expected_result
340+
341+
342+
def test_balance_curly_braces_unmatched_openings() -> None:
343+
json_string = '{"name": "John", "hobbies": {"reading": "yes"'
344+
expected_result = '{"name": "John", "hobbies": {"reading": "yes"}}'
345+
346+
fixed_json = balance_curly_braces(json_string)
347+
348+
assert json.loads(fixed_json)
349+
assert fixed_json == expected_result
350+
351+
352+
def test_balance_curly_braces_unmatched_closings() -> None:
353+
json_string = '{"name": "John", "hobbies": {"reading": "yes"}}}'
354+
expected_result = '{"name": "John", "hobbies": {"reading": "yes"}}'
355+
356+
fixed_json = balance_curly_braces(json_string)
357+
358+
assert json.loads(fixed_json)
359+
assert fixed_json == expected_result
360+
361+
362+
def test_balance_curly_braces_complex_structure() -> None:
363+
json_string = (
364+
'{"name": "John", "details": {"age": 30, "hobbies": {"reading": "yes"}}}'
365+
)
366+
expected_result = json_string
367+
368+
fixed_json = balance_curly_braces(json_string)
369+
370+
assert json.loads(fixed_json)
371+
assert fixed_json == expected_result
372+
373+
374+
def test_balance_curly_braces_incorrect_nested_closings() -> None:
375+
json_string = '{"key1": {"key2": {"reading": "yes"}}, "key3": {"age": 30}}}'
376+
expected_result = '{"key1": {"key2": {"reading": "yes"}}, "key3": {"age": 30}}'
377+
378+
fixed_json = balance_curly_braces(json_string)
379+
380+
assert json.loads(fixed_json)
381+
assert fixed_json == expected_result
382+
383+
384+
def test_balance_curly_braces_braces_inside_string() -> None:
385+
json_string = '{"name": "John", "example": "a{b}c", "age": 30}'
386+
expected_result = json_string
387+
388+
fixed_json = balance_curly_braces(json_string)
389+
390+
assert json.loads(fixed_json)
391+
assert fixed_json == expected_result
392+
393+
394+
def test_balance_curly_braces_unbalanced_with_string() -> None:
395+
json_string = '{"name": "John", "example": "a{b}c", "hobbies": {"reading": "yes"'
396+
expected_result = (
397+
'{"name": "John", "example": "a{b}c", "hobbies": {"reading": "yes"}}'
398+
)
399+
400+
fixed_json = balance_curly_braces(json_string)
401+
402+
assert json.loads(fixed_json)
403+
assert fixed_json == expected_result

0 commit comments

Comments
 (0)