Skip to content

Commit 8885e2c

Browse files
Fix remaining mypy errors
1 parent 7558b56 commit 8885e2c

File tree

1 file changed

+30
-22
lines changed

1 file changed

+30
-22
lines changed

src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Any, ClassVar, Literal, Optional, Sequence, Union, TypeVar
15+
from typing import Any, ClassVar, Literal, Optional, Sequence, Union, List, Tuple
1616
import logging
1717

18-
from pydantic import ConfigDict, model_validator
18+
from pydantic import ConfigDict, Field, model_validator
19+
from typing_extensions import Self
1920

2021
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
2122
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
@@ -59,8 +60,6 @@
5960

6061
logger = logging.getLogger(__name__)
6162

62-
T = TypeVar("T", bound="SimpleKGPipelineConfig")
63-
6463

6564
class SimpleKGPipelineConfig(TemplatePipelineConfig):
6665
COMPONENTS: ClassVar[list[str]] = [
@@ -81,7 +80,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
8180
entities: Sequence[EntityInputType] = []
8281
relations: Sequence[RelationInputType] = []
8382
potential_schema: Optional[list[tuple[str, str, str]]] = None
84-
schema: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = None # type: ignore
83+
schema_: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = Field(default=None, alias="schema")
8584
enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE
8685
on_error: OnError = OnError.IGNORE
8786
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate()
@@ -97,10 +96,10 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
9796
model_config = ConfigDict(arbitrary_types_allowed=True)
9897

9998
@model_validator(mode="after")
100-
def handle_schema_precedence(self) -> T: # type: ignore
99+
def handle_schema_precedence(self) -> Self:
101100
"""Handle schema precedence and warnings"""
102101
self._process_schema_parameters()
103-
return self # type: ignore
102+
return self
104103

105104
def _process_schema_parameters(self) -> None:
106105
"""
@@ -112,7 +111,7 @@ def _process_schema_parameters(self) -> None:
112111
[self.entities, self.relations, self.potential_schema]
113112
)
114113

115-
if has_individual_schema_components and self.schema is not None:
114+
if has_individual_schema_components and self.schema_ is not None:
116115
logger.warning(
117116
"Both 'schema' and individual schema components (entities, relations, potential_schema) "
118117
"were provided. The 'schema' parameter takes precedence. In the future, individual "
@@ -134,7 +133,7 @@ def has_user_provided_schema(self) -> bool:
134133
self.entities
135134
or self.relations
136135
or self.potential_schema
137-
or self.schema is not None
136+
or self.schema_ is not None
138137
)
139138

140139
def _get_pdf_loader(self) -> Optional[PdfLoader]:
@@ -175,8 +174,8 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]:
175174

176175
def _process_schema_with_precedence(
177176
self,
178-
) -> tuple[
179-
list[SchemaEntity], list[SchemaRelation], Optional[list[tuple[str, str, str]]]
177+
) -> Tuple[
178+
List[SchemaEntity], List[SchemaRelation], Optional[List[Tuple[str, str, str]]]
180179
]:
181180
"""
182181
Process schema inputs according to precedence rules:
@@ -187,28 +186,37 @@ def _process_schema_with_precedence(
187186
Returns:
188187
Tuple of (entities, relations, potential_schema)
189188
"""
190-
if self.schema is not None:
189+
if self.schema_ is not None:
191190
# schema takes precedence over individual components
192-
if isinstance(self.schema, SchemaConfig):
191+
if isinstance(self.schema_, SchemaConfig):
193192
# extract components from SchemaConfig
194-
entities = list(self.schema.entities.values())
195-
relations = list(self.schema.relations.values()) # type: ignore
196-
potential_schema = self.schema.potential_schema
193+
entity_dicts = list(self.schema_.entities.values())
194+
# convert dict values to SchemaEntity objects
195+
entities = [SchemaEntity.model_validate(e) for e in entity_dicts]
196+
197+
# handle case where relations could be None
198+
if self.schema_.relations is not None:
199+
relation_dicts = list(self.schema_.relations.values())
200+
relations = [SchemaRelation.model_validate(r) for r in relation_dicts]
201+
else:
202+
relations = []
203+
204+
potential_schema = self.schema_.potential_schema
197205
else:
198206
# extract from dictionary
199207
entities = [
200-
SchemaEntity.from_text_or_dict(e) # type: ignore
201-
for e in self.schema.get("entities", [])
208+
SchemaEntity.from_text_or_dict(e)
209+
for e in self.schema_.get("entities", [])
202210
]
203211
relations = [
204212
SchemaRelation.from_text_or_dict(r)
205-
for r in self.schema.get("relations", [])
213+
for r in self.schema_.get("relations", [])
206214
]
207-
potential_schema = self.schema.get("potential_schema")
215+
potential_schema = self.schema_.get("potential_schema")
208216
else:
209217
# use individual components
210218
entities = (
211-
[SchemaEntity.from_text_or_dict(e) for e in self.entities] # type: ignore
219+
[SchemaEntity.from_text_or_dict(e) for e in self.entities]
212220
if self.entities
213221
else []
214222
)
@@ -219,7 +227,7 @@ def _process_schema_with_precedence(
219227
)
220228
potential_schema = self.potential_schema
221229

222-
return entities, relations, potential_schema # type: ignore
230+
return entities, relations, potential_schema
223231

224232
def _get_run_params_for_schema(self) -> dict[str, Any]:
225233
if self.auto_schema_extraction and not self.has_user_provided_schema():

0 commit comments

Comments
 (0)