Skip to content

Commit a72c314

Browse files
committed
Expose all components in SimpleKGPipeline
1 parent bdd82f7 commit a72c314

File tree

4 files changed

+26
-18
lines changed

4 files changed

+26
-18
lines changed

src/neo4j_graphrag/experimental/pipeline/config/object_config.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,10 @@ def parse(self, resolved_data: dict[str, Any] | None = None) -> Embedder:
244244
return self.root.parse(resolved_data)
245245

246246

247-
ComponentT = TypeVar("ComponentT", bound=Component)
247+
ComponentGeneric = TypeVar("ComponentGeneric")
248248

249249

250-
class ComponentConfig(ObjectConfig[ComponentT]):
250+
class ComponentConfig(ObjectConfig[ComponentGeneric], Generic[ComponentGeneric]):
251251
"""A config model for all components.
252252
253253
In addition to the object config, components can have pre-defined parameters
@@ -259,25 +259,27 @@ class ComponentConfig(ObjectConfig[ComponentT]):
259259
DEFAULT_MODULE = "neo4j_graphrag.experimental.components"
260260
INTERFACE = Component
261261

262+
model_config = ConfigDict(arbitrary_types_allowed=True)
263+
262264
def get_run_params(self, resolved_data: dict[str, Any]) -> dict[str, Any]:
263265
self._global_data = resolved_data
264266
return self.resolve_params(self.run_params_)
265267

266268

267269
class ComponentType(
268-
RootModel[Generic[ComponentT]],
269-
# Generic[ComponentT]
270+
RootModel[Union[ComponentGeneric, ComponentConfig[ComponentGeneric]]],
271+
Generic[ComponentGeneric],
270272
):
271-
root: Union[ComponentT, ComponentConfig[ComponentT]]
273+
root: Union[ComponentGeneric, ComponentConfig[ComponentGeneric]]
272274

273275
model_config = ConfigDict(arbitrary_types_allowed=True)
274276

275-
def parse(self, resolved_data: dict[str, Any] | None = None) -> ComponentT:
276-
if isinstance(self.root, Component):
277-
return self.root # type: ignore
278-
return self.root.parse(resolved_data)
277+
def parse(self, resolved_data: dict[str, Any] | None = None) -> ComponentGeneric:
278+
if isinstance(self.root, ComponentConfig):
279+
return self.root.parse(resolved_data)
280+
return self.root
279281

280282
def get_run_params(self, resolved_data: dict[str, Any]) -> dict[str, Any]:
281-
if isinstance(self.root, Component):
282-
return {}
283-
return self.root.get_run_params(resolved_data)
283+
if isinstance(self.root, ComponentConfig):
284+
return self.root.get_run_params(resolved_data)
285+
return {}

src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
8787
chunk_embedder: Optional[ComponentType[TextChunkEmbedder]] = None
8888
extractor: Optional[ComponentType[EntityRelationExtractor]] = None
8989
kg_writer: Optional[ComponentType[KGWriter]] = None
90-
resolver: Optional[list[ComponentType[EntityResolver]]] = None
90+
resolver: Optional[ComponentType[EntityResolver]] = None
9191

9292
model_config = ConfigDict(arbitrary_types_allowed=True)
9393

@@ -158,6 +158,8 @@ def _get_run_params_for_writer(self) -> dict[str, Any]:
158158
def _get_resolver(self) -> Optional[EntityResolver]:
159159
if not self.perform_entity_resolution:
160160
return None
161+
if self.resolver:
162+
return self.resolver.parse(self._global_data)
161163
return SinglePropertyExactMatchResolver(
162164
driver=self.get_default_neo4j_driver(),
163165
neo4j_database=self.neo4j_database,

src/neo4j_graphrag/experimental/pipeline/kg_builder.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
chunk_embedder: Optional[TextChunkEmbedder] = None,
9595
extractor: Optional[EntityRelationExtractor] = None,
9696
kg_writer: Optional[KGWriter] = None,
97-
resolver: Optional[list[EntityResolver]] = None,
97+
resolver: Optional[EntityResolver] = None,
9898
on_error: str = "IGNORE",
9999
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate(),
100100
perform_entity_resolution: bool = True,
@@ -112,12 +112,16 @@ def __init__(
112112
potential_schema=potential_schema,
113113
from_pdf=from_pdf,
114114
pdf_loader=ComponentType(pdf_loader) if pdf_loader else None,
115-
schema_builder=ComponentType(schema_builder) if schema_builder else None,
115+
schema_builder=ComponentType(schema_builder)
116+
if schema_builder
117+
else None,
116118
text_splitter=ComponentType(text_splitter) if text_splitter else None,
117-
chunk_embedder=ComponentType(chunk_embedder) if chunk_embedder else None,
119+
chunk_embedder=ComponentType(chunk_embedder)
120+
if chunk_embedder
121+
else None,
118122
extractor=ComponentType(extractor) if extractor else None,
119123
kg_writer=ComponentType(kg_writer) if kg_writer else None,
120-
resolver=[ComponentType(r) for r in resolver] if resolver else None,
124+
resolver=ComponentType(resolver) if resolver else None,
121125
on_error=OnError(on_error),
122126
prompt_template=prompt_template,
123127
perform_entity_resolution=perform_entity_resolution,

tests/unit/experimental/pipeline/config/test_pipeline_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def test_abstract_pipeline_config_resolve_component_definition_no_run_params(
345345
) -> None:
346346
mock_component_parse.return_value = component
347347
config = AbstractPipelineConfig()
348-
component_type = ComponentType(component)
348+
component_type: ComponentType[Component] = ComponentType(component)
349349
component_definition = config._resolve_component_definition("name", component_type)
350350
assert isinstance(component_definition, ComponentDefinition)
351351
mock_component_parse.assert_called_once_with({})

0 commit comments

Comments
 (0)