Skip to content

Commit fa8a6af

Browse files
Bug fixes
1 parent 49452d4 commit fa8a6af

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,13 @@ def store_as_yaml(self, file_path: str) -> None:
146146
Args:
147147
file_path (str): The path where the schema configuration will be saved.
148148
"""
149+
# create a copy of the data and convert tuples to lists for YAML compatibility
150+
data = self.model_dump()
151+
if data.get('potential_schema'):
152+
data['potential_schema'] = [list(item) for item in data['potential_schema']]
153+
149154
with open(file_path, 'w') as f:
150-
yaml.dump(self.model_dump(), f, default_flow_style=False, sort_keys=False)
155+
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
151156

152157
@classmethod
153158
def from_file(cls, file_path: Union[str, Path]) -> Self:
@@ -347,18 +352,18 @@ def __init__(
347352
self._llm_params: dict[str, Any] = llm_params or {}
348353

349354
@validate_call
350-
async def run(self, text: str, **kwargs: Any) -> SchemaConfig:
355+
async def run(self, text: str, examples:str = "", **kwargs: Any) -> SchemaConfig:
351356
"""
352357
Asynchronously extracts the schema from text and returns a SchemaConfig object.
353358
354359
Args:
355360
text (str): the text from which the schema will be inferred.
356-
361+
examples (str): examples to guide schema extraction.
357362
Returns:
358363
SchemaConfig: A configured schema object, extracted automatically and
359364
constructed asynchronously.
360365
"""
361-
prompt: str = self._prompt_template.format(text=text)
366+
prompt: str = self._prompt_template.format(text=text, examples=examples)
362367

363368
response = await self._llm.invoke(prompt, **self._llm_params)
364369
content: str = (

src/neo4j_graphrag/generation/prompts.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,10 @@ class SchemaExtractionTemplate(PromptTemplate):
219219
220220
For example, if the text says "Alice lives in London", the output JSON object should
221221
adhere to the following format:
222-
{"entities": [{"label": "Person", "properties": [{"name": "name", "type": "STRING"}]},
223-
{"label": "City", "properties":[{"name": "name", "type": "STRING"}]}],
224-
"relations": [{"label": "LIVES_IN"}],
225-
"potential_schema":[[ "Person", "LIVES_IN", "City"]]}
222+
{{"entities": [{{"label": "Person", "properties": [{{"name": "name", "type": "STRING"}}]}},
223+
{{"label": "City", "properties":[{{"name": "name", "type": "STRING"}}]}}],
224+
"relations": [{{"label": "LIVES_IN"}}],
225+
"potential_schema":[[ "Person", "LIVES_IN", "City"]]}}
226226
227227
More examples:
228228
{examples}
@@ -233,8 +233,8 @@ class SchemaExtractionTemplate(PromptTemplate):
233233
EXPECTED_INPUTS = ["text"]
234234

235235
def format(
236-
self,
237-
examples: str,
238-
text: str = "",
236+
self,
237+
text: str = "",
238+
examples: str = "",
239239
) -> str:
240240
return super().format(text=text, examples=examples)

0 commit comments

Comments
 (0)