|
14 | 14 | # limitations under the License.
|
15 | 15 | from __future__ import annotations
|
16 | 16 |
|
| 17 | +import json |
| 18 | +from unittest.mock import AsyncMock |
| 19 | + |
17 | 20 | import pytest
|
18 | 21 | from neo4j_graphrag.exceptions import SchemaValidationError
|
19 | 22 | from neo4j_graphrag.experimental.components.schema import (
|
20 | 23 | SchemaBuilder,
|
21 | 24 | SchemaEntity,
|
22 | 25 | SchemaProperty,
|
23 | 26 | SchemaRelation,
|
| 27 | + SchemaFromText, |
| 28 | + SchemaConfig, |
24 | 29 | )
|
25 | 30 | from pydantic import ValidationError
|
| 31 | +import os |
| 32 | +import tempfile |
| 33 | +import yaml |
| 34 | +from pathlib import Path |
| 35 | + |
| 36 | +from neo4j_graphrag.generation import PromptTemplate |
26 | 37 |
|
27 | 38 |
|
28 | 39 | @pytest.fixture
|
@@ -93,6 +104,18 @@ def schema_builder() -> SchemaBuilder:
|
93 | 104 | return SchemaBuilder()
|
94 | 105 |
|
95 | 106 |
|
| 107 | +@pytest.fixture |
| 108 | +def schema_config( |
| 109 | + schema_builder: SchemaBuilder, |
| 110 | + valid_entities: list[SchemaEntity], |
| 111 | + valid_relations: list[SchemaRelation], |
| 112 | + potential_schema: list[tuple[str, str, str]], |
| 113 | +): |
| 114 | + return schema_builder.create_schema_model( |
| 115 | + valid_entities, valid_relations, potential_schema |
| 116 | + ) |
| 117 | + |
| 118 | + |
96 | 119 | def test_create_schema_model_valid_data(
|
97 | 120 | schema_builder: SchemaBuilder,
|
98 | 121 | valid_entities: list[SchemaEntity],
|
@@ -419,3 +442,224 @@ def test_create_schema_model_missing_relations(
|
419 | 442 | assert "Relations must also be provided when using a potential schema." in str(
|
420 | 443 | exc_info.value
|
421 | 444 | ), "Should fail due to missing relations"
|
| 445 | + |
| 446 | + |
| 447 | +@pytest.fixture |
| 448 | +def mock_llm(): |
| 449 | + mock = AsyncMock() |
| 450 | + mock.invoke = AsyncMock() |
| 451 | + return mock |
| 452 | + |
| 453 | + |
| 454 | +@pytest.fixture |
| 455 | +def valid_schema_json(): |
| 456 | + return ''' |
| 457 | + { |
| 458 | + "entities": [ |
| 459 | + { |
| 460 | + "label": "Person", |
| 461 | + "properties": [ |
| 462 | + {"name": "name", "type": "STRING"} |
| 463 | + ] |
| 464 | + }, |
| 465 | + { |
| 466 | + "label": "Organization", |
| 467 | + "properties": [ |
| 468 | + {"name": "name", "type": "STRING"} |
| 469 | + ] |
| 470 | + } |
| 471 | + ], |
| 472 | + "relations": [ |
| 473 | + { |
| 474 | + "label": "WORKS_FOR", |
| 475 | + "properties": [ |
| 476 | + {"name": "since", "type": "DATE"} |
| 477 | + ] |
| 478 | + } |
| 479 | + ], |
| 480 | + "potential_schema": [ |
| 481 | + ["Person", "WORKS_FOR", "Organization"] |
| 482 | + ] |
| 483 | + } |
| 484 | + ''' |
| 485 | + |
| 486 | + |
| 487 | +@pytest.fixture |
| 488 | +def invalid_schema_json(): |
| 489 | + return ''' |
| 490 | + { |
| 491 | + "entities": [ |
| 492 | + { |
| 493 | + "label": "Person", |
| 494 | + }, |
| 495 | + ], |
| 496 | + invalid json content |
| 497 | + } |
| 498 | + ''' |
| 499 | + |
| 500 | + |
| 501 | +@pytest.fixture |
| 502 | +def schema_from_text(mock_llm): |
| 503 | + return SchemaFromText(llm=mock_llm) |
| 504 | + |
| 505 | + |
| 506 | +@pytest.mark.asyncio |
| 507 | +async def test_schema_from_text_run_valid_response(schema_from_text, mock_llm, valid_schema_json): |
| 508 | + # configure the mock LLM to return a valid schema JSON |
| 509 | + mock_llm.invoke.return_value = valid_schema_json |
| 510 | + |
| 511 | + # run the schema extraction |
| 512 | + schema_config = await schema_from_text.run(text="Sample text for extraction") |
| 513 | + |
| 514 | + # verify the LLM was called with a prompt |
| 515 | + mock_llm.invoke.assert_called_once() |
| 516 | + prompt_arg = mock_llm.invoke.call_args[0][0] |
| 517 | + assert isinstance(prompt_arg, str) |
| 518 | + assert "Sample text for extraction" in prompt_arg |
| 519 | + |
| 520 | + # verify the schema was correctly extracted |
| 521 | + assert len(schema_config.entities) == 2 |
| 522 | + assert "Person" in schema_config.entities |
| 523 | + assert "Organization" in schema_config.entities |
| 524 | + |
| 525 | + assert schema_config.relations is not None |
| 526 | + assert "WORKS_FOR" in schema_config.relations |
| 527 | + |
| 528 | + assert schema_config.potential_schema is not None |
| 529 | + assert len(schema_config.potential_schema) == 1 |
| 530 | + assert schema_config.potential_schema[0] == ("Person", "WORKS_FOR", "Organization") |
| 531 | + |
| 532 | + |
| 533 | +@pytest.mark.asyncio |
| 534 | +async def test_schema_from_text_run_invalid_json(schema_from_text, mock_llm, invalid_schema_json): |
| 535 | + # configure the mock LLM to return invalid JSON |
| 536 | + mock_llm.invoke.return_value = invalid_schema_json |
| 537 | + |
| 538 | + # verify that running with invalid JSON raises a ValueError |
| 539 | + with pytest.raises(ValueError) as exc_info: |
| 540 | + await schema_from_text.run(text="Sample text for extraction") |
| 541 | + |
| 542 | + assert "not valid JSON" in str(exc_info.value) |
| 543 | + |
| 544 | + |
| 545 | +@pytest.mark.asyncio |
| 546 | +async def test_schema_from_text_custom_template(mock_llm, valid_schema_json): |
| 547 | + # create a custom template |
| 548 | + custom_prompt = "This is a custom prompt with text: {text}" |
| 549 | + custom_template = PromptTemplate(template=custom_prompt, expected_inputs=["text"]) |
| 550 | + |
| 551 | + # create SchemaFromText with the custom template |
| 552 | + schema_from_text = SchemaFromText(llm=mock_llm, prompt_template=custom_template) |
| 553 | + |
| 554 | + # configure mock LLM to return valid JSON and capture the prompt that was sent to it |
| 555 | + mock_llm.invoke.return_value = valid_schema_json |
| 556 | + |
| 557 | + # run the schema extraction |
| 558 | + await schema_from_text.run(text="Sample text") |
| 559 | + |
| 560 | + # verify the custom prompt was passed to the LLM |
| 561 | + prompt_sent_to_llm = mock_llm.invoke.call_args[0][0] |
| 562 | + assert "This is a custom prompt with text" in prompt_sent_to_llm |
| 563 | + |
| 564 | + |
| 565 | +@pytest.mark.asyncio |
| 566 | +async def test_schema_from_text_llm_params(mock_llm, valid_schema_json): |
| 567 | + # configure custom LLM parameters |
| 568 | + llm_params = {"temperature": 0.1, "max_tokens": 500} |
| 569 | + |
| 570 | + # create SchemaFromText with custom LLM parameters |
| 571 | + schema_from_text = SchemaFromText(llm=mock_llm, llm_params=llm_params) |
| 572 | + |
| 573 | + # configure the mock LLM to return a valid schema JSON |
| 574 | + mock_llm.invoke.return_value = valid_schema_json |
| 575 | + |
| 576 | + # run the schema extraction |
| 577 | + await schema_from_text.run(text="Sample text") |
| 578 | + |
| 579 | + # verify the LLM was called with the custom parameters |
| 580 | + mock_llm.invoke.assert_called_once() |
| 581 | + call_kwargs = mock_llm.invoke.call_args[1] |
| 582 | + assert call_kwargs["temperature"] == 0.1 |
| 583 | + assert call_kwargs["max_tokens"] == 500 |
| 584 | + |
| 585 | + |
| 586 | +@pytest.mark.asyncio |
| 587 | +async def test_schema_config_store_as_json(schema_config): |
| 588 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 589 | + # create file path |
| 590 | + json_path = os.path.join(temp_dir, "schema.json") |
| 591 | + |
| 592 | + # store the schema config |
| 593 | + schema_config.store_as_json(json_path) |
| 594 | + |
| 595 | + # verify the file exists and has content |
| 596 | + assert os.path.exists(json_path) |
| 597 | + assert os.path.getsize(json_path) > 0 |
| 598 | + |
| 599 | + # verify the content is valid JSON and contains expected data |
| 600 | + with open(json_path, 'r') as f: |
| 601 | + data = json.load(f) |
| 602 | + assert "entities" in data |
| 603 | + assert "PERSON" in data["entities"] |
| 604 | + assert "properties" in data["entities"]["PERSON"] |
| 605 | + assert "description" in data["entities"]["PERSON"] |
| 606 | + assert data["entities"]["PERSON"]["description"] == "An individual human being." |
| 607 | + |
| 608 | + |
| 609 | +@pytest.mark.asyncio |
| 610 | +async def test_schema_config_store_as_yaml(schema_config): |
| 611 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 612 | + # Create file path |
| 613 | + yaml_path = os.path.join(temp_dir, "schema.yaml") |
| 614 | + |
| 615 | + # Store the schema config |
| 616 | + schema_config.store_as_yaml(yaml_path) |
| 617 | + |
| 618 | + # Verify the file exists and has content |
| 619 | + assert os.path.exists(yaml_path) |
| 620 | + assert os.path.getsize(yaml_path) > 0 |
| 621 | + |
| 622 | + # Verify the content is valid YAML and contains expected data |
| 623 | + with open(yaml_path, 'r') as f: |
| 624 | + data = yaml.safe_load(f) |
| 625 | + assert "entities" in data |
| 626 | + assert "PERSON" in data["entities"] |
| 627 | + assert "properties" in data["entities"]["PERSON"] |
| 628 | + assert "description" in data["entities"]["PERSON"] |
| 629 | + assert data["entities"]["PERSON"]["description"] == "An individual human being." |
| 630 | + |
| 631 | + |
| 632 | +@pytest.mark.asyncio |
| 633 | +async def test_schema_config_from_file(schema_config): |
| 634 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 635 | + # create file paths with different extensions |
| 636 | + json_path = os.path.join(temp_dir, "schema.json") |
| 637 | + yaml_path = os.path.join(temp_dir, "schema.yaml") |
| 638 | + yml_path = os.path.join(temp_dir, "schema.yml") |
| 639 | + |
| 640 | + # store the schema config in the different formats |
| 641 | + schema_config.store_as_json(json_path) |
| 642 | + schema_config.store_as_yaml(yaml_path) |
| 643 | + schema_config.store_as_yaml(yml_path) |
| 644 | + |
| 645 | + # load using from_file which should detect the format based on extension |
| 646 | + json_schema = SchemaConfig.from_file(json_path) |
| 647 | + yaml_schema = SchemaConfig.from_file(yaml_path) |
| 648 | + yml_schema = SchemaConfig.from_file(yml_path) |
| 649 | + |
| 650 | + # simple verification that the objects were loaded correctly |
| 651 | + assert isinstance(json_schema, SchemaConfig) |
| 652 | + assert isinstance(yaml_schema, SchemaConfig) |
| 653 | + assert isinstance(yml_schema, SchemaConfig) |
| 654 | + |
| 655 | + # verify basic structure is intact |
| 656 | + assert "entities" in json_schema.model_dump() |
| 657 | + assert "entities" in yaml_schema.model_dump() |
| 658 | + assert "entities" in yml_schema.model_dump() |
| 659 | + |
| 660 | + # verify an unsupported extension raises the correct error |
| 661 | + txt_path = os.path.join(temp_dir, "schema.txt") |
| 662 | + schema_config.store_as_json(txt_path) # Store as JSON but with .txt extension |
| 663 | + |
| 664 | + with pytest.raises(ValueError, match="Unsupported file format"): |
| 665 | + SchemaConfig.from_file(txt_path) |
0 commit comments