Skip to content

Commit 0c8d787

Browse files
feat(graphgen): use average loss when choosing nodes & edges
1 parent db7a071 commit 0c8d787

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

graphgen/operators/split_graph.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ async def _get_node_info(
2626

2727
def _get_level_n_edges_by_max_width(
2828
edge_adj_list: dict,
29+
node_dict: dict,
2930
edges: list,
30-
src_id: str,
31-
tgt_id: str,
31+
nodes,
32+
src_edge: tuple,
3233
max_depth: int,
3334
bidirectional: bool,
3435
max_extra_edges: int,
@@ -39,15 +40,18 @@ def _get_level_n_edges_by_max_width(
3940
n is decided by max_depth in traverse_strategy
4041
4142
:param edge_adj_list
43+
:param node_dict
4244
:param edges
43-
:param src_id
44-
:param tgt_id
45+
:param nodes
46+
:param src_edge
4547
:param max_depth
4648
:param bidirectional
4749
:param max_extra_edges
4850
:param edge_sampling
4951
:return: level n edges
5052
"""
53+
src_id, tgt_id, _ = src_edge
54+
5155
level_n_edges = []
5256

5357
start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id}
@@ -66,7 +70,8 @@ def _get_level_n_edges_by_max_width(
6670
break
6771

6872
if len(candidate_edges) >= max_extra_edges:
69-
candidate_edges = _sort_edges(candidate_edges, edge_sampling)[:max_extra_edges]
73+
er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges]
74+
candidate_edges = _sort_edges(er_tuples, edge_sampling)[:max_extra_edges]
7075
for edge in candidate_edges:
7176
level_n_edges.append(edge)
7277
edge[2]["visited"] = True
@@ -138,7 +143,8 @@ def _get_level_n_edges_by_max_tokens(
138143
if not candidate_edges:
139144
break
140145

141-
candidate_edges = _sort_edges(candidate_edges, edge_sampling)
146+
er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges]
147+
candidate_edges = _sort_edges(er_tuples, edge_sampling)
142148
for edge in candidate_edges:
143149
max_tokens -= edge[2]["length"]
144150
if not edge[0] in temp_nodes:
@@ -166,22 +172,24 @@ def _get_level_n_edges_by_max_tokens(
166172
return level_n_edges
167173

168174

169-
def _sort_edges(edges: list, edge_sampling: str) -> list:
175+
def _sort_edges(er_tuples: list, edge_sampling: str) -> list:
170176
"""
171177
Sort edges with edge sampling strategy
172178
173-
:param edges: total edges
179+
:param er_tuples: [(nodes:list, edge:tuple)]
174180
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
175181
:return: sorted edges
176182
"""
177183
if edge_sampling == "random":
178-
random.shuffle(edges)
184+
er_tuples = random.sample(er_tuples, len(er_tuples))
179185
elif edge_sampling == "min_loss":
180-
edges = sorted(edges, key=lambda x: x[2]["loss"])
186+
er_tuples = sorted(er_tuples, key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"])
181187
elif edge_sampling == "max_loss":
182-
edges = sorted(edges, key=lambda x: x[2]["loss"], reverse=True)
188+
er_tuples = sorted(er_tuples, key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
189+
reverse=True)
183190
else:
184191
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
192+
edges = [edge for _, edge in er_tuples]
185193
return edges
186194

187195
async def get_batches_with_strategy(
@@ -199,8 +207,6 @@ async def get_batches_with_strategy(
199207
max_depth = traverse_strategy.max_depth
200208
edge_sampling = traverse_strategy.edge_sampling
201209

202-
edges = _sort_edges(edges, edge_sampling)
203-
204210
# 构建临接矩阵
205211
edge_adj_list = defaultdict(list)
206212
node_dict = {}
@@ -220,6 +226,9 @@ async def get_cached_node_info(node_id: str) -> dict:
220226
for i, (node_name, _) in enumerate(nodes):
221227
node_dict[node_name] = i
222228

229+
er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in edges]
230+
edges = _sort_edges(er_tuples, edge_sampling)
231+
223232
for edge in tqdm_async(edges, desc="Preparing batches"):
224233
if "visited" in edge[2] and edge[2]["visited"]:
225234
continue
@@ -238,7 +247,7 @@ async def get_cached_node_info(node_id: str) -> dict:
238247

239248
if expand_method == "max_width":
240249
level_n_edges = _get_level_n_edges_by_max_width(
241-
edge_adj_list, edges, src_id, tgt_id, max_depth,
250+
edge_adj_list, node_dict, edges, nodes, edge, max_depth,
242251
traverse_strategy.bidirectional, traverse_strategy.max_extra_edges,
243252
edge_sampling
244253
)
@@ -260,7 +269,6 @@ async def get_cached_node_info(node_id: str) -> dict:
260269

261270
processing_batches.append((_process_nodes, _process_edges))
262271

263-
l
264272
# isolate nodes
265273
isolated_node_strategy = traverse_strategy.isolated_node_strategy
266274
if isolated_node_strategy == "add":

0 commit comments

Comments
 (0)