Skip to content

Commit 52a2686

Browse files
Ruff
1 parent 30c273d commit 52a2686

File tree

5 files changed

+138
-94
lines changed

5 files changed

+138
-94
lines changed

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -136,81 +136,83 @@ def store_as_json(self, file_path: str) -> None:
136136
Args:
137137
file_path (str): The path where the schema configuration will be saved.
138138
"""
139-
with open(file_path, 'w') as f:
139+
with open(file_path, "w") as f:
140140
json.dump(self.model_dump(), f, indent=2)
141-
141+
142142
def store_as_yaml(self, file_path: str) -> None:
143143
"""
144144
Save the schema configuration to a YAML file.
145145
146146
Args:
147147
file_path (str): The path where the schema configuration will be saved.
148-
"""
148+
"""
149149
# create a copy of the data and convert tuples to lists for YAML compatibility
150150
data = self.model_dump()
151-
if data.get('potential_schema'):
152-
data['potential_schema'] = [list(item) for item in data['potential_schema']]
153-
154-
with open(file_path, 'w') as f:
151+
if data.get("potential_schema"):
152+
data["potential_schema"] = [list(item) for item in data["potential_schema"]]
153+
154+
with open(file_path, "w") as f:
155155
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
156-
156+
157157
@classmethod
158158
def from_file(cls, file_path: Union[str, Path]) -> Self:
159159
"""
160160
Load a schema configuration from a file (either JSON or YAML).
161-
161+
162162
The file format is automatically detected based on the file extension.
163-
163+
164164
Args:
165165
file_path (Union[str, Path]): The path to the schema configuration file.
166-
166+
167167
Returns:
168168
SchemaConfig: The loaded schema configuration.
169169
"""
170170
file_path = Path(file_path)
171-
171+
172172
if not file_path.exists():
173173
raise FileNotFoundError(f"Schema file not found: {file_path}")
174-
175-
if file_path.suffix.lower() in ['.json']:
174+
175+
if file_path.suffix.lower() in [".json"]:
176176
return cls.from_json(file_path)
177-
elif file_path.suffix.lower() in ['.yaml', '.yml']:
177+
elif file_path.suffix.lower() in [".yaml", ".yml"]:
178178
return cls.from_yaml(file_path)
179179
else:
180-
raise ValueError(f"Unsupported file format: {file_path.suffix}. Use .json, .yaml, or .yml")
181-
180+
raise ValueError(
181+
f"Unsupported file format: {file_path.suffix}. Use .json, .yaml, or .yml"
182+
)
183+
182184
@classmethod
183185
def from_json(cls, file_path: Union[str, Path]) -> Self:
184186
"""
185187
Load a schema configuration from a JSON file.
186-
188+
187189
Args:
188190
file_path (Union[str, Path]): The path to the JSON schema configuration file.
189-
191+
190192
Returns:
191193
SchemaConfig: The loaded schema configuration.
192194
"""
193-
with open(file_path, 'r') as f:
195+
with open(file_path, "r") as f:
194196
try:
195197
data = json.load(f)
196198
return cls.model_validate(data)
197199
except json.JSONDecodeError as e:
198200
raise ValueError(f"Invalid JSON file: {e}")
199201
except ValidationError as e:
200202
raise SchemaValidationError(f"Schema validation failed: {e}")
201-
203+
202204
@classmethod
203205
def from_yaml(cls, file_path: Union[str, Path]) -> Self:
204206
"""
205207
Load a schema configuration from a YAML file.
206-
208+
207209
Args:
208210
file_path (Union[str, Path]): The path to the YAML schema configuration file.
209-
211+
210212
Returns:
211213
SchemaConfig: The loaded schema configuration.
212214
"""
213-
with open(file_path, 'r') as f:
215+
with open(file_path, "r") as f:
214216
try:
215217
data = yaml.safe_load(f)
216218
return cls.model_validate(data)
@@ -348,11 +350,13 @@ def __init__(
348350
) -> None:
349351
super().__init__()
350352
self._llm: LLMInterface = llm
351-
self._prompt_template: PromptTemplate = prompt_template or SchemaExtractionTemplate()
353+
self._prompt_template: PromptTemplate = (
354+
prompt_template or SchemaExtractionTemplate()
355+
)
352356
self._llm_params: dict[str, Any] = llm_params or {}
353357

354358
@validate_call
355-
async def run(self, text: str, examples:str = "", **kwargs: Any) -> SchemaConfig:
359+
async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfig:
356360
"""
357361
Asynchronously extracts the schema from text and returns a SchemaConfig object.
358362
@@ -367,23 +371,27 @@ async def run(self, text: str, examples:str = "", **kwargs: Any) -> SchemaConfig
367371

368372
response = await self._llm.invoke(prompt, **self._llm_params)
369373
content: str = (
370-
response if isinstance(response, str) else getattr(response, "content", str(response))
374+
response
375+
if isinstance(response, str)
376+
else getattr(response, "content", str(response))
371377
)
372378

373379
try:
374380
extracted_schema: Dict[str, Any] = json.loads(content)
375381
except json.JSONDecodeError as exc:
376-
raise ValueError(
377-
"LLM response is not valid JSON."
378-
) from exc
382+
raise ValueError("LLM response is not valid JSON.") from exc
379383

380384
extracted_entities: List[dict] = extracted_schema.get("entities", [])
381385
extracted_relations: Optional[List[dict]] = extracted_schema.get("relations")
382-
potential_schema: Optional[List[Tuple[str, str, str]]] = extracted_schema.get("potential_schema")
386+
potential_schema: Optional[List[Tuple[str, str, str]]] = extracted_schema.get(
387+
"potential_schema"
388+
)
383389

384390
entities: List[SchemaEntity] = [SchemaEntity(**e) for e in extracted_entities]
385391
relations: Optional[List[SchemaRelation]] = (
386-
[SchemaRelation(**r) for r in extracted_relations] if extracted_relations else None
392+
[SchemaRelation(**r) for r in extracted_relations]
393+
if extracted_relations
394+
else None
387395
)
388396

389397
return await super().run(

src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@
5959

6060
logger = logging.getLogger(__name__)
6161

62-
T = TypeVar('T', bound='SimpleKGPipelineConfig')
62+
T = TypeVar("T", bound="SimpleKGPipelineConfig")
63+
6364

6465
class SimpleKGPipelineConfig(TemplatePipelineConfig):
6566
COMPONENTS: ClassVar[list[str]] = [
@@ -94,40 +95,47 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
9495
text_splitter: Optional[ComponentType] = None
9596

9697
model_config = ConfigDict(arbitrary_types_allowed=True)
97-
98-
@model_validator(mode='after')
98+
99+
@model_validator(mode="after")
99100
def handle_schema_precedence(self) -> T:
100101
"""Handle schema precedence and warnings"""
101102
self._process_schema_parameters()
102103
return self
103-
104+
104105
def _process_schema_parameters(self) -> None:
105106
"""
106107
Process schema parameters and handle precedence between 'schema' parameter and individual components.
107108
Also logs warnings for deprecated usage.
108109
"""
109110
# check if both schema and individual components are provided
110-
has_individual_schema_components = any([self.entities, self.relations, self.potential_schema])
111-
111+
has_individual_schema_components = any(
112+
[self.entities, self.relations, self.potential_schema]
113+
)
114+
112115
if has_individual_schema_components and self.schema is not None:
113116
logger.warning(
114117
"Both 'schema' and individual schema components (entities, relations, potential_schema) "
115118
"were provided. The 'schema' parameter takes precedence. In the future, individual "
116119
"components will be removed. Please use only the 'schema' parameter.",
117-
stacklevel=2
120+
stacklevel=2,
118121
)
119-
122+
120123
elif has_individual_schema_components:
121124
logger.warning(
122125
"The 'entities', 'relations', and 'potential_schema' parameters are deprecated "
123126
"and will be removed in a future version. "
124127
"Please use the 'schema' parameter instead.",
125-
stacklevel=2
128+
stacklevel=2,
126129
)
127130

128131
def has_user_provided_schema(self) -> bool:
129132
"""Check if the user has provided schema information"""
130-
return bool(self.entities or self.relations or self.potential_schema or self.schema is not None)
133+
return bool(
134+
self.entities
135+
or self.relations
136+
or self.potential_schema
137+
or self.schema is not None
138+
)
131139

132140
def _get_pdf_loader(self) -> Optional[PdfLoader]:
133141
if not self.from_pdf:
@@ -165,13 +173,17 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromText]:
165173
return SchemaFromText(llm=self.get_default_llm())
166174
return SchemaBuilder()
167175

168-
def _process_schema_with_precedence(self) -> tuple[list[SchemaEntity], list[SchemaRelation], Optional[list[tuple[str, str, str]]]]:
176+
def _process_schema_with_precedence(
177+
self,
178+
) -> tuple[
179+
list[SchemaEntity], list[SchemaRelation], Optional[list[tuple[str, str, str]]]
180+
]:
169181
"""
170182
Process schema inputs according to precedence rules:
171183
1. If schema is provided as SchemaConfig object, use it
172184
2. If schema is provided as dictionary, extract from it
173185
3. Otherwise, use individual schema components
174-
186+
175187
Returns:
176188
Tuple of (entities, relations, potential_schema)
177189
"""
@@ -184,15 +196,29 @@ def _process_schema_with_precedence(self) -> tuple[list[SchemaEntity], list[Sche
184196
potential_schema = self.schema.potential_schema
185197
else:
186198
# extract from dictionary
187-
entities = [SchemaEntity.from_text_or_dict(e) for e in self.schema.get("entities", [])]
188-
relations = [SchemaRelation.from_text_or_dict(r) for r in self.schema.get("relations", [])]
199+
entities = [
200+
SchemaEntity.from_text_or_dict(e)
201+
for e in self.schema.get("entities", [])
202+
]
203+
relations = [
204+
SchemaRelation.from_text_or_dict(r)
205+
for r in self.schema.get("relations", [])
206+
]
189207
potential_schema = self.schema.get("potential_schema")
190208
else:
191209
# use individual components
192-
entities = [SchemaEntity.from_text_or_dict(e) for e in self.entities] if self.entities else []
193-
relations = [SchemaRelation.from_text_or_dict(r) for r in self.relations] if self.relations else []
210+
entities = (
211+
[SchemaEntity.from_text_or_dict(e) for e in self.entities]
212+
if self.entities
213+
else []
214+
)
215+
relations = (
216+
[SchemaRelation.from_text_or_dict(r) for r in self.relations]
217+
if self.relations
218+
else []
219+
)
194220
potential_schema = self.potential_schema
195-
221+
196222
return entities, relations, potential_schema
197223

198224
def _get_run_params_for_schema(self) -> dict[str, Any]:
@@ -201,8 +227,10 @@ def _get_run_params_for_schema(self) -> dict[str, Any]:
201227
return {}
202228
else:
203229
# process schema components according to precedence rules
204-
entities, relations, potential_schema = self._process_schema_with_precedence()
205-
230+
entities, relations, potential_schema = (
231+
self._process_schema_with_precedence()
232+
)
233+
206234
return {
207235
"entities": entities,
208236
"relations": relations,
@@ -248,7 +276,7 @@ def _get_connections(self) -> list[ConnectionDefinition]:
248276
input_config={"text": "pdf_loader.text"},
249277
)
250278
)
251-
279+
252280
# handle automatic schema extraction
253281
if self.auto_schema_extraction and not self.has_user_provided_schema():
254282
connections.append(
@@ -258,7 +286,7 @@ def _get_connections(self) -> list[ConnectionDefinition]:
258286
input_config={"text": "pdf_loader.text"},
259287
)
260288
)
261-
289+
262290
connections.append(
263291
ConnectionDefinition(
264292
start="schema",
@@ -279,7 +307,7 @@ def _get_connections(self) -> list[ConnectionDefinition]:
279307
input_config={"text": "text"}, # use the original text input
280308
)
281309
)
282-
310+
283311
connections.append(
284312
ConnectionDefinition(
285313
start="schema",

src/neo4j_graphrag/experimental/pipeline/kg_builder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
logger = logging.getLogger(__name__)
4949

50+
5051
class SimpleKGPipeline:
5152
"""
5253
A class to simplify the process of building a knowledge graph from text documents.
@@ -123,7 +124,9 @@ def __init__(
123124
perform_entity_resolution=perform_entity_resolution,
124125
lexical_graph_config=lexical_graph_config,
125126
neo4j_database=neo4j_database,
126-
auto_schema_extraction=not bool(schema or entities or relations or potential_schema),
127+
auto_schema_extraction=not bool(
128+
schema or entities or relations or potential_schema
129+
),
127130
)
128131
except (ValidationError, ValueError) as e:
129132
raise PipelineDefinitionError() from e
Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
11
from .graphrag import GraphRAG
22
from .prompts import PromptTemplate, RagTemplate, SchemaExtractionTemplate
33

4-
__all__ = [
5-
"GraphRAG",
6-
"PromptTemplate",
7-
"RagTemplate",
8-
"SchemaExtractionTemplate"
9-
]
4+
__all__ = ["GraphRAG", "PromptTemplate", "RagTemplate", "SchemaExtractionTemplate"]

0 commit comments

Comments
 (0)