Skip to content

Commit b19e57c

Browse files
Fix mypy issues
1 parent 52a2686 commit b19e57c

File tree

4 files changed

+32
-26
lines changed

4 files changed

+32
-26
lines changed

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from pathlib import Path
2121

2222
from pydantic import BaseModel, ValidationError, model_validator, validate_call
23-
from requests.exceptions import InvalidJSONError
2423
from typing_extensions import Self
2524

2625
from neo4j_graphrag.exceptions import SchemaValidationError
@@ -336,10 +335,10 @@ async def run(
336335
return self.create_schema_model(entities, relations, potential_schema)
337336

338337

339-
class SchemaFromText(SchemaBuilder):
338+
class SchemaFromText(Component):
340339
"""
341-
A builder class for constructing SchemaConfig objects from the output of an LLM after
342-
automatic schema extraction from text.
340+
A component for constructing SchemaConfig objects from the output of an LLM after
341+
automatic schema extraction from text.
343342
"""
344343

345344
def __init__(
@@ -348,7 +347,6 @@ def __init__(
348347
prompt_template: Optional[PromptTemplate] = None,
349348
llm_params: Optional[Dict[str, Any]] = None,
350349
) -> None:
351-
super().__init__()
352350
self._llm: LLMInterface = llm
353351
self._prompt_template: PromptTemplate = (
354352
prompt_template or SchemaExtractionTemplate()
@@ -369,7 +367,7 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfi
369367
"""
370368
prompt: str = self._prompt_template.format(text=text, examples=examples)
371369

372-
response = await self._llm.invoke(prompt, **self._llm_params)
370+
response = await self._llm.ainvoke(prompt, **self._llm_params)
373371
content: str = (
374372
response
375373
if isinstance(response, str)
@@ -381,8 +379,12 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfi
381379
except json.JSONDecodeError as exc:
382380
raise ValueError("LLM response is not valid JSON.") from exc
383381

384-
extracted_entities: List[dict] = extracted_schema.get("entities", [])
385-
extracted_relations: Optional[List[dict]] = extracted_schema.get("relations")
382+
extracted_entities: List[Dict[str, Any]] = (
383+
extracted_schema.get("entities") or []
384+
)
385+
extracted_relations: Optional[List[Dict[str, Any]]] = extracted_schema.get(
386+
"relations"
387+
)
386388
potential_schema: Optional[List[Tuple[str, str, str]]] = extracted_schema.get(
387389
"potential_schema"
388390
)
@@ -394,7 +396,7 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfi
394396
else None
395397
)
396398

397-
return await super().run(
399+
return SchemaBuilder.create_schema_model(
398400
entities=entities,
399401
relations=relations,
400402
potential_schema=potential_schema,

src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
8181
entities: Sequence[EntityInputType] = []
8282
relations: Sequence[RelationInputType] = []
8383
potential_schema: Optional[list[tuple[str, str, str]]] = None
84-
schema: Optional[Union[SchemaConfig, dict[str, list]]] = None
84+
schema: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = None
8585
enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE
8686
on_error: OnError = OnError.IGNORE
8787
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate()

src/neo4j_graphrag/experimental/pipeline/kg_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from __future__ import annotations
1717

18-
from typing import List, Optional, Sequence, Union
18+
from typing import List, Optional, Sequence, Union, Any
1919
import logging
2020

2121
import neo4j
@@ -92,7 +92,7 @@ def __init__(
9292
entities: Optional[Sequence[EntityInputType]] = None,
9393
relations: Optional[Sequence[RelationInputType]] = None,
9494
potential_schema: Optional[List[tuple[str, str, str]]] = None,
95-
schema: Optional[Union[SchemaConfig, dict[str, list]]] = None,
95+
schema: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = None,
9696
enforce_schema: str = "NONE",
9797
from_pdf: bool = True,
9898
text_splitter: Optional[TextSplitter] = None,

tests/unit/experimental/components/test_schema.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def schema_config(
110110
valid_entities: list[SchemaEntity],
111111
valid_relations: list[SchemaRelation],
112112
potential_schema: list[tuple[str, str, str]],
113-
):
113+
) -> SchemaConfig:
114114
return schema_builder.create_schema_model(
115115
valid_entities, valid_relations, potential_schema
116116
)
@@ -445,14 +445,14 @@ def test_create_schema_model_missing_relations(
445445

446446

447447
@pytest.fixture
448-
def mock_llm():
448+
def mock_llm() -> AsyncMock:
449449
mock = AsyncMock()
450450
mock.invoke = AsyncMock()
451451
return mock
452452

453453

454454
@pytest.fixture
455-
def valid_schema_json():
455+
def valid_schema_json() -> str:
456456
return """
457457
{
458458
"entities": [
@@ -485,7 +485,7 @@ def valid_schema_json():
485485

486486

487487
@pytest.fixture
488-
def invalid_schema_json():
488+
def invalid_schema_json() -> str:
489489
return """
490490
{
491491
"entities": [
@@ -499,14 +499,14 @@ def invalid_schema_json():
499499

500500

501501
@pytest.fixture
502-
def schema_from_text(mock_llm):
502+
def schema_from_text(mock_llm: AsyncMock) -> SchemaFromText:
503503
return SchemaFromText(llm=mock_llm)
504504

505505

506506
@pytest.mark.asyncio
507507
async def test_schema_from_text_run_valid_response(
508-
schema_from_text, mock_llm, valid_schema_json
509-
):
508+
schema_from_text: SchemaFromText, mock_llm: AsyncMock, valid_schema_json: str
509+
) -> None:
510510
# configure the mock LLM to return a valid schema JSON
511511
mock_llm.invoke.return_value = valid_schema_json
512512

@@ -534,8 +534,8 @@ async def test_schema_from_text_run_valid_response(
534534

535535
@pytest.mark.asyncio
536536
async def test_schema_from_text_run_invalid_json(
537-
schema_from_text, mock_llm, invalid_schema_json
538-
):
537+
schema_from_text: SchemaFromText, mock_llm: AsyncMock, invalid_schema_json: str
538+
) -> None:
539539
# configure the mock LLM to return invalid JSON
540540
mock_llm.invoke.return_value = invalid_schema_json
541541

@@ -547,7 +547,9 @@ async def test_schema_from_text_run_invalid_json(
547547

548548

549549
@pytest.mark.asyncio
550-
async def test_schema_from_text_custom_template(mock_llm, valid_schema_json):
550+
async def test_schema_from_text_custom_template(
551+
mock_llm: AsyncMock, valid_schema_json: str
552+
) -> None:
551553
# create a custom template
552554
custom_prompt = "This is a custom prompt with text: {text}"
553555
custom_template = PromptTemplate(template=custom_prompt, expected_inputs=["text"])
@@ -567,7 +569,9 @@ async def test_schema_from_text_custom_template(mock_llm, valid_schema_json):
567569

568570

569571
@pytest.mark.asyncio
570-
async def test_schema_from_text_llm_params(mock_llm, valid_schema_json):
572+
async def test_schema_from_text_llm_params(
573+
mock_llm: AsyncMock, valid_schema_json: str
574+
) -> None:
571575
# configure custom LLM parameters
572576
llm_params = {"temperature": 0.1, "max_tokens": 500}
573577

@@ -588,7 +592,7 @@ async def test_schema_from_text_llm_params(mock_llm, valid_schema_json):
588592

589593

590594
@pytest.mark.asyncio
591-
async def test_schema_config_store_as_json(schema_config):
595+
async def test_schema_config_store_as_json(schema_config: SchemaConfig) -> None:
592596
with tempfile.TemporaryDirectory() as temp_dir:
593597
# create file path
594598
json_path = os.path.join(temp_dir, "schema.json")
@@ -614,7 +618,7 @@ async def test_schema_config_store_as_json(schema_config):
614618

615619

616620
@pytest.mark.asyncio
617-
async def test_schema_config_store_as_yaml(schema_config):
621+
async def test_schema_config_store_as_yaml(schema_config: SchemaConfig) -> None:
618622
with tempfile.TemporaryDirectory() as temp_dir:
619623
# Create file path
620624
yaml_path = os.path.join(temp_dir, "schema.yaml")
@@ -640,7 +644,7 @@ async def test_schema_config_store_as_yaml(schema_config):
640644

641645

642646
@pytest.mark.asyncio
643-
async def test_schema_config_from_file(schema_config):
647+
async def test_schema_config_from_file(schema_config: SchemaConfig) -> None:
644648
with tempfile.TemporaryDirectory() as temp_dir:
645649
# create file paths with different extensions
646650
json_path = os.path.join(temp_dir, "schema.json")

0 commit comments

Comments
 (0)