Skip to content

Commit 60d8f97

Browse files
authored
Upgrade Anthropic & clean ruff/mypy stuffs (neo4j#302)
* Upgrade Anthropic * Fix mypy and tests * Update CHANGELOG * Ruff * Mypy
1 parent 5b868aa commit 60d8f97

File tree

8 files changed

+1196
-1507
lines changed

8 files changed

+1196
-1507
lines changed

CHANGELOG.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,19 @@
55
### Added
66

77
- Added optional schema enforcement as a validation layer after entity and relation extraction.
8+
- Introduced a linear hybrid search ranker for HybridRetriever and HybridCypherRetriever, allowing customizable ranking with an `alpha` parameter.
89
- Introduced SearchQueryParseError for handling invalid Lucene query strings in HybridRetriever and HybridCypherRetriever.
910

11+
### Fixed
12+
13+
- Fixed config loading after module reload (usage in jupyter notebooks)
14+
15+
### Changed
16+
17+
- Qdrant retriever now fallbacks on the point ID if the `external_id_property` is not found in the payload.
18+
- Updated a few dependencies, mainly `pypdf`, `anthropic` and `cohere`.
19+
20+
1021
## 1.5.0
1122

1223
### Added
@@ -18,7 +29,7 @@
1829
- Introduced Neo4jMessageHistory and InMemoryMessageHistory classes for managing LLM message histories.
1930
- Added examples and documentation for using message history with Neo4j and in-memory storage.
2031
- Updated LLM and GraphRAG classes to support new message history classes.
21-
- Introduced a linear hybrid search ranker for HybridRetriever and HybridCypherRetriever, allowing customizable ranking with an `alpha` parameter.
32+
2233
### Changed
2334

2435
- Refactored index-related functions for improved compatibility and functionality.
@@ -311,4 +322,3 @@
311322

312323
- Updated documentation to include new custom exceptions.
313324
- Improved the use of Pydantic for input data validation for retriever objects.
314-
- Fixed config loading after module reload (usage in jupyter notebooks)

poetry.lock

Lines changed: 989 additions & 1363 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ mistralai = {version = "^1.0.3", optional = true}
5050
qdrant-client = {version = "^1.11.3", optional = true}
5151
llama-index = {version = "^0.12.0", optional = true }
5252
openai = {version = "^1.51.1", optional = true }
53-
anthropic = { version = "^0.36.0", optional = true}
53+
anthropic = {version = "^0.49.0", optional = true}
5454
sentence-transformers = {version = "^3.0.0", optional = true }
5555
ollama = {version = "^0.4.4", optional = true}
5656

src/neo4j_graphrag/experimental/components/entity_relation_extractor.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,11 @@ async def run(
346346
return graph
347347

348348
def validate_chunk(
349-
self,
350-
chunk_graph: Neo4jGraph,
351-
schema: SchemaConfig
349+
self, chunk_graph: Neo4jGraph, schema: SchemaConfig
352350
) -> Neo4jGraph:
353351
"""
354-
Perform validation after entity and relation extraction:
355-
- Enforce schema if schema enforcement mode is on and schema is provided
352+
Perform validation after entity and relation extraction:
353+
- Enforce schema if schema enforcement mode is on and schema is provided
356354
"""
357355
if self.enforce_schema != SchemaEnforcementMode.NONE:
358356
if not schema or not schema.entities: # schema is not provided
@@ -365,9 +363,9 @@ def validate_chunk(
365363
return chunk_graph
366364

367365
def _clean_graph(
368-
self,
369-
graph: Neo4jGraph,
370-
schema: SchemaConfig,
366+
self,
367+
graph: Neo4jGraph,
368+
schema: SchemaConfig,
371369
) -> Neo4jGraph:
372370
"""
373371
Verify that the graph conforms to the provided schema.
@@ -389,17 +387,15 @@ def _clean_graph(
389387
return Neo4jGraph(nodes=filtered_nodes, relationships=filtered_rels)
390388

391389
def _enforce_nodes(
392-
self,
393-
extracted_nodes: List[Neo4jNode],
394-
schema: SchemaConfig
390+
self, extracted_nodes: List[Neo4jNode], schema: SchemaConfig
395391
) -> List[Neo4jNode]:
396392
"""
397-
Filter extracted nodes to be conformant to the schema.
393+
Filter extracted nodes to be conformant to the schema.
398394
399-
Keep only those whose label is in schema.
400-
For each valid node, filter out properties not present in the schema.
401-
Remove a node if it ends up with no valid properties.
402-
"""
395+
Keep only those whose label is in schema.
396+
For each valid node, filter out properties not present in the schema.
397+
Remove a node if it ends up with no valid properties.
398+
"""
403399
if self.enforce_schema != SchemaEnforcementMode.STRICT:
404400
return extracted_nodes
405401

@@ -424,10 +420,10 @@ def _enforce_nodes(
424420
return valid_nodes
425421

426422
def _enforce_relationships(
427-
self,
428-
extracted_relationships: List[Neo4jRelationship],
429-
filtered_nodes: List[Neo4jNode],
430-
schema: SchemaConfig
423+
self,
424+
extracted_relationships: List[Neo4jRelationship],
425+
filtered_nodes: List[Neo4jNode],
426+
schema: SchemaConfig,
431427
) -> List[Neo4jRelationship]:
432428
"""
433429
Filter extracted nodes to be conformant to the schema.
@@ -447,12 +443,16 @@ def _enforce_relationships(
447443
potential_schema = schema.potential_schema
448444

449445
for rel in extracted_relationships:
450-
schema_relation = schema.relations.get(rel.type)
446+
schema_relation = (
447+
schema.relations.get(rel.type) if schema.relations else None
448+
)
451449
if not schema_relation:
452450
continue
453451

454-
if (rel.start_node_id not in valid_nodes or
455-
rel.end_node_id not in valid_nodes):
452+
if (
453+
rel.start_node_id not in valid_nodes
454+
or rel.end_node_id not in valid_nodes
455+
):
456456
continue
457457

458458
start_label = valid_nodes[rel.start_node_id]
@@ -461,8 +461,11 @@ def _enforce_relationships(
461461
tuple_valid = True
462462
if potential_schema:
463463
tuple_valid = (start_label, rel.type, end_label) in potential_schema
464-
reverse_tuple_valid = ((end_label, rel.type, start_label) in
465-
potential_schema)
464+
reverse_tuple_valid = (
465+
end_label,
466+
rel.type,
467+
start_label,
468+
) in potential_schema
466469

467470
if not tuple_valid and not reverse_tuple_valid:
468471
continue
@@ -483,18 +486,13 @@ def _enforce_relationships(
483486
return valid_rels
484487

485488
def _enforce_properties(
486-
self,
487-
properties: Dict[str, Any],
488-
valid_properties: List[Dict[str, Any]]
489+
self, properties: Dict[str, Any], valid_properties: List[Dict[str, Any]]
489490
) -> Dict[str, Any]:
490491
"""
491492
Filter properties.
492493
Keep only those that exist in schema (i.e., valid properties).
493494
"""
494495
valid_prop_names = {prop["name"] for prop in valid_properties}
495496
return {
496-
key: value
497-
for key, value in properties.items()
498-
if key in valid_prop_names
497+
key: value for key, value in properties.items() if key in valid_prop_names
499498
}
500-

src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
)
4040
from neo4j_graphrag.experimental.components.types import (
4141
LexicalGraphConfig,
42-
SchemaEnforcementMode
42+
SchemaEnforcementMode,
4343
)
4444
from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType
4545
from neo4j_graphrag.experimental.pipeline.config.template_pipeline.base import (

src/neo4j_graphrag/llm/cohere_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def invoke(
124124
except self.cohere_api_error as e:
125125
raise LLMGenerationError(e)
126126
return LLMResponse(
127-
content=res.message.content[0].text,
127+
content=res.message.content[0].text if res.message.content else "",
128128
)
129129

130130
async def ainvoke(
@@ -148,12 +148,12 @@ async def ainvoke(
148148
if isinstance(message_history, MessageHistory):
149149
message_history = message_history.messages
150150
messages = self.get_messages(input, message_history, system_instruction)
151-
res = self.async_client.chat(
151+
res = await self.async_client.chat(
152152
messages=messages,
153153
model=self.model_name,
154154
)
155155
except self.cohere_api_error as e:
156156
raise LLMGenerationError(e)
157157
return LLMResponse(
158-
content=res.message.content[0].text,
158+
content=res.message.content[0].text if res.message.content else "",
159159
)

0 commit comments

Comments
 (0)