Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/neo4j_graphrag/experimental/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import json
import re

import neo4j
import logging
Expand Down Expand Up @@ -554,6 +555,15 @@ def _filter_relationships_without_labels(
relationship_types, "relationship type"
)

def _clean_json_content(self, content: str) -> str:
content = content.strip()

# Remove markdown code block markers if present
content = re.sub(r'^```(?:json)?\s*', '', content, flags=re.MULTILINE)
content = re.sub(r'```\s*$', '', content, flags=re.MULTILINE)

return content.strip()

@validate_call
async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema:
"""
Expand All @@ -575,6 +585,9 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
# Re-raise the LLMGenerationError
raise LLMGenerationError("Failed to generate schema from text") from e

# Clean response
content = self._clean_json_content(content)

try:
extracted_schema: Dict[str, Any] = json.loads(content)

Expand Down
80 changes: 80 additions & 0 deletions tests/unit/experimental/components/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,86 @@ async def test_schema_from_text_filters_relationships_without_labels(
assert ("Person", "MANAGES", "Organization") in schema.patterns


@pytest.fixture
def valid_schema_json_with_markdown() -> str:
return """```json
{
"node_types": [
{
"label": "Person",
"properties": [
{"name": "name", "type": "STRING"}
]
},
{
"label": "Organization",
"properties": [
{"name": "name", "type": "STRING"}
]
}
],
"relationship_types": [
{
"label": "WORKS_FOR",
"properties": [
{"name": "since", "type": "DATE"}
]
}
],
"patterns": [
["Person", "WORKS_FOR", "Organization"]
]
}
```"""


@pytest.fixture
def valid_schema_json_with_markdown_no_language() -> str:
return """```
{
"node_types": [
{
"label": "Person",
"properties": [
{"name": "name", "type": "STRING"}
]
}
]
}
```"""


def test_clean_json_content_markdown_with_json_language(
schema_from_text: SchemaFromTextExtractor,
) -> None:
content = """```json
{"node_types": [{"label": "Person"}]}
```"""

cleaned = schema_from_text._clean_json_content(content)
assert cleaned == '{"node_types": [{"label": "Person"}]}'


def test_clean_json_content_markdown_without_language(
schema_from_text: SchemaFromTextExtractor,
) -> None:
content = """```
{"node_types": [{"label": "Person"}]}
```"""

cleaned = schema_from_text._clean_json_content(content)
assert cleaned == '{"node_types": [{"label": "Person"}]}'


def test_clean_json_content_plain_json(
schema_from_text: SchemaFromTextExtractor,
) -> None:
content = '{"node_types": [{"label": "Person"}]}'

cleaned = schema_from_text._clean_json_content(content)
assert cleaned == '{"node_types": [{"label": "Person"}]}'


@pytest.mark.asyncio
@patch("neo4j_graphrag.experimental.components.schema.get_structured_schema")
async def test_schema_from_existing_graph(mock_get_structured_schema: Mock) -> None:
Expand Down
Loading