Skip to content

Commit b52bed4

Browse files
Add unit tests
1 parent fa8a6af commit b52bed4

File tree

1 file changed

+244
-0
lines changed

1 file changed

+244
-0
lines changed

tests/unit/experimental/components/test_schema.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,26 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
import json
18+
from unittest.mock import AsyncMock
19+
1720
import pytest
1821
from neo4j_graphrag.exceptions import SchemaValidationError
1922
from neo4j_graphrag.experimental.components.schema import (
2023
SchemaBuilder,
2124
SchemaEntity,
2225
SchemaProperty,
2326
SchemaRelation,
27+
SchemaFromText,
28+
SchemaConfig,
2429
)
2530
from pydantic import ValidationError
31+
import os
32+
import tempfile
33+
import yaml
34+
from pathlib import Path
35+
36+
from neo4j_graphrag.generation import PromptTemplate
2637

2738

2839
@pytest.fixture
@@ -93,6 +104,18 @@ def schema_builder() -> SchemaBuilder:
93104
return SchemaBuilder()
94105

95106

107+
@pytest.fixture
108+
def schema_config(
109+
schema_builder: SchemaBuilder,
110+
valid_entities: list[SchemaEntity],
111+
valid_relations: list[SchemaRelation],
112+
potential_schema: list[tuple[str, str, str]],
113+
):
114+
return schema_builder.create_schema_model(
115+
valid_entities, valid_relations, potential_schema
116+
)
117+
118+
96119
def test_create_schema_model_valid_data(
97120
schema_builder: SchemaBuilder,
98121
valid_entities: list[SchemaEntity],
@@ -419,3 +442,224 @@ def test_create_schema_model_missing_relations(
419442
assert "Relations must also be provided when using a potential schema." in str(
420443
exc_info.value
421444
), "Should fail due to missing relations"
445+
446+
447+
@pytest.fixture
448+
def mock_llm():
449+
mock = AsyncMock()
450+
mock.invoke = AsyncMock()
451+
return mock
452+
453+
454+
@pytest.fixture
455+
def valid_schema_json():
456+
return '''
457+
{
458+
"entities": [
459+
{
460+
"label": "Person",
461+
"properties": [
462+
{"name": "name", "type": "STRING"}
463+
]
464+
},
465+
{
466+
"label": "Organization",
467+
"properties": [
468+
{"name": "name", "type": "STRING"}
469+
]
470+
}
471+
],
472+
"relations": [
473+
{
474+
"label": "WORKS_FOR",
475+
"properties": [
476+
{"name": "since", "type": "DATE"}
477+
]
478+
}
479+
],
480+
"potential_schema": [
481+
["Person", "WORKS_FOR", "Organization"]
482+
]
483+
}
484+
'''
485+
486+
487+
@pytest.fixture
488+
def invalid_schema_json():
489+
return '''
490+
{
491+
"entities": [
492+
{
493+
"label": "Person",
494+
},
495+
],
496+
invalid json content
497+
}
498+
'''
499+
500+
501+
@pytest.fixture
502+
def schema_from_text(mock_llm):
503+
return SchemaFromText(llm=mock_llm)
504+
505+
506+
@pytest.mark.asyncio
507+
async def test_schema_from_text_run_valid_response(schema_from_text, mock_llm, valid_schema_json):
508+
# configure the mock LLM to return a valid schema JSON
509+
mock_llm.invoke.return_value = valid_schema_json
510+
511+
# run the schema extraction
512+
schema_config = await schema_from_text.run(text="Sample text for extraction")
513+
514+
# verify the LLM was called with a prompt
515+
mock_llm.invoke.assert_called_once()
516+
prompt_arg = mock_llm.invoke.call_args[0][0]
517+
assert isinstance(prompt_arg, str)
518+
assert "Sample text for extraction" in prompt_arg
519+
520+
# verify the schema was correctly extracted
521+
assert len(schema_config.entities) == 2
522+
assert "Person" in schema_config.entities
523+
assert "Organization" in schema_config.entities
524+
525+
assert schema_config.relations is not None
526+
assert "WORKS_FOR" in schema_config.relations
527+
528+
assert schema_config.potential_schema is not None
529+
assert len(schema_config.potential_schema) == 1
530+
assert schema_config.potential_schema[0] == ("Person", "WORKS_FOR", "Organization")
531+
532+
533+
@pytest.mark.asyncio
534+
async def test_schema_from_text_run_invalid_json(schema_from_text, mock_llm, invalid_schema_json):
535+
# configure the mock LLM to return invalid JSON
536+
mock_llm.invoke.return_value = invalid_schema_json
537+
538+
# verify that running with invalid JSON raises a ValueError
539+
with pytest.raises(ValueError) as exc_info:
540+
await schema_from_text.run(text="Sample text for extraction")
541+
542+
assert "not valid JSON" in str(exc_info.value)
543+
544+
545+
@pytest.mark.asyncio
546+
async def test_schema_from_text_custom_template(mock_llm, valid_schema_json):
547+
# create a custom template
548+
custom_prompt = "This is a custom prompt with text: {text}"
549+
custom_template = PromptTemplate(template=custom_prompt, expected_inputs=["text"])
550+
551+
# create SchemaFromText with the custom template
552+
schema_from_text = SchemaFromText(llm=mock_llm, prompt_template=custom_template)
553+
554+
# configure mock LLM to return valid JSON and capture the prompt that was sent to it
555+
mock_llm.invoke.return_value = valid_schema_json
556+
557+
# run the schema extraction
558+
await schema_from_text.run(text="Sample text")
559+
560+
# verify the custom prompt was passed to the LLM
561+
prompt_sent_to_llm = mock_llm.invoke.call_args[0][0]
562+
assert "This is a custom prompt with text" in prompt_sent_to_llm
563+
564+
565+
@pytest.mark.asyncio
566+
async def test_schema_from_text_llm_params(mock_llm, valid_schema_json):
567+
# configure custom LLM parameters
568+
llm_params = {"temperature": 0.1, "max_tokens": 500}
569+
570+
# create SchemaFromText with custom LLM parameters
571+
schema_from_text = SchemaFromText(llm=mock_llm, llm_params=llm_params)
572+
573+
# configure the mock LLM to return a valid schema JSON
574+
mock_llm.invoke.return_value = valid_schema_json
575+
576+
# run the schema extraction
577+
await schema_from_text.run(text="Sample text")
578+
579+
# verify the LLM was called with the custom parameters
580+
mock_llm.invoke.assert_called_once()
581+
call_kwargs = mock_llm.invoke.call_args[1]
582+
assert call_kwargs["temperature"] == 0.1
583+
assert call_kwargs["max_tokens"] == 500
584+
585+
586+
@pytest.mark.asyncio
587+
async def test_schema_config_store_as_json(schema_config):
588+
with tempfile.TemporaryDirectory() as temp_dir:
589+
# create file path
590+
json_path = os.path.join(temp_dir, "schema.json")
591+
592+
# store the schema config
593+
schema_config.store_as_json(json_path)
594+
595+
# verify the file exists and has content
596+
assert os.path.exists(json_path)
597+
assert os.path.getsize(json_path) > 0
598+
599+
# verify the content is valid JSON and contains expected data
600+
with open(json_path, 'r') as f:
601+
data = json.load(f)
602+
assert "entities" in data
603+
assert "PERSON" in data["entities"]
604+
assert "properties" in data["entities"]["PERSON"]
605+
assert "description" in data["entities"]["PERSON"]
606+
assert data["entities"]["PERSON"]["description"] == "An individual human being."
607+
608+
609+
@pytest.mark.asyncio
610+
async def test_schema_config_store_as_yaml(schema_config):
611+
with tempfile.TemporaryDirectory() as temp_dir:
612+
# Create file path
613+
yaml_path = os.path.join(temp_dir, "schema.yaml")
614+
615+
# Store the schema config
616+
schema_config.store_as_yaml(yaml_path)
617+
618+
# Verify the file exists and has content
619+
assert os.path.exists(yaml_path)
620+
assert os.path.getsize(yaml_path) > 0
621+
622+
# Verify the content is valid YAML and contains expected data
623+
with open(yaml_path, 'r') as f:
624+
data = yaml.safe_load(f)
625+
assert "entities" in data
626+
assert "PERSON" in data["entities"]
627+
assert "properties" in data["entities"]["PERSON"]
628+
assert "description" in data["entities"]["PERSON"]
629+
assert data["entities"]["PERSON"]["description"] == "An individual human being."
630+
631+
632+
@pytest.mark.asyncio
633+
async def test_schema_config_from_file(schema_config):
634+
with tempfile.TemporaryDirectory() as temp_dir:
635+
# create file paths with different extensions
636+
json_path = os.path.join(temp_dir, "schema.json")
637+
yaml_path = os.path.join(temp_dir, "schema.yaml")
638+
yml_path = os.path.join(temp_dir, "schema.yml")
639+
640+
# store the schema config in the different formats
641+
schema_config.store_as_json(json_path)
642+
schema_config.store_as_yaml(yaml_path)
643+
schema_config.store_as_yaml(yml_path)
644+
645+
# load using from_file which should detect the format based on extension
646+
json_schema = SchemaConfig.from_file(json_path)
647+
yaml_schema = SchemaConfig.from_file(yaml_path)
648+
yml_schema = SchemaConfig.from_file(yml_path)
649+
650+
# simple verification that the objects were loaded correctly
651+
assert isinstance(json_schema, SchemaConfig)
652+
assert isinstance(yaml_schema, SchemaConfig)
653+
assert isinstance(yml_schema, SchemaConfig)
654+
655+
# verify basic structure is intact
656+
assert "entities" in json_schema.model_dump()
657+
assert "entities" in yaml_schema.model_dump()
658+
assert "entities" in yml_schema.model_dump()
659+
660+
# verify an unsupported extension raises the correct error
661+
txt_path = os.path.join(temp_dir, "schema.txt")
662+
schema_config.store_as_json(txt_path) # Store as JSON but with .txt extension
663+
664+
with pytest.raises(ValueError, match="Unsupported file format"):
665+
SchemaConfig.from_file(txt_path)

0 commit comments

Comments
 (0)