Skip to content

Commit b7928fa

Browse files
authored
Merge pull request #26 from FalkorDB/get_neighbors-get-ids-list
get_neighbors get list of ids
2 parents 4c9da39 + dbe2424 commit b7928fa

File tree

2 files changed

+23
-28
lines changed

2 files changed

+23
-28
lines changed

api/graph.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import time
33
from .entities import *
4-
from typing import Dict, List, Optional, Tuple
4+
from typing import Dict, Optional, List, Tuple
55
from falkordb import FalkorDB, Path, Node, QueryResult
66

77
# Configure the logger
@@ -194,12 +194,12 @@ def get_sub_graph(self, l: int) -> dict:
194194
return sub_graph
195195

196196

197-
def get_neighbors(self, node_id: int, rel: Optional[str] = None, lbl: Optional[str] = None) -> Dict[str, List[dict]]:
197+
def get_neighbors(self, node_ids: List[int], rel: Optional[str] = None, lbl: Optional[str] = None) -> Dict[str, List[dict]]:
198198
"""
199-
Fetch the neighbors of a given node in the graph based on relationship type and/or label.
199+
Fetch the neighbors of a given nodes in the graph based on relationship type and/or label.
200200
201201
Args:
202-
node_id (int): The ID of the source node.
202+
node_ids (List[int]): The IDs of the source nodes.
203203
rel (str, optional): The type of relationship to filter by. Defaults to None.
204204
lbl (str, optional): The label of the destination node to filter by. Defaults to None.
205205
@@ -208,8 +208,8 @@ def get_neighbors(self, node_id: int, rel: Optional[str] = None, lbl: Optional[s
208208
"""
209209

210210
# Validate inputs
211-
if not isinstance(node_id, int):
212-
raise ValueError("node_id must be an integer")
211+
if not all(isinstance(node_id, int) for node_id in node_ids):
212+
raise ValueError("node_ids must be an integer list")
213213

214214
# Build relationship and label query parts
215215
rel_query = f":{rel}" if rel else ""
@@ -218,7 +218,7 @@ def get_neighbors(self, node_id: int, rel: Optional[str] = None, lbl: Optional[s
218218
# Parameterized Cypher query to find neighbors
219219
query = f"""
220220
MATCH (n)-[e{rel_query}]->(dest{lbl_query})
221-
WHERE ID(n) = $node_id
221+
WHERE ID(n) IN $node_ids
222222
RETURN e, dest
223223
"""
224224

@@ -227,7 +227,7 @@ def get_neighbors(self, node_id: int, rel: Optional[str] = None, lbl: Optional[s
227227

228228
try:
229229
# Execute the graph query with node_id parameter
230-
result_set = self._query(query, {'node_id': node_id}).result_set
230+
result_set = self._query(query, {'node_ids': node_ids}).result_set
231231

232232
# Iterate over the result set and process nodes and edges
233233
for edge, destination_node in result_set:
@@ -237,7 +237,7 @@ def get_neighbors(self, node_id: int, rel: Optional[str] = None, lbl: Optional[s
237237
return neighbors
238238

239239
except Exception as e:
240-
logging.error(f"Error fetching neighbors for node {node_id}: {e}")
240+
logging.error(f"Error fetching neighbors for node {node_ids}: {e}")
241241
return {'nodes': [], 'edges': []}
242242

243243

api/index.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -79,51 +79,47 @@ def graph_entities():
7979
return jsonify({"status": "Internal server error"}), 500
8080

8181

82-
@app.route('/get_neighbors', methods=['GET'])
82+
@app.route('/get_neighbors', methods=['POST'])
8383
@token_required # Apply token authentication decorator
8484
def get_neighbors():
8585
"""
86-
Endpoint to get neighbors of a specific node in the graph.
87-
Expects 'repo' and 'node_id' as query parameters.
86+
Endpoint to get neighbors of a nodes list in the graph.
87+
Expects 'repo' and 'node_ids' as body parameters.
8888
8989
Returns:
9090
JSON response containing neighbors or error messages.
9191
"""
9292

93+
# Get JSON data from the request
94+
data = request.get_json()
95+
9396
# Get query parameters
94-
repo = request.args.get('repo')
95-
node_id = request.args.get('node_id')
97+
repo = data.get('repo')
98+
node_ids = data.get('node_ids')
9699

97100
# Validate 'repo' parameter
98101
if not repo:
99102
logging.error("Repository name is missing in the request.")
100103
return jsonify({"status": "Repository name is required."}), 400
101104

102-
# Validate 'node_id' parameter
103-
if not node_id:
104-
logging.error("Node ID is missing in the request.")
105-
return jsonify({"status": "Node ID is required."}), 400
105+
# Validate 'node_ids' parameter
106+
if not node_ids:
107+
logging.error("Node IDs is missing in the request.")
108+
return jsonify({"status": "Node IDs is required."}), 400
106109

107110
# Validate repo exists
108111
if not graph_exists(repo):
109112
logging.error(f"Missing project {repo}")
110113
return jsonify({"status": f"Missing project {repo}"}), 400
111114

112-
# Try converting node_id to an integer
113-
try:
114-
node_id = int(node_id)
115-
except ValueError:
116-
logging.error(f"Invalid node ID: {node_id}. It must be an integer.")
117-
return jsonify({"status": "Invalid node ID. It must be an integer."}), 400
118-
119115
# Initialize the graph with the provided repository
120116
g = Graph(repo)
121117

122118
# Fetch the neighbors of the specified node
123-
neighbors = g.get_neighbors(node_id)
119+
neighbors = g.get_neighbors(node_ids)
124120

125121
# Log and return the neighbors
126-
logging.info(f"Successfully retrieved neighbors for node ID {node_id} in repo '{repo}'.")
122+
logging.info(f"Successfully retrieved neighbors for node IDs {node_ids} in repo '{repo}'.")
127123

128124
response = {
129125
'status': 'success',
@@ -170,7 +166,6 @@ def auto_complete():
170166

171167
return jsonify(response), 200
172168

173-
174169
@app.route('/list_repos', methods=['GET'])
175170
@token_required # Apply token authentication decorator
176171
def list_repos():

0 commit comments

Comments
 (0)