Skip to content

Commit e70fe78

Browse files
fix: fix lint errors
1 parent 8b99b57 commit e70fe78

File tree

7 files changed

+45
-33
lines changed

7 files changed

+45
-33
lines changed

graphgen/operators/judge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from templates import STATEMENT_JUDGEMENT_PROMPT
77

88

9-
async def judge_statement(
9+
async def judge_statement( # pylint: disable=too-many-statements
1010
student_llm_client: OpenAIModel,
1111
graph_storage: NetworkXStorage,
1212
rephrase_storage: JsonKVStorage,

graphgen/operators/split_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def _sort_edges(edges: list, edge_sampling: str) -> list:
224224
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
225225
return edges
226226

227-
async def get_batches_with_strategy(
227+
async def get_batches_with_strategy( # pylint: disable=too-many-arguments
228228
nodes: list,
229229
edges: list,
230230
graph_storage: NetworkXStorage,

graphgen/operators/traverse_graph.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,33 @@ def get_loss_tercile(losses: list) -> (float, float):
100100

101101
return losses[q1_index], losses[q2_index]
102102

103+
def assign_difficulty(subgraphs: list, difficulty_order: list) -> list:
104+
"""
105+
Assign difficulty to subgraphs based on the loss
106+
107+
:param subgraphs
108+
:param difficulty_order
109+
:return
110+
"""
111+
losses = []
112+
for subgraph in subgraphs:
113+
loss = get_average_loss(subgraph)
114+
losses.append(loss)
115+
q1, q2 = get_loss_tercile(losses)
116+
117+
for i, subgraph in enumerate(subgraphs):
118+
loss = get_average_loss(subgraph)
119+
if loss < q1:
120+
# easy
121+
subgraphs[i] = (subgraph[0], subgraph[1], difficulty_order[0])
122+
elif loss < q2:
123+
# medium
124+
subgraphs[i] = (subgraph[0], subgraph[1], difficulty_order[1])
125+
else:
126+
# hard
127+
subgraphs[i] = (subgraph[0], subgraph[1], difficulty_order[2])
128+
return subgraphs
129+
103130
def get_average_loss(batch: tuple) -> float:
104131
if loss_strategy == "only_edge":
105132
return sum(edge[2]['loss'] for edge in batch[1]) / len(batch[1])
@@ -258,24 +285,7 @@ async def _process_single_batch(
258285
traverse_strategy
259286
)
260287

261-
losses = []
262-
for batch in processing_batches:
263-
loss = get_average_loss(batch)
264-
losses.append(loss)
265-
q1, q2 = get_loss_tercile(losses)
266-
267-
difficulty_order = traverse_strategy.difficulty_order
268-
for i, batch in enumerate(processing_batches):
269-
loss = get_average_loss(batch)
270-
if loss < q1:
271-
# easy
272-
processing_batches[i] = (batch[0], batch[1], difficulty_order[0])
273-
elif loss < q2:
274-
# medium
275-
processing_batches[i] = (batch[0], batch[1], difficulty_order[1])
276-
else:
277-
# hard
278-
processing_batches[i] = (batch[0], batch[1], difficulty_order[2])
288+
processing_batches = assign_difficulty(processing_batches, traverse_strategy.difficulty_order)
279289

280290
for result in tqdm_async(asyncio.as_completed(
281291
[_process_single_batch(batch) for batch in processing_batches]

judge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dotenv import load_dotenv
55

66
from models import NetworkXStorage, JsonKVStorage, OpenAIModel
7-
from graphgen.operators import judge_relations
7+
from graphgen.operators import judge_statement
88

99
sys_path = os.path.abspath(os.path.dirname(__file__))
1010

@@ -33,7 +33,7 @@
3333
namespace="rephrase"
3434
)
3535

36-
new_graph = asyncio.run(judge_relations(llm_client, graph_storage, rephrase_storage, re_judge=True))
36+
new_graph = asyncio.run(judge_statement(llm_client, graph_storage, rephrase_storage, re_judge=True))
3737

3838
graph_file = asyncio.run(graph_storage.get_graph())
3939

webui/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import gradio as gr
55

66
from models import TraverseStrategy, NetworkXStorage, Tokenizer
7-
from charts import plot_pre_length_distribution, plot_post_synth_length_distribution, plot_loss_distribution
7+
from webui.charts import plot_pre_length_distribution, plot_post_synth_length_distribution, plot_loss_distribution
88
from graphgen.operators.split_graph import get_batches_with_strategy
99
from utils import create_event_loop
1010

webui/charts/plot_metric_trend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import numpy as np
44
from scipy.interpolate import make_interp_spline
55

6-
def plot_metric_trend(df, x, y):
7-
fig = px.line(df, x=x, y=y,
6+
def plot_metric_trend(dataframe, x, y):
7+
fig = px.line(dataframe, x=x, y=y,
88
color='max length',
99
markers=True,
1010
color_discrete_sequence=['#925EB0', '#7E99F4', '#CC7C71', '#7AB656']) # A5AEB7
1111

12-
fig.update_xaxes(tickvals=df[x], ticktext=[f'{int(val * 100)}%' for val in df[x].unique()])
12+
fig.update_xaxes(tickvals=dataframe[x], ticktext=[f'{int(val * 100)}%' for val in dataframe[x].unique()])
1313

14-
avg = df.groupby(x)[y].mean().reset_index()
14+
avg = dataframe.groupby(x)[y].mean().reset_index()
1515
avg['max length'] = 'Average'
1616

1717
x_smooth = np.linspace(avg[x].min(), avg[x].max(), 500)

webui/charts/plot_rephrase_process.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@ def analyse_log(log_info: dict) -> list:
4040

4141
logs = [log_item for log_item in logs if log_item['log_level'] == 'INFO']
4242

43+
break_index = 0
4344
for i, log_item in enumerate(logs):
4445
match = re.search(r'(\d+) nodes and (\d+) edges processed', log_item['message'])
4546
if match:
47+
break_index = i
4648
break
4749

48-
logs = logs[i:]
50+
logs = logs[break_index:]
4951
assert len(logs) % 3 == 0
5052

5153
# 每三个为一组
@@ -93,8 +95,8 @@ def plot_pre_length_distribution(stats: list[dict]):
9395
length_distribution = defaultdict(int)
9496

9597
# 一次遍历完成所有统计
96-
for item in stats:
97-
bin_start = (item['pre_length'] // bin_size) * bin_size
98+
for stat in stats:
99+
bin_start = (stat['pre_length'] // bin_size) * bin_size
98100
bin_key = f"{bin_start}-{bin_start + bin_size}"
99101
length_distribution[bin_key] += 1
100102

@@ -145,16 +147,16 @@ def plot_post_synth_length_distribution(stats: list[dict]):
145147
return go.Figure()
146148

147149
# 计算最大长度并确定区间
148-
max_length = max(item['post_length'] for item in stats)
150+
max_length = max(stat['post_length'] for stat in stats)
149151
bin_size = 50
150152
max_length = ((max_length // bin_size) + 1) * bin_size
151153

152154
# 使用defaultdict避免键不存在的检查
153155
length_distribution = defaultdict(int)
154156

155157
# 一次遍历完成所有统计
156-
for item in stats:
157-
bin_start = (item['post_length'] // bin_size) * bin_size
158+
for stat in stats:
159+
bin_start = (stat['post_length'] // bin_size) * bin_size
158160
bin_key = f"{bin_start}-{bin_start + bin_size}"
159161
length_distribution[bin_key] += 1
160162

0 commit comments

Comments
 (0)