Skip to content

Upgrade Anthropic & clean ruff/mypy stuffs #302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,19 @@
### Added

- Added optional schema enforcement as a validation layer after entity and relation extraction.
- Introduced a linear hybrid search ranker for HybridRetriever and HybridCypherRetriever, allowing customizable ranking with an `alpha` parameter.
- Introduced SearchQueryParseError for handling invalid Lucene query strings in HybridRetriever and HybridCypherRetriever.

### Fixed

- Fixed config loading after module reload (usage in jupyter notebooks)

### Changed

- Qdrant retriever now fallbacks on the point ID if the `external_id_property` is not found in the payload.
- Updated a few dependencies, mainly `pypdf`, `anthropic` and `cohere`.


## 1.5.0

### Added
Expand All @@ -18,7 +29,7 @@
- Introduced Neo4jMessageHistory and InMemoryMessageHistory classes for managing LLM message histories.
- Added examples and documentation for using message history with Neo4j and in-memory storage.
- Updated LLM and GraphRAG classes to support new message history classes.
- Introduced a linear hybrid search ranker for HybridRetriever and HybridCypherRetriever, allowing customizable ranking with an `alpha` parameter.

### Changed

- Refactored index-related functions for improved compatibility and functionality.
Expand Down Expand Up @@ -311,4 +322,3 @@

- Updated documentation to include new custom exceptions.
- Improved the use of Pydantic for input data validation for retriever objects.
- Fixed config loading after module reload (usage in jupyter notebooks)
2,352 changes: 989 additions & 1,363 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ mistralai = {version = "^1.0.3", optional = true}
qdrant-client = {version = "^1.11.3", optional = true}
llama-index = {version = "^0.12.0", optional = true }
openai = {version = "^1.51.1", optional = true }
anthropic = { version = "^0.36.0", optional = true}
anthropic = {version = "^0.49.0", optional = true}
sentence-transformers = {version = "^3.0.0", optional = true }
ollama = {version = "^0.4.4", optional = true}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,11 @@ async def run(
return graph

def validate_chunk(
self,
chunk_graph: Neo4jGraph,
schema: SchemaConfig
self, chunk_graph: Neo4jGraph, schema: SchemaConfig
) -> Neo4jGraph:
"""
Perform validation after entity and relation extraction:
- Enforce schema if schema enforcement mode is on and schema is provided
Perform validation after entity and relation extraction:
- Enforce schema if schema enforcement mode is on and schema is provided
"""
if self.enforce_schema != SchemaEnforcementMode.NONE:
if not schema or not schema.entities: # schema is not provided
Expand All @@ -365,9 +363,9 @@ def validate_chunk(
return chunk_graph

def _clean_graph(
self,
graph: Neo4jGraph,
schema: SchemaConfig,
self,
graph: Neo4jGraph,
schema: SchemaConfig,
) -> Neo4jGraph:
"""
Verify that the graph conforms to the provided schema.
Expand All @@ -389,17 +387,15 @@ def _clean_graph(
return Neo4jGraph(nodes=filtered_nodes, relationships=filtered_rels)

def _enforce_nodes(
self,
extracted_nodes: List[Neo4jNode],
schema: SchemaConfig
self, extracted_nodes: List[Neo4jNode], schema: SchemaConfig
) -> List[Neo4jNode]:
"""
Filter extracted nodes to be conformant to the schema.
Filter extracted nodes to be conformant to the schema.

Keep only those whose label is in schema.
For each valid node, filter out properties not present in the schema.
Remove a node if it ends up with no valid properties.
"""
Keep only those whose label is in schema.
For each valid node, filter out properties not present in the schema.
Remove a node if it ends up with no valid properties.
"""
if self.enforce_schema != SchemaEnforcementMode.STRICT:
return extracted_nodes

Expand All @@ -424,10 +420,10 @@ def _enforce_nodes(
return valid_nodes

def _enforce_relationships(
self,
extracted_relationships: List[Neo4jRelationship],
filtered_nodes: List[Neo4jNode],
schema: SchemaConfig
self,
extracted_relationships: List[Neo4jRelationship],
filtered_nodes: List[Neo4jNode],
schema: SchemaConfig,
) -> List[Neo4jRelationship]:
"""
Filter extracted nodes to be conformant to the schema.
Expand All @@ -447,12 +443,16 @@ def _enforce_relationships(
potential_schema = schema.potential_schema

for rel in extracted_relationships:
schema_relation = schema.relations.get(rel.type)
schema_relation = (
schema.relations.get(rel.type) if schema.relations else None
)
if not schema_relation:
continue

if (rel.start_node_id not in valid_nodes or
rel.end_node_id not in valid_nodes):
if (
rel.start_node_id not in valid_nodes
or rel.end_node_id not in valid_nodes
):
continue

start_label = valid_nodes[rel.start_node_id]
Expand All @@ -461,8 +461,11 @@ def _enforce_relationships(
tuple_valid = True
if potential_schema:
tuple_valid = (start_label, rel.type, end_label) in potential_schema
reverse_tuple_valid = ((end_label, rel.type, start_label) in
potential_schema)
reverse_tuple_valid = (
end_label,
rel.type,
start_label,
) in potential_schema

if not tuple_valid and not reverse_tuple_valid:
continue
Expand All @@ -483,18 +486,13 @@ def _enforce_relationships(
return valid_rels

def _enforce_properties(
self,
properties: Dict[str, Any],
valid_properties: List[Dict[str, Any]]
self, properties: Dict[str, Any], valid_properties: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Filter properties.
Keep only those that exist in schema (i.e., valid properties).
"""
valid_prop_names = {prop["name"] for prop in valid_properties}
return {
key: value
for key, value in properties.items()
if key in valid_prop_names
key: value for key, value in properties.items() if key in valid_prop_names
}

Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)
from neo4j_graphrag.experimental.components.types import (
LexicalGraphConfig,
SchemaEnforcementMode
SchemaEnforcementMode,
)
from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType
from neo4j_graphrag.experimental.pipeline.config.template_pipeline.base import (
Expand Down
6 changes: 3 additions & 3 deletions src/neo4j_graphrag/llm/cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def invoke(
except self.cohere_api_error as e:
raise LLMGenerationError(e)
return LLMResponse(
content=res.message.content[0].text,
content=res.message.content[0].text if res.message.content else "",
)

async def ainvoke(
Expand All @@ -148,12 +148,12 @@ async def ainvoke(
if isinstance(message_history, MessageHistory):
message_history = message_history.messages
messages = self.get_messages(input, message_history, system_instruction)
res = self.async_client.chat(
res = await self.async_client.chat(
messages=messages,
model=self.model_name,
)
except self.cohere_api_error as e:
raise LLMGenerationError(e)
return LLMResponse(
content=res.message.content[0].text,
content=res.message.content[0].text if res.message.content else "",
)
Loading