Skip to content

Commit 523df15

Browse files
committed
Fixes
1 parent 65183d7 commit 523df15

File tree

4 files changed

+55
-48
lines changed

4 files changed

+55
-48
lines changed

src/neo4j_graphrag/experimental/components/graph_pruning.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Neo4jGraph,
2929
Neo4jNode,
3030
Neo4jRelationship,
31+
LexicalGraphConfig,
3132
)
3233
from neo4j_graphrag.experimental.pipeline import Component, DataModel
3334

@@ -135,9 +136,14 @@ async def run(
135136
self,
136137
graph: Neo4jGraph,
137138
schema: Optional[GraphSchema] = None,
139+
lexical_graph_config: Optional[LexicalGraphConfig] = None,
138140
) -> GraphPruningResult:
141+
if lexical_graph_config is None:
142+
lexical_graph_config = LexicalGraphConfig()
139143
if schema is not None:
140-
new_graph, pruning_stats = self._clean_graph(graph, schema)
144+
new_graph, pruning_stats = self._clean_graph(
145+
graph, schema, lexical_graph_config
146+
)
141147
else:
142148
new_graph = graph
143149
pruning_stats = PruningStats()
@@ -150,6 +156,7 @@ def _clean_graph(
150156
self,
151157
graph: Neo4jGraph,
152158
schema: GraphSchema,
159+
lexical_graph_config: LexicalGraphConfig,
153160
) -> tuple[Neo4jGraph, PruningStats]:
154161
"""
155162
Verify that the graph conforms to the provided schema.
@@ -162,6 +169,7 @@ def _clean_graph(
162169
filtered_nodes = self._enforce_nodes(
163170
graph.nodes,
164171
schema,
172+
lexical_graph_config,
165173
pruning_stats,
166174
)
167175
if not filtered_nodes:
@@ -174,6 +182,7 @@ def _clean_graph(
174182
graph.relationships,
175183
filtered_nodes,
176184
schema,
185+
lexical_graph_config,
177186
pruning_stats,
178187
)
179188

@@ -216,6 +225,7 @@ def _enforce_nodes(
216225
self,
217226
extracted_nodes: list[Neo4jNode],
218227
schema: GraphSchema,
228+
lexical_graph_config: LexicalGraphConfig,
219229
pruning_stats: PruningStats,
220230
) -> list[Neo4jNode]:
221231
"""
@@ -228,6 +238,9 @@ def _enforce_nodes(
228238
"""
229239
valid_nodes = []
230240
for node in extracted_nodes:
241+
if node.label in lexical_graph_config.lexical_graph_node_labels:
242+
valid_nodes.append(node)
243+
continue
231244
schema_entity = schema.node_type_from_label(node.label)
232245
new_node = self._validate_node(
233246
node,
@@ -319,6 +332,7 @@ def _enforce_relationships(
319332
extracted_relationships: list[Neo4jRelationship],
320333
filtered_nodes: list[Neo4jNode],
321334
schema: GraphSchema,
335+
lexical_graph_config: LexicalGraphConfig,
322336
pruning_stats: PruningStats,
323337
) -> list[Neo4jRelationship]:
324338
"""
@@ -334,6 +348,9 @@ def _enforce_relationships(
334348
valid_rels = []
335349
valid_nodes = {node.id: node.label for node in filtered_nodes}
336350
for rel in extracted_relationships:
351+
if rel.type in lexical_graph_config.lexical_graph_relationship_types:
352+
valid_rels.append(rel)
353+
continue
337354
schema_relation = schema.relationship_type_from_label(rel.type)
338355
new_rel = self._validate_relationship(
339356
rel,

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def create_schema_model(
334334
node_types: Sequence[NodeType],
335335
relationship_types: Optional[Sequence[RelationshipType]] = None,
336336
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
337+
**kwargs: Any,
337338
) -> GraphSchema:
338339
"""
339340
Creates a GraphSchema object from Lists of Entity and Relation objects
@@ -343,6 +344,7 @@ def create_schema_model(
343344
node_types (Sequence[NodeType]): List or tuple of NodeType objects.
344345
relationship_types (Optional[Sequence[RelationshipType]]): List or tuple of RelationshipType objects.
345346
patterns (Optional[Sequence[Tuple[str, str, str]]]): List or tuples of triplets: (source_entity_label, relation_label, target_entity_label).
347+
kwargs: other arguments passed to GraphSchema validator.
346348
347349
Returns:
348350
GraphSchema: A configured schema object.
@@ -353,17 +355,19 @@ def create_schema_model(
353355
node_types=node_types,
354356
relationship_types=relationship_types or (),
355357
patterns=patterns or (),
358+
**kwargs,
356359
)
357360
)
358-
except (ValidationError, SchemaValidationError) as e:
359-
raise SchemaValidationError(e) from e
361+
except ValidationError as e:
362+
raise SchemaValidationError() from e
360363

361364
@validate_call
362365
async def run(
363366
self,
364367
node_types: Sequence[NodeType],
365368
relationship_types: Optional[Sequence[RelationshipType]] = None,
366369
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
370+
**kwargs: Any,
367371
) -> GraphSchema:
368372
"""
369373
Asynchronously constructs and returns a GraphSchema object.
@@ -376,7 +380,12 @@ async def run(
376380
Returns:
377381
GraphSchema: A configured schema object, constructed asynchronously.
378382
"""
379-
return self.create_schema_model(node_types, relationship_types, patterns)
383+
return self.create_schema_model(
384+
node_types,
385+
relationship_types,
386+
patterns,
387+
**kwargs,
388+
)
380389

381390

382391
class SchemaFromTextExtractor(Component):

src/neo4j_graphrag/experimental/components/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,14 @@ class LexicalGraphConfig(BaseModel):
174174
def lexical_graph_node_labels(self) -> tuple[str, ...]:
175175
return self.document_node_label, self.chunk_node_label
176176

177+
@property
178+
def lexical_graph_relationship_types(self) -> tuple[str, ...]:
179+
return (
180+
self.chunk_to_document_relationship_type,
181+
self.next_chunk_relationship_type,
182+
self.node_to_chunk_relationship_type,
183+
)
184+
177185

178186
class GraphResult(DataModel):
179187
graph: Neo4jGraph

src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py

Lines changed: 17 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -184,66 +184,33 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]:
184184
return SchemaFromTextExtractor(llm=self.get_default_llm())
185185
return SchemaBuilder()
186186

187-
def _process_schema_with_precedence(
188-
self,
189-
) -> Tuple[
190-
Tuple[NodeType, ...],
191-
Tuple[RelationshipType, ...] | None,
192-
Optional[Tuple[Tuple[str, str, str], ...]] | None,
193-
]:
187+
def _process_schema_with_precedence(self) -> GraphSchema:
194188
"""
195189
Process schema inputs according to precedence rules:
196190
1. If schema is provided as GraphSchema object, use it
197191
2. If schema is provided as dictionary, extract from it
198192
3. Otherwise, use individual schema components
199193
200194
Returns:
201-
Tuple of (node_types, relationship_types, patterns)
195+
A GraphSchema object
202196
"""
203197
if self.schema_ is not None:
204-
# schema takes precedence over individual components
205-
node_types = self.schema_.node_types
198+
return self.schema_
206199

207-
# handle case where relations could be None
208-
if self.schema_.relationship_types is not None:
209-
relationship_types = self.schema_.relationship_types
210-
else:
211-
relationship_types = None
212-
213-
patterns = self.schema_.patterns
214-
else:
215-
# use individual components
216-
node_types = tuple(
217-
[NodeType.model_validate(e) for e in self.entities]
218-
if self.entities
219-
else []
220-
)
221-
relationship_types = (
222-
tuple([RelationshipType.model_validate(r) for r in self.relations])
223-
if self.relations is not None
224-
else None
225-
)
226-
patterns = (
227-
tuple(self.potential_schema) if self.potential_schema else tuple()
228-
)
229-
230-
return node_types, relationship_types, patterns
200+
return GraphSchema(
201+
node_types=tuple(self.entities) if self.entities else tuple(),
202+
relationship_types=tuple(self.relations) if self.relations else tuple(),
203+
patterns=tuple(self.potential_schema) if self.potential_schema else tuple(),
204+
)
231205

232206
def _get_run_params_for_schema(self) -> dict[str, Any]:
233207
if not self.has_user_provided_schema():
234208
# for automatic extraction, the text parameter is needed (will flow through the pipeline connections)
235209
return {}
236210
else:
237211
# process schema components according to precedence rules
238-
node_types, relationship_types, patterns = (
239-
self._process_schema_with_precedence()
240-
)
241-
242-
return {
243-
"node_types": node_types,
244-
"relationship_types": relationship_types,
245-
"patterns": patterns,
246-
}
212+
schema = self._process_schema_with_precedence()
213+
return schema.model_dump()
247214

248215
def _get_extractor(self) -> EntityRelationExtractor:
249216
return LLMEntityRelationExtractor(
@@ -368,7 +335,13 @@ def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:
368335
run_params = {}
369336
if self.lexical_graph_config:
370337
run_params["extractor"] = {
371-
"lexical_graph_config": self.lexical_graph_config
338+
"lexical_graph_config": self.lexical_graph_config,
339+
}
340+
run_params["writer"] = {
341+
"lexical_graph_config": self.lexical_graph_config,
342+
}
343+
run_params["pruner"] = {
344+
"lexical_graph_config": self.lexical_graph_config,
372345
}
373346
text = user_input.get("text")
374347
file_path = user_input.get("file_path")

0 commit comments

Comments
 (0)