Skip to content

Commit bdd82f7

Browse files
committed
WIP: make all component parameters in simplekgpipeline
1 parent e99ebb0 commit bdd82f7

File tree

6 files changed

+53
-22
lines changed

6 files changed

+53
-22
lines changed

src/neo4j_graphrag/experimental/pipeline/config/object_config.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,10 @@ def parse(self, resolved_data: dict[str, Any] | None = None) -> Embedder:
244244
return self.root.parse(resolved_data)
245245

246246

247-
class ComponentConfig(ObjectConfig[Component]):
247+
ComponentT = TypeVar("ComponentT", bound=Component)
248+
249+
250+
class ComponentConfig(ObjectConfig[ComponentT]):
248251
"""A config model for all components.
249252
250253
In addition to the object config, components can have pre-defined parameters
@@ -261,14 +264,17 @@ def get_run_params(self, resolved_data: dict[str, Any]) -> dict[str, Any]:
261264
return self.resolve_params(self.run_params_)
262265

263266

264-
class ComponentType(RootModel): # type: ignore[type-arg]
265-
root: Union[Component, ComponentConfig]
267+
class ComponentType(
268+
RootModel[Generic[ComponentT]],
269+
# Generic[ComponentT]
270+
):
271+
root: Union[ComponentT, ComponentConfig[ComponentT]]
266272

267273
model_config = ConfigDict(arbitrary_types_allowed=True)
268274

269-
def parse(self, resolved_data: dict[str, Any] | None = None) -> Component:
275+
def parse(self, resolved_data: dict[str, Any] | None = None) -> ComponentT:
270276
if isinstance(self.root, Component):
271-
return self.root
277+
return self.root # type: ignore
272278
return self.root.parse(resolved_data)
273279

274280
def get_run_params(self, resolved_data: dict[str, Any]) -> dict[str, Any]:

src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pydantic import field_validator
2121

2222
from neo4j_graphrag.embeddings import Embedder
23+
from neo4j_graphrag.experimental.pipeline import Component
2324
from neo4j_graphrag.experimental.pipeline.config.base import AbstractConfig
2425
from neo4j_graphrag.experimental.pipeline.config.object_config import (
2526
ComponentType,
@@ -83,7 +84,7 @@ def validate_embedders(
8384
return embedders
8485

8586
def _resolve_component_definition(
86-
self, name: str, config: ComponentType
87+
self, name: str, config: ComponentType[Component]
8788
) -> ComponentDefinition:
8889
component = config.parse(self._global_data)
8990
if hasattr(config.root, "run_params_"):
@@ -188,7 +189,7 @@ class PipelineConfig(AbstractPipelineConfig):
188189
"""Configuration class for raw pipelines.
189190
Config must contain all components and connections."""
190191

191-
component_config: dict[str, ComponentType]
192+
component_config: dict[str, ComponentType[Component]]
192193
connection_config: list[ConnectionDefinition]
193194
template_: Literal[PipelineType.NONE] = PipelineType.NONE
194195

src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
OnError,
2424
)
2525
from neo4j_graphrag.experimental.components.kg_writer import KGWriter, Neo4jWriter
26-
from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader
26+
from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader, DataLoader
2727
from neo4j_graphrag.experimental.components.resolver import (
2828
EntityResolver,
2929
SinglePropertyExactMatchResolver,
@@ -81,17 +81,21 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
8181
lexical_graph_config: Optional[LexicalGraphConfig] = None
8282
neo4j_database: Optional[str] = None
8383

84-
pdf_loader: Optional[ComponentType] = None
85-
kg_writer: Optional[ComponentType] = None
86-
text_splitter: Optional[ComponentType] = None
84+
pdf_loader: Optional[ComponentType[DataLoader]] = None
85+
schema_builder: Optional[ComponentType[SchemaBuilder]] = None
86+
text_splitter: Optional[ComponentType[TextSplitter]] = None
87+
chunk_embedder: Optional[ComponentType[TextChunkEmbedder]] = None
88+
extractor: Optional[ComponentType[EntityRelationExtractor]] = None
89+
kg_writer: Optional[ComponentType[KGWriter]] = None
90+
resolver: Optional[list[ComponentType[EntityResolver]]] = None
8791

8892
model_config = ConfigDict(arbitrary_types_allowed=True)
8993

90-
def _get_pdf_loader(self) -> Optional[PdfLoader]:
94+
def _get_pdf_loader(self) -> Optional[DataLoader]:
9195
if not self.from_pdf:
9296
return None
9397
if self.pdf_loader:
94-
return self.pdf_loader.parse(self._global_data) # type: ignore
98+
return self.pdf_loader.parse(self._global_data)
9599
return PdfLoader()
96100

97101
def _get_run_params_for_pdf_loader(self) -> dict[str, Any]:
@@ -103,7 +107,7 @@ def _get_run_params_for_pdf_loader(self) -> dict[str, Any]:
103107

104108
def _get_splitter(self) -> TextSplitter:
105109
if self.text_splitter:
106-
return self.text_splitter.parse(self._global_data) # type: ignore
110+
return self.text_splitter.parse(self._global_data)
107111
return FixedSizeSplitter()
108112

109113
def _get_run_params_for_splitter(self) -> dict[str, Any]:
@@ -112,9 +116,13 @@ def _get_run_params_for_splitter(self) -> dict[str, Any]:
112116
return {}
113117

114118
def _get_chunk_embedder(self) -> TextChunkEmbedder:
119+
if self.chunk_embedder:
120+
return self.chunk_embedder.parse(self._global_data)
115121
return TextChunkEmbedder(embedder=self.get_default_embedder())
116122

117123
def _get_schema(self) -> SchemaBuilder:
124+
if self.schema_builder:
125+
return self.schema_builder.parse(self._global_data)
118126
return SchemaBuilder()
119127

120128
def _get_run_params_for_schema(self) -> dict[str, Any]:
@@ -125,6 +133,8 @@ def _get_run_params_for_schema(self) -> dict[str, Any]:
125133
}
126134

127135
def _get_extractor(self) -> EntityRelationExtractor:
136+
if self.extractor:
137+
return self.extractor.parse(self._global_data)
128138
return LLMEntityRelationExtractor(
129139
llm=self.get_default_llm(),
130140
prompt_template=self.prompt_template,
@@ -134,7 +144,7 @@ def _get_extractor(self) -> EntityRelationExtractor:
134144

135145
def _get_writer(self) -> KGWriter:
136146
if self.kg_writer:
137-
return self.kg_writer.parse(self._global_data) # type: ignore
147+
return self.kg_writer.parse(self._global_data)
138148
return Neo4jWriter(
139149
driver=self.get_default_neo4j_driver(),
140150
neo4j_database=self.neo4j_database,

src/neo4j_graphrag/experimental/pipeline/kg_builder.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,15 @@
2121
from pydantic import ValidationError
2222

2323
from neo4j_graphrag.embeddings import Embedder
24-
from neo4j_graphrag.experimental.components.entity_relation_extractor import OnError
24+
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
25+
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
26+
OnError,
27+
EntityRelationExtractor,
28+
)
2529
from neo4j_graphrag.experimental.components.kg_writer import KGWriter
2630
from neo4j_graphrag.experimental.components.pdf_loader import DataLoader
31+
from neo4j_graphrag.experimental.components.resolver import EntityResolver
32+
from neo4j_graphrag.experimental.components.schema import SchemaBuilder
2733
from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter
2834
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
2935
from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType
@@ -82,9 +88,13 @@ def __init__(
8288
relations: Optional[Sequence[RelationInputType]] = None,
8389
potential_schema: Optional[List[tuple[str, str, str]]] = None,
8490
from_pdf: bool = True,
85-
text_splitter: Optional[TextSplitter] = None,
8691
pdf_loader: Optional[DataLoader] = None,
92+
schema_builder: Optional[SchemaBuilder] = None,
93+
text_splitter: Optional[TextSplitter] = None,
94+
chunk_embedder: Optional[TextChunkEmbedder] = None,
95+
extractor: Optional[EntityRelationExtractor] = None,
8796
kg_writer: Optional[KGWriter] = None,
97+
resolver: Optional[list[EntityResolver]] = None,
8898
on_error: str = "IGNORE",
8999
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate(),
90100
perform_entity_resolution: bool = True,
@@ -102,8 +112,12 @@ def __init__(
102112
potential_schema=potential_schema,
103113
from_pdf=from_pdf,
104114
pdf_loader=ComponentType(pdf_loader) if pdf_loader else None,
105-
kg_writer=ComponentType(kg_writer) if kg_writer else None,
115+
schema_builder=ComponentType(schema_builder) if schema_builder else None,
106116
text_splitter=ComponentType(text_splitter) if text_splitter else None,
117+
chunk_embedder=ComponentType(chunk_embedder) if chunk_embedder else None,
118+
extractor=ComponentType(extractor) if extractor else None,
119+
kg_writer=ComponentType(kg_writer) if kg_writer else None,
120+
resolver=[ComponentType(r) for r in resolver] if resolver else None,
107121
on_error=OnError(on_error),
108122
prompt_template=prompt_template,
109123
perform_entity_resolution=perform_entity_resolution,

tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_simple_kg_pipeline_config_pdf_loader_class_overwrite_but_from_pdf_is_fa
7171
def test_simple_kg_pipeline_config_pdf_loader_from_pdf_is_true_class_overwrite_from_config(
7272
mock_component_parse: Mock,
7373
) -> None:
74-
my_pdf_loader_config = ComponentConfig(
74+
my_pdf_loader_config: ComponentConfig[PdfLoader] = ComponentConfig(
7575
class_="",
7676
)
7777
my_pdf_loader = PdfLoader()
@@ -92,7 +92,7 @@ def test_simple_kg_pipeline_config_text_splitter() -> None:
9292
def test_simple_kg_pipeline_config_text_splitter_overwrite(
9393
mock_component_parse: Mock,
9494
) -> None:
95-
my_text_splitter_config = ComponentConfig(
95+
my_text_splitter_config: ComponentConfig[FixedSizeSplitter] = ComponentConfig(
9696
class_="",
9797
)
9898
my_text_splitter = FixedSizeSplitter()
@@ -184,7 +184,7 @@ def test_simple_kg_pipeline_config_writer_overwrite(
184184
_: Mock,
185185
driver: neo4j.Driver,
186186
) -> None:
187-
my_writer_config = ComponentConfig(
187+
my_writer_config: ComponentConfig[Neo4jWriter] = ComponentConfig(
188188
class_="",
189189
)
190190
my_writer = Neo4jWriter(driver, neo4j_database="my_db")

tests/unit/experimental/pipeline/config/test_pipeline_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def test_abstract_pipeline_config_resolve_component_definition_with_run_params(
366366
mock_component_parse.return_value = component
367367
mock_resolve_params.return_value = {"param": "resolver param result"}
368368
config = AbstractPipelineConfig()
369-
component_type = ComponentType(
369+
component_type: ComponentType[Component] = ComponentType(
370370
ComponentConfig(class_="", params_={}, run_params_={"param1": "value1"})
371371
)
372372
component_definition = config._resolve_component_definition("name", component_type)

0 commit comments

Comments
 (0)