|
14 | 14 | # limitations under the License.
|
15 | 15 | from __future__ import annotations
|
16 | 16 |
|
| 17 | +import json |
17 | 18 | from typing import Any, Dict, List, Literal, Optional, Tuple
|
18 | 19 |
|
19 | 20 | from pydantic import BaseModel, ValidationError, model_validator, validate_call
|
| 21 | +from requests.exceptions import InvalidJSONError |
20 | 22 | from typing_extensions import Self
|
21 | 23 |
|
22 | 24 | from neo4j_graphrag.exceptions import SchemaValidationError
|
|
25 | 27 | EntityInputType,
|
26 | 28 | RelationInputType,
|
27 | 29 | )
|
| 30 | +from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate |
| 31 | +from neo4j_graphrag.llm import LLMInterface |
28 | 32 |
|
29 | 33 |
|
30 | 34 | class SchemaProperty(BaseModel):
|
@@ -236,3 +240,62 @@ async def run(
|
236 | 240 | SchemaConfig: A configured schema object, constructed asynchronously.
|
237 | 241 | """
|
238 | 242 | return self.create_schema_model(entities, relations, potential_schema)
|
| 243 | + |
| 244 | + |
| 245 | +class SchemaFromText(SchemaBuilder): |
| 246 | + """ |
| 247 | + A builder class for constructing SchemaConfig objects from the output of an LLM after |
| 248 | + automatic schema extraction from text. |
| 249 | + """ |
| 250 | + |
| 251 | + def __init__( |
| 252 | + self, |
| 253 | + llm: LLMInterface, |
| 254 | + prompt_template: Optional[PromptTemplate] = None, |
| 255 | + llm_params: Optional[Dict[str, Any]] = None, |
| 256 | + ) -> None: |
| 257 | + super().__init__() |
| 258 | + self._llm: LLMInterface = llm |
| 259 | + self._prompt_template: PromptTemplate = prompt_template or SchemaExtractionTemplate() |
| 260 | + self._llm_params: dict[str, Any] = llm_params or {} |
| 261 | + |
| 262 | + @validate_call |
| 263 | + async def run(self, text: str, **kwargs: Any) -> SchemaConfig: |
| 264 | + """ |
| 265 | + Asynchronously extracts the schema from text and returns a SchemaConfig object. |
| 266 | +
|
| 267 | + Args: |
| 268 | + text (str): the text from which the schema will be inferred. |
| 269 | +
|
| 270 | + Returns: |
| 271 | + SchemaConfig: A configured schema object, extracted automatically and |
| 272 | + constructed asynchronously. |
| 273 | + """ |
| 274 | + prompt: str = self._prompt_template.format(text=text) |
| 275 | + |
| 276 | + response = await self._llm.invoke(prompt, **self._llm_params) |
| 277 | + content: str = ( |
| 278 | + response if isinstance(response, str) else getattr(response, "content", str(response)) |
| 279 | + ) |
| 280 | + |
| 281 | + try: |
| 282 | + extracted_schema: Dict[str, Any] = json.loads(content) |
| 283 | + except json.JSONDecodeError as exc: |
| 284 | + raise InvalidJSONError( |
| 285 | + "LLM response is not valid JSON." |
| 286 | + ) from exc |
| 287 | + |
| 288 | + extracted_entities: List[dict] = extracted_schema.get("entities", []) |
| 289 | + extracted_relations: Optional[List[dict]] = extracted_schema.get("relations") |
| 290 | + potential_schema: Optional[List[Tuple[str, str, str]]] = extracted_schema.get("potential_schema") |
| 291 | + |
| 292 | + entities: List[SchemaEntity] = [SchemaEntity(**e) for e in extracted_entities] |
| 293 | + relations: Optional[List[SchemaRelation]] = ( |
| 294 | + [SchemaRelation(**r) for r in extracted_relations] if extracted_relations else None |
| 295 | + ) |
| 296 | + |
| 297 | + return await super().run( |
| 298 | + entities=entities, |
| 299 | + relations=relations, |
| 300 | + potential_schema=potential_schema, |
| 301 | + ) |
0 commit comments