Skip to content

Commit 4ebe9bd

Browse files
Add unit tests for schema enforcement modes
1 parent a9f47d7 commit 4ebe9bd

File tree

1 file changed

+235
-0
lines changed

1 file changed

+235
-0
lines changed

tests/unit/experimental/components/test_entity_relation_extractor.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@
2525
balance_curly_braces,
2626
fix_invalid_json,
2727
)
28+
from neo4j_graphrag.experimental.components.schema import SchemaConfig
2829
from neo4j_graphrag.experimental.components.types import (
2930
DocumentInfo,
3031
Neo4jGraph,
3132
TextChunk,
3233
TextChunks,
34+
SchemaEnforcementMode,
3335
)
3436
from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError
3537
from neo4j_graphrag.llm import LLMInterface, LLMResponse
@@ -229,6 +231,239 @@ async def test_extractor_custom_prompt() -> None:
229231
llm.ainvoke.assert_called_once_with("this is my prompt")
230232

231233

234+
@pytest.mark.asyncio
235+
async def test_extractor_no_schema_enforcement() -> None:
236+
llm = MagicMock(spec=LLMInterface)
237+
llm.ainvoke.return_value = LLMResponse(
238+
content='{"nodes":[{"id":"0","label":"Alien","properties":{"foo":"bar"}}],'
239+
'"relationships":[]}'
240+
)
241+
242+
extractor = LLMEntityRelationExtractor(llm=llm,
243+
create_lexical_graph=False,
244+
enforce_schema=SchemaEnforcementMode.NONE)
245+
246+
schema = SchemaConfig(entities={"Person": {"name": "STRING"}},
247+
relations={},
248+
potential_schema=[])
249+
250+
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
251+
252+
result: Neo4jGraph = await extractor.run(chunks=chunks, schema=schema)
253+
254+
assert len(result.nodes) == 1
255+
assert result.nodes[0].label == "Alien"
256+
assert result.nodes[0].properties == {"chunk_index": 0, "foo": "bar"}
257+
258+
259+
@pytest.mark.asyncio
260+
async def test_extractor_schema_enforcement_when_no_schema_provided():
261+
llm = MagicMock(spec=LLMInterface)
262+
llm.ainvoke.return_value = LLMResponse(
263+
content='{"nodes":[{"id":"0","label":"Alien","properties":{"foo":"bar"}}],'
264+
'"relationships":[]}'
265+
)
266+
267+
extractor = LLMEntityRelationExtractor(llm=llm,
268+
create_lexical_graph=False,
269+
enforce_schema=SchemaEnforcementMode.STRICT)
270+
271+
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
272+
273+
result: Neo4jGraph = await extractor.run(chunks=chunks)
274+
275+
assert len(result.nodes) == 1
276+
assert result.nodes[0].label == "Alien"
277+
assert result.nodes[0].properties == {"chunk_index": 0, "foo": "bar"}
278+
279+
280+
@pytest.mark.asyncio
281+
async def test_extractor_schema_enforcement_invalid_nodes():
282+
llm = MagicMock(spec=LLMInterface)
283+
llm.ainvoke.return_value = LLMResponse(
284+
content='{"nodes":[{"id":"0","label":"Alien","properties":{"foo":"bar"}},'
285+
'{"id":"1","label":"Person","properties":{"name":"Alice"}}],'
286+
'"relationships":[]}'
287+
)
288+
289+
extractor = LLMEntityRelationExtractor(llm=llm,
290+
create_lexical_graph=False,
291+
enforce_schema=SchemaEnforcementMode.STRICT)
292+
293+
schema = SchemaConfig(entities={"Person": {"name": "STRING"}},
294+
relations={},
295+
potential_schema=[])
296+
297+
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
298+
299+
result: Neo4jGraph = await extractor.run(chunks=chunks, schema=schema)
300+
301+
assert len(result.nodes) == 1
302+
assert result.nodes[0].label == "Person"
303+
assert result.nodes[0].properties == {"chunk_index": 0, "name": "Alice"}
304+
305+
306+
@pytest.mark.asyncio
307+
async def test_extraction_schema_enforcement_invalid_node_properties():
308+
llm = MagicMock(spec=LLMInterface)
309+
llm.ainvoke.return_value = LLMResponse(
310+
content='{"nodes":[{"id":"1","label":"Person","properties":'
311+
'{"name":"Alice","age":30,"foo":"bar"}}],'
312+
'"relationships":[]}'
313+
)
314+
315+
extractor = LLMEntityRelationExtractor(llm=llm,
316+
create_lexical_graph=False,
317+
enforce_schema=SchemaEnforcementMode.STRICT)
318+
319+
schema = SchemaConfig(entities={"Person": {"name": str, "age": int}},
320+
relations={},
321+
potential_schema=[])
322+
323+
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
324+
325+
result: Neo4jGraph = await extractor.run(chunks, schema=schema)
326+
327+
# "foo" is removed
328+
assert len(result.nodes) == 1
329+
assert len(result.nodes[0].properties) == 3
330+
assert "foo" not in result.nodes[0].properties
331+
332+
333+
@pytest.mark.asyncio
334+
async def test_extractor_schema_enforcement_valid_nodes_with_empty_props():
335+
llm = MagicMock(spec=LLMInterface)
336+
llm.ainvoke.return_value = LLMResponse(
337+
content='{"nodes":[{"id":"1","label":"Person","properties":{"foo":"bar"}}],'
338+
'"relationships":[]}'
339+
)
340+
341+
extractor = LLMEntityRelationExtractor(llm=llm,
342+
create_lexical_graph=False,
343+
enforce_schema=SchemaEnforcementMode.STRICT)
344+
345+
schema = SchemaConfig(entities={"Person": {}},
346+
relations={},
347+
potential_schema=[])
348+
349+
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
350+
351+
result: Neo4jGraph = await extractor.run(chunks, schema=schema)
352+
353+
assert len(result.nodes) == 0
354+
355+
356+
@pytest.mark.asyncio
357+
async def test_extractor_schema_enforcement_invalid_relations_wrong_types():
358+
llm = MagicMock(spec=LLMInterface)
359+
llm.ainvoke.return_value = LLMResponse(
360+
content='{"nodes":[{"id":"1","label":"Person","properties":'
361+
'{"name":"Alice"}},{"id":"2","label":"Person","properties":'
362+
'{"name":"Bob"}}],'
363+
'"relationships":[{"start_node_id":"1","end_node_id":"2",'
364+
'"type":"FRIENDS_WITH","properties":{}}]}'
365+
)
366+
367+
extractor = LLMEntityRelationExtractor(llm=llm,
368+
create_lexical_graph=False,
369+
enforce_schema=SchemaEnforcementMode.STRICT)
370+
371+
schema = SchemaConfig(entities={"Person": {"name": str}},
372+
relations={"LIKES": {}},
373+
potential_schema=[])
374+
375+
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
376+
377+
result: Neo4jGraph = await extractor.run(chunks, schema=schema)
378+
379+
assert len(result.nodes) == 2
380+
assert len(result.relationships) == 0
381+
382+
383+
@pytest.mark.asyncio
384+
async def test_extractor_schema_enforcement_invalid_relations_wrong_start_node():
385+
llm = MagicMock(spec=LLMInterface)
386+
llm.ainvoke.return_value = LLMResponse(
387+
content='{"nodes":[{"id":"1","label":"Person","properties":{"name":"Alice"}},'
388+
'{"id":"2","label":"Person","properties":{"name":"Bob"}}, '
389+
'{"id":"3","label":"City","properties":{"name":"London"}}],'
390+
'"relationships":[{"start_node_id":"1","end_node_id":"2",'
391+
'"type":"LIVES_IN","properties":{}}]}'
392+
)
393+
394+
extractor = LLMEntityRelationExtractor(llm=llm,
395+
create_lexical_graph=False,
396+
enforce_schema=SchemaEnforcementMode.STRICT)
397+
398+
schema = SchemaConfig(entities={"Person": {"name": str}, "City": {"name": str}},
399+
relations={"LIVES_IN": {}},
400+
potential_schema=[("Person", "LIVES_IN", "City")])
401+
402+
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
403+
404+
result: Neo4jGraph = await extractor.run(chunks, schema=schema)
405+
406+
assert len(result.nodes) == 3
407+
assert len(result.relationships) == 0
408+
409+
410+
@pytest.mark.asyncio
411+
async def test_extractor_schema_enforcement_invalid_relation_properties():
412+
llm = MagicMock(spec=LLMInterface)
413+
llm.ainvoke.return_value = LLMResponse(
414+
content='{"nodes":[{"id":"1","label":"Person","properties":{"name":"Alice"}},'
415+
'{"id":"2","label":"Person","properties":{"name":"Bob"}}],'
416+
'"relationships":[{"start_node_id":"1","end_node_id":"2",'
417+
'"type":"LIKES","properties":{"strength":"high","foo":"bar"}}]}'
418+
)
419+
420+
extractor = LLMEntityRelationExtractor(llm=llm,
421+
create_lexical_graph=False,
422+
enforce_schema=SchemaEnforcementMode.STRICT)
423+
424+
schema = SchemaConfig(
425+
entities={"Person": {"name": str}},
426+
relations={"LIKES": {"strength": str}},
427+
potential_schema=[]
428+
)
429+
430+
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
431+
432+
result: Neo4jGraph = await extractor.run(chunks, schema=schema)
433+
434+
assert len(result.nodes) == 2
435+
assert len(result.relationships) == 1
436+
rel = result.relationships[0]
437+
assert "foo" not in rel.properties
438+
assert rel.properties["strength"] == "high"
439+
440+
441+
@pytest.mark.asyncio
442+
async def test_extractor_schema_enforcement_removed_relation_start_end_nodes():
443+
llm = MagicMock(spec=LLMInterface)
444+
llm.ainvoke.return_value = LLMResponse(
445+
content='{"nodes":[{"id":"1","label":"Alien","properties":{}},'
446+
'{"id":"2","label":"Robot","properties":{}}],'
447+
'"relationships":[{"start_node_id":"1","end_node_id":"2",'
448+
'"type":"LIKES","properties":{}}]}'
449+
)
450+
451+
extractor = LLMEntityRelationExtractor(llm=llm,
452+
create_lexical_graph=False,
453+
enforce_schema=SchemaEnforcementMode.STRICT)
454+
455+
schema = SchemaConfig(entities={"Person": {"name": str}},
456+
relations={"LIKES": {}},
457+
potential_schema=[("Person", "LIKES", "Person")])
458+
459+
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
460+
461+
result: Neo4jGraph = await extractor.run(chunks, schema=schema)
462+
463+
assert len(result.nodes) == 0
464+
assert len(result.relationships) == 0
465+
466+
232467
def test_fix_invalid_json_empty_result() -> None:
233468
json_string = "invalid json"
234469

0 commit comments

Comments
 (0)