Skip to content

Commit 18e8e2a

Browse files
authored
Add flexibility for lexical graph config to SimpleKGPipeline (#209)
* Add flexibility for lexical graph config to SimpleKGPipeline * Update CHANGELOG and update E2E test * Revert LLMEntityRelationExtractor changes and move lexical_graph_config to pipe_inputs in SimpleKGPipeline
1 parent 508323a commit 18e8e2a

File tree

4 files changed

+86
-2
lines changed

4 files changed

+86
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## Next
44

5+
### Added
6+
- Introduced optional lexical graph configuration for SimpleKGPipeline, enhancing flexibility in customizing node labels and relationship types in the lexical graph.
7+
58
## 1.2.0
69

710
### Added

src/neo4j_graphrag/experimental/pipeline/kg_builder.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
4040
FixedSizeSplitter,
4141
)
42+
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
4243
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
4344
from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline, PipelineResult
4445
from neo4j_graphrag.generation.prompts import ERExtractionTemplate
@@ -59,6 +60,7 @@ class SimpleKGPipelineConfig(BaseModel):
5960
on_error: OnError = OnError.RAISE
6061
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate()
6162
perform_entity_resolution: bool = True
63+
lexical_graph_config: Optional[LexicalGraphConfig] = None
6264

6365
model_config = ConfigDict(arbitrary_types_allowed=True)
6466

@@ -84,6 +86,7 @@ class SimpleKGPipeline:
8486
on_error (str): Error handling strategy for the Entity and relation extractor. Defaults to "IGNORE", where chunk will be ignored if extraction fails. Possible values: "RAISE" or "IGNORE".
8587
perform_entity_resolution (bool): Merge entities with same label and name. Default: True
8688
prompt_template (str): A custom prompt template to use for extraction.
89+
lexical_graph_config (Optional[LexicalGraphConfig], optional): Lexical graph configuration to customize node labels and relationship types in the lexical graph.
8790
"""
8891

8992
def __init__(
@@ -101,6 +104,7 @@ def __init__(
101104
on_error: str = "IGNORE",
102105
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate(),
103106
perform_entity_resolution: bool = True,
107+
lexical_graph_config: Optional[LexicalGraphConfig] = None,
104108
):
105109
self.entities = [SchemaEntity(label=label) for label in entities or []]
106110
self.relations = [SchemaRelation(label=label) for label in relations or []]
@@ -127,6 +131,7 @@ def __init__(
127131
prompt_template=prompt_template,
128132
embedder=embedder,
129133
perform_entity_resolution=perform_entity_resolution,
134+
lexical_graph_config=lexical_graph_config,
130135
)
131136

132137
self.from_pdf = config.from_pdf
@@ -141,6 +146,7 @@ def __init__(
141146
)
142147
self.prompt_template = config.prompt_template
143148
self.perform_entity_resolution = config.perform_entity_resolution
149+
self.lexical_graph_config = config.lexical_graph_config
144150

145151
self.pipeline = self._build_pipeline()
146152

@@ -252,4 +258,9 @@ def _prepare_inputs(
252258
else:
253259
pipe_inputs["splitter"] = {"text": text}
254260

261+
if self.lexical_graph_config:
262+
pipe_inputs["extractor"] = {
263+
"lexical_graph_config": self.lexical_graph_config
264+
}
265+
255266
return pipe_inputs

tests/e2e/test_simplekgpipeline_e2e.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import neo4j
2121
import pytest
2222
from neo4j_graphrag.embeddings.base import Embedder
23+
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
2324
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
2425
from neo4j_graphrag.llm import LLMInterface, LLMResponse
2526

@@ -111,6 +112,11 @@ async def test_pipeline_builder_happy_path(
111112
("ORGANIZATION", "LED_BY", "PERSON"),
112113
]
113114

115+
# Additional arguments
116+
lexical_graph_config = LexicalGraphConfig(chunk_node_label="chunkNodeLabel")
117+
from_pdf = False
118+
on_error = "RAISE"
119+
114120
# Create an instance of the SimpleKGPipeline
115121
kg_builder_text = SimpleKGPipeline(
116122
llm=llm,
@@ -119,8 +125,9 @@ async def test_pipeline_builder_happy_path(
119125
entities=entities,
120126
relations=relations,
121127
potential_schema=potential_schema,
122-
from_pdf=False,
123-
on_error="RAISE",
128+
from_pdf=from_pdf,
129+
on_error=on_error,
130+
lexical_graph_config=lexical_graph_config,
124131
)
125132

126133
# Run the knowledge graph building process with text input

tests/unit/experimental/pipeline/test_kg_builder.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from neo4j_graphrag.embeddings import Embedder
2121
from neo4j_graphrag.experimental.components.entity_relation_extractor import OnError
2222
from neo4j_graphrag.experimental.components.schema import SchemaEntity, SchemaRelation
23+
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
2324
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
2425
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
2526
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
@@ -316,3 +317,65 @@ def test_simple_kg_pipeline_no_entity_resolution(_: Mock) -> None:
316317
)
317318

318319
assert "resolver" not in kg_builder.pipeline
320+
321+
322+
@mock.patch(
323+
"neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
324+
return_value=(5, 23, 0),
325+
)
326+
@pytest.mark.asyncio
327+
def test_simple_kg_pipeline_lexical_graph_config_attribute(_: Mock) -> None:
328+
llm = MagicMock(spec=LLMInterface)
329+
driver = MagicMock(spec=neo4j.Driver)
330+
embedder = MagicMock(spec=Embedder)
331+
332+
lexical_graph_config = LexicalGraphConfig()
333+
kg_builder = SimpleKGPipeline(
334+
llm=llm,
335+
driver=driver,
336+
embedder=embedder,
337+
on_error="IGNORE",
338+
lexical_graph_config=lexical_graph_config,
339+
)
340+
341+
assert kg_builder.lexical_graph_config == lexical_graph_config
342+
343+
344+
@mock.patch(
345+
"neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
346+
return_value=(5, 23, 0),
347+
)
348+
@pytest.mark.asyncio
349+
async def test_knowledge_graph_builder_with_lexical_graph_config(_: Mock) -> None:
350+
llm = MagicMock(spec=LLMInterface)
351+
driver = MagicMock(spec=neo4j.Driver)
352+
embedder = MagicMock(spec=Embedder)
353+
354+
chunk_node_label = "TestChunk"
355+
document_nodel_label = "TestDocument"
356+
lexical_graph_config = LexicalGraphConfig(
357+
chunk_node_label=chunk_node_label, document_node_label=document_nodel_label
358+
)
359+
360+
kg_builder = SimpleKGPipeline(
361+
llm=llm,
362+
driver=driver,
363+
embedder=embedder,
364+
from_pdf=False,
365+
lexical_graph_config=lexical_graph_config,
366+
)
367+
368+
text_input = "May thy knife chip and shatter."
369+
370+
with patch.object(
371+
kg_builder.pipeline,
372+
"run",
373+
return_value=PipelineResult(run_id="test_run", result=None),
374+
) as mock_run:
375+
await kg_builder.run_async(text=text_input)
376+
377+
pipe_inputs = mock_run.call_args[0][0]
378+
assert "extractor" in pipe_inputs
379+
assert pipe_inputs["extractor"] == {
380+
"lexical_graph_config": lexical_graph_config
381+
}

0 commit comments

Comments
 (0)