Skip to content

Commit ba4fdb3

Browse files
fix(graphgen): sort before adj construction
1 parent 0c8d787 commit ba4fdb3

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

graphgen/operators/split_graph.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def _get_level_n_edges_by_max_width(
7272
if len(candidate_edges) >= max_extra_edges:
7373
er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges]
7474
candidate_edges = _sort_edges(er_tuples, edge_sampling)[:max_extra_edges]
75+
7576
for edge in candidate_edges:
7677
level_n_edges.append(edge)
7778
edge[2]["visited"] = True
@@ -145,6 +146,7 @@ def _get_level_n_edges_by_max_tokens(
145146

146147
er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges]
147148
candidate_edges = _sort_edges(er_tuples, edge_sampling)
149+
148150
for edge in candidate_edges:
149151
max_tokens -= edge[2]["length"]
150152
if not edge[0] in temp_nodes:
@@ -219,16 +221,16 @@ async def get_cached_node_info(node_id: str) -> dict:
219221
node_cache[node_id] = await _get_node_info(node_id, graph_storage)
220222
return node_cache[node_id]
221223

222-
for i, (src, tgt, _) in enumerate(edges):
223-
edge_adj_list[src].append(i)
224-
edge_adj_list[tgt].append(i)
225-
226224
for i, (node_name, _) in enumerate(nodes):
227225
node_dict[node_name] = i
228226

229227
er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in edges]
230228
edges = _sort_edges(er_tuples, edge_sampling)
231229

230+
for i, (src, tgt, _) in enumerate(edges):
231+
edge_adj_list[src].append(i)
232+
edge_adj_list[tgt].append(i)
233+
232234
for edge in tqdm_async(edges, desc="Preparing batches"):
233235
if "visited" in edge[2] and edge[2]["visited"]:
234236
continue
@@ -269,10 +271,14 @@ async def get_cached_node_info(node_id: str) -> dict:
269271

270272
processing_batches.append((_process_nodes, _process_edges))
271273

274+
logger.info("Processing batches: %d", len(processing_batches))
275+
272276
# isolate nodes
273277
isolated_node_strategy = traverse_strategy.isolated_node_strategy
274278
if isolated_node_strategy == "add":
275279
processing_batches = await _add_isolated_nodes(nodes, processing_batches, graph_storage)
280+
logger.info("Processing batches after adding isolated nodes: %d", len(processing_batches))
281+
276282
return processing_batches
277283

278284
async def _add_isolated_nodes(

0 commit comments

Comments
 (0)