1
1
import os
2
2
import time
3
3
from .entities import *
4
- from typing import Dict , List , Optional , Tuple
4
+ from typing import Dict , Optional , List , Tuple
5
5
from falkordb import FalkorDB , Path , Node , QueryResult
6
6
7
7
# Configure the logger
@@ -194,12 +194,12 @@ def get_sub_graph(self, l: int) -> dict:
194
194
return sub_graph
195
195
196
196
197
- def get_neighbors (self , node_id : int , rel : Optional [str ] = None , lbl : Optional [str ] = None ) -> Dict [str , List [dict ]]:
197
+ def get_neighbors (self , node_ids : List [ int ] , rel : Optional [str ] = None , lbl : Optional [str ] = None ) -> Dict [str , List [dict ]]:
198
198
"""
199
- Fetch the neighbors of a given node in the graph based on relationship type and/or label.
199
+ Fetch the neighbors of a given nodes in the graph based on relationship type and/or label.
200
200
201
201
Args:
202
- node_id ( int): The ID of the source node .
202
+ node_ids (List[ int] ): The IDs of the source nodes .
203
203
rel (str, optional): The type of relationship to filter by. Defaults to None.
204
204
lbl (str, optional): The label of the destination node to filter by. Defaults to None.
205
205
@@ -208,8 +208,8 @@ def get_neighbors(self, node_id: int, rel: Optional[str] = None, lbl: Optional[s
208
208
"""
209
209
210
210
# Validate inputs
211
- if not isinstance (node_id , int ):
212
- raise ValueError ("node_id must be an integer" )
211
+ if not all ( isinstance (node_id , int ) for node_id in node_ids ):
212
+ raise ValueError ("node_ids must be an integer list " )
213
213
214
214
# Build relationship and label query parts
215
215
rel_query = f":{ rel } " if rel else ""
@@ -218,7 +218,7 @@ def get_neighbors(self, node_id: int, rel: Optional[str] = None, lbl: Optional[s
218
218
# Parameterized Cypher query to find neighbors
219
219
query = f"""
220
220
MATCH (n)-[e{ rel_query } ]->(dest{ lbl_query } )
221
- WHERE ID(n) = $node_id
221
+ WHERE ID(n) IN $node_ids
222
222
RETURN e, dest
223
223
"""
224
224
@@ -227,7 +227,7 @@ def get_neighbors(self, node_id: int, rel: Optional[str] = None, lbl: Optional[s
227
227
228
228
try :
229
229
# Execute the graph query with node_id parameter
230
- result_set = self ._query (query , {'node_id ' : node_id }).result_set
230
+ result_set = self ._query (query , {'node_ids ' : node_ids }).result_set
231
231
232
232
# Iterate over the result set and process nodes and edges
233
233
for edge , destination_node in result_set :
@@ -237,7 +237,7 @@ def get_neighbors(self, node_id: int, rel: Optional[str] = None, lbl: Optional[s
237
237
return neighbors
238
238
239
239
except Exception as e :
240
- logging .error (f"Error fetching neighbors for node { node_id } : { e } " )
240
+ logging .error (f"Error fetching neighbors for node { node_ids } : { e } " )
241
241
return {'nodes' : [], 'edges' : []}
242
242
243
243
0 commit comments