Skip to content

Commit baf9302

Browse files
Add schema from text using an LLM
1 parent 7c831de commit baf9302

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
import json
1718
from typing import Any, Dict, List, Literal, Optional, Tuple
1819

1920
from pydantic import BaseModel, ValidationError, model_validator, validate_call
21+
from requests.exceptions import InvalidJSONError
2022
from typing_extensions import Self
2123

2224
from neo4j_graphrag.exceptions import SchemaValidationError
@@ -25,6 +27,8 @@
2527
EntityInputType,
2628
RelationInputType,
2729
)
30+
from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
31+
from neo4j_graphrag.llm import LLMInterface
2832

2933

3034
class SchemaProperty(BaseModel):
@@ -236,3 +240,62 @@ async def run(
236240
SchemaConfig: A configured schema object, constructed asynchronously.
237241
"""
238242
return self.create_schema_model(entities, relations, potential_schema)
243+
244+
245+
class SchemaFromText(SchemaBuilder):
246+
"""
247+
A builder class for constructing SchemaConfig objects from the output of an LLM after
248+
automatic schema extraction from text.
249+
"""
250+
251+
def __init__(
252+
self,
253+
llm: LLMInterface,
254+
prompt_template: Optional[PromptTemplate] = None,
255+
llm_params: Optional[Dict[str, Any]] = None,
256+
) -> None:
257+
super().__init__()
258+
self._llm: LLMInterface = llm
259+
self._prompt_template: PromptTemplate = prompt_template or SchemaExtractionTemplate()
260+
self._llm_params: dict[str, Any] = llm_params or {}
261+
262+
@validate_call
263+
async def run(self, text: str, **kwargs: Any) -> SchemaConfig:
264+
"""
265+
Asynchronously extracts the schema from text and returns a SchemaConfig object.
266+
267+
Args:
268+
text (str): the text from which the schema will be inferred.
269+
270+
Returns:
271+
SchemaConfig: A configured schema object, extracted automatically and
272+
constructed asynchronously.
273+
"""
274+
prompt: str = self._prompt_template.format(text=text)
275+
276+
response = await self._llm.invoke(prompt, **self._llm_params)
277+
content: str = (
278+
response if isinstance(response, str) else getattr(response, "content", str(response))
279+
)
280+
281+
try:
282+
extracted_schema: Dict[str, Any] = json.loads(content)
283+
except json.JSONDecodeError as exc:
284+
raise InvalidJSONError(
285+
"LLM response is not valid JSON."
286+
) from exc
287+
288+
extracted_entities: List[dict] = extracted_schema.get("entities", [])
289+
extracted_relations: Optional[List[dict]] = extracted_schema.get("relations")
290+
potential_schema: Optional[List[Tuple[str, str, str]]] = extracted_schema.get("potential_schema")
291+
292+
entities: List[SchemaEntity] = [SchemaEntity(**e) for e in extracted_entities]
293+
relations: Optional[List[SchemaRelation]] = (
294+
[SchemaRelation(**r) for r in extracted_relations] if extracted_relations else None
295+
)
296+
297+
return await super().run(
298+
entities=entities,
299+
relations=relations,
300+
potential_schema=potential_schema,
301+
)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from .graphrag import GraphRAG
2-
from .prompts import PromptTemplate, RagTemplate
2+
from .prompts import PromptTemplate, RagTemplate, SchemaExtractionTemplate
33

44
__all__ = [
55
"GraphRAG",
66
"PromptTemplate",
77
"RagTemplate",
8+
"SchemaExtractionTemplate"
89
]

0 commit comments

Comments
 (0)