Skip to content

Commit 9258682

Browse files
committed
[Trace Analysis] Added a column mapping to get rid of static column references
1 parent b128d63 commit 9258682

File tree

1 file changed

+75
-43
lines changed

1 file changed

+75
-43
lines changed

thinkbench/trace_analysis/trace_analyzer.py

Lines changed: 75 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,33 @@ def write_to_json_file(data, suffix: str = ""):
5959

6060
@staticmethod
6161
def write_to_xlsx(analysis_results: List[Dict]):
62+
padding_top_left = 1
63+
64+
cols = [
65+
"question_id",
66+
"question",
67+
"reasoning",
68+
"answer_sentences",
69+
"automatic_extraction",
70+
"manual_extraction",
71+
"trace_label_correct",
72+
"model_choice",
73+
"model_choice_correct",
74+
"correct_answer",
75+
"labels_match",
76+
"comment"
77+
]
78+
79+
def col_id(col_name: str, ignore_padding: bool = False):
80+
col_id = cols.index(col_name)
81+
if not ignore_padding:
82+
col_id += 1
83+
84+
return col_id
85+
86+
def col_letter(col_name: str):
87+
return chr(65 + col_id(col_name, ignore_padding=False))
88+
6289
pastel_red = "#FFCCCC"
6390
pastel_green = "#CCFFCC"
6491

@@ -87,31 +114,34 @@ def write_to_xlsx(analysis_results: List[Dict]):
87114
'right': 5
88115
})
89116

117+
# Set column widths
90118
for col_num, col_data in enumerate(analysis_result["result_rows"][0].keys()):
91-
worksheet.set_column(col_num + 1, col_num + 1, 10)
92-
93-
worksheet.set_column('C:C', 60) # reasoning
94-
worksheet.set_column('D:D', 40) # answer sentence
95-
worksheet.set_column('L:L', 40) # comment
119+
worksheet.set_column(col_num + padding_top_left, col_num + padding_top_left, 10)
120+
worksheet.set_column(f"{col_letter('question')}:{col_letter('question')}", 30)
121+
worksheet.set_column(f"{col_letter('reasoning')}:{col_letter('reasoning')}", 60)
122+
worksheet.set_column(f"{col_letter('answer_sentences')}:{col_letter('answer_sentences')}", 30)
123+
worksheet.set_column(f"{col_letter('comment')}:{col_letter('comment')}", 30)
96124

97-
# Table content
125+
# Write table content
98126
for row_num, row_data in enumerate(analysis_result["result_rows"]):
99127
for col_num, cell_data in enumerate(list(row_data.values())):
100-
if col_num == 3:
128+
if col_num == col_id("automatic_extraction", True):
101129
worksheet.write(row_num + 2, col_num + 1, cell_data, thick_border_left_format)
102-
elif col_num == 5:
103-
worksheet.write_formula(row_num + 2, col_num + 1, f'=IF(OR($E{row_num + 3}=$J{row_num + 3}, $F{row_num + 3}=$J{row_num + 3}), 1, 0)', thick_border_right_format)
104-
elif col_num == 7 or col_num == 8:
130+
elif col_num == col_id("trace_label_correct", True):
131+
worksheet.write_formula(row_num + 2, col_num + 1, f'=IF(OR('
132+
f'${col_letter("automatic_extraction")}{row_num + 3}=${col_letter("correct_answer")}{row_num + 3}, '
133+
f'${col_letter("manual_extraction")}{row_num + 3}=${col_letter("correct_answer")}{row_num + 3}), 1, 0)', thick_border_right_format)
134+
elif col_num == col_id("model_choice_correct", True) or col_num == col_id("correct_answer", True):
105135
worksheet.write(row_num + 2, col_num + 1, cell_data, thick_border_right_format)
106-
elif col_num == 9:
107-
worksheet.write_formula(row_num + 2, col_num + 1, f'=IF(OR($E{row_num + 3}=$H{row_num + 3}, $F{row_num + 3}=$H{row_num + 3}), 1, 0)', cell_format)
136+
elif col_num == col_id("labels_match", True):
137+
worksheet.write_formula(row_num + 2, col_num + 1, f'=IF(OR(${col_letter("automatic_extraction")}{row_num + 3}=${col_letter("model_choice")}{row_num + 3}, ${col_letter("manual_extraction")}{row_num + 3}=${col_letter("model_choice")}{row_num + 3}), 1, 0)', cell_format)
108138
else:
109139
worksheet.write(row_num + 2, col_num + 1, cell_data, cell_format)
110140

111141
# Define the table range (start_row, start_col, end_row, end_col)
112-
start_row = 1
113-
start_col = 1
114-
end_row = len(analysis_result["result_rows"]) + 2 # 1 space, 1 header
142+
start_row = padding_top_left
143+
start_col = padding_top_left
144+
end_row = len(analysis_result["result_rows"]) + padding_top_left + 1 # 1 space, 1 header
115145
end_col = len(analysis_result["result_rows"][0].values())
116146

117147
# Add the table with headers
@@ -125,88 +155,88 @@ def write_to_xlsx(analysis_results: List[Dict]):
125155
for row_num in range(start_row + 1, end_row): # Skip header and total row
126156
# Color answer sentences dark red if they are empty
127157
worksheet.conditional_format(
128-
row_num, 3, # Column D
129-
row_num, 3,
158+
row_num, col_id("answer_sentences"),
159+
row_num, col_id("answer_sentences"),
130160
{
131161
'type': 'formula',
132-
'criteria': f'=$D{row_num + 1}=""',
162+
'criteria': f'=${col_letter("answer_sentences")}{row_num + 1}=""',
133163
'format': workbook.add_format({'bg_color': 'red'})
134164
}
135165
)
136166

137167
# Color trace_label_correct red if automatic and manual extraction are wrong
138168
worksheet.conditional_format(
139-
row_num, 6, # Column G
140-
row_num, 6,
169+
row_num, col_id("trace_label_correct"),
170+
row_num, col_id("trace_label_correct"),
141171
{
142172
'type': 'formula',
143-
'criteria': f'=AND($E{row_num + 1}<>$J{row_num + 1},$F{row_num + 1}<>$J{row_num + 1})',
173+
'criteria': f'=AND(${col_letter("automatic_extraction")}{row_num + 1}<>${col_letter("correct_answer")}{row_num + 1},${col_letter("manual_extraction")}{row_num + 1}<>${col_letter("correct_answer")}{row_num + 1})',
144174
'format': workbook.add_format({'bg_color': pastel_red})
145175
}
146176
)
147177

148178
# Color trace_label_correct green if automatic or manual extraction are right
149179
worksheet.conditional_format(
150-
row_num, 6, # Column G
151-
row_num, 6,
180+
row_num, col_id("trace_label_correct"),
181+
row_num, col_id("trace_label_correct"),
152182
{
153183
'type': 'formula',
154-
'criteria': f'=OR($E{row_num + 1}=$J{row_num + 1}, $F{row_num + 1}=$J{row_num + 1})',
184+
'criteria': f'=OR(${col_letter("automatic_extraction")}{row_num + 1}=${col_letter("correct_answer")}{row_num + 1}, ${col_letter("manual_extraction")}{row_num + 1}=${col_letter("correct_answer")}{row_num + 1})',
155185
'format': workbook.add_format({'bg_color': pastel_green})
156186
}
157187
)
158188

159189
# Color manual extraction dark red if it is necessary (automatic extraction failed (empty or undecisive))
160190
worksheet.conditional_format(
161-
row_num, 5, # Column F
162-
row_num, 5,
191+
row_num, col_id("manual_extraction"),
192+
row_num, col_id("manual_extraction"),
163193
{
164194
'type': 'formula',
165-
'criteria': f'=AND($F{row_num + 1}="", OR($E{row_num + 1}="", ISNUMBER(SEARCH("#", $E{row_num + 1}))))',
195+
'criteria': f'=AND(${col_letter("manual_extraction")}{row_num + 1}="", OR(${col_letter("automatic_extraction")}{row_num + 1}="", ISNUMBER(SEARCH("#", ${col_letter("automatic_extraction")}{row_num + 1}))))',
166196
'format': workbook.add_format({'bg_color': 'red'})
167197
}
168198
)
169199

170200
# Color model_choice_correct red if it is wrong
171201
worksheet.conditional_format(
172-
row_num, 8, # Column I
173-
row_num, 8,
202+
row_num, col_id("model_choice_correct"),
203+
row_num, col_id("model_choice_correct"),
174204
{
175205
'type': 'formula',
176-
'criteria': f'=$H{row_num + 1}<>$J{row_num + 1}',
206+
'criteria': f'=${col_letter("model_choice")}{row_num + 1}<>${col_letter("correct_answer")}{row_num + 1}',
177207
'format': workbook.add_format({'bg_color': pastel_red})
178208
}
179209
)
180210

181211
# Color the model_choice_correct green if it is right
182212
worksheet.conditional_format(
183-
row_num, 8, # Column I
184-
row_num, 8,
213+
row_num, col_id("model_choice_correct"),
214+
row_num, col_id("model_choice_correct"),
185215
{
186216
'type': 'formula',
187-
'criteria': f'=$H{row_num + 1}=$J{row_num + 1}',
217+
'criteria': f'=${col_letter("model_choice")}{row_num + 1}=${col_letter("correct_answer")}{row_num + 1}',
188218
'format': workbook.add_format({'bg_color': pastel_green})
189219
}
190220
)
191221

192-
# Color labels_do_match green if they match
222+
# Color labels_match green if they match
193223
worksheet.conditional_format(
194-
row_num, 10, # Column I
195-
row_num, 10,
224+
row_num, col_id("labels_match"),
225+
row_num, col_id("labels_match"),
196226
{
197227
'type': 'formula',
198-
'criteria': f'=OR($E{row_num + 1}=$H{row_num + 1}, $F{row_num + 1}=$H{row_num + 1})',
228+
'criteria': f'=OR(${col_letter("automatic_extraction")}{row_num + 1}=${col_letter("model_choice")}{row_num + 1}, ${col_letter("manual_extraction")}{row_num + 1}=${col_letter("model_choice")}{row_num + 1})',
199229
'format': workbook.add_format({'bg_color': 'green'})
200230
}
201231
)
202232

203233
# Color labels_do_match green if they do not match
204234
worksheet.conditional_format(
205-
row_num, 10, # Column I
206-
row_num, 10,
235+
row_num, col_id("labels_match"),
236+
row_num, col_id("labels_match"),
207237
{
208238
'type': 'formula',
209-
'criteria': f'=AND($E{row_num + 1}<>$H{row_num + 1}, $F{row_num + 1}<>$H{row_num + 1})',
239+
'criteria': f'=AND(${col_letter("automatic_extraction")}{row_num + 1}<>${col_letter("model_choice")}{row_num + 1}, ${col_letter("manual_extraction")}{row_num + 1}<>${col_letter("model_choice")}{row_num + 1})',
210240
'format': workbook.add_format({'bg_color': 'red'})
211241
}
212242
)
@@ -261,6 +291,7 @@ def analyze_trace_label_match(test_result):
261291
formatted_single_results = [
262292
{
263293
"question_id": single_result.data["question_id"],
294+
"question": single_result.data["question"],
264295
"labels": single_result.data["labels"],
265296
"reasoning": single_result.data["completions"][0]["reasoning"]["text"],
266297
"model_choice": single_result.data["model_choice"],
@@ -279,9 +310,10 @@ def analyze_trace_label_match(test_result):
279310

280311
answer_sentence_indicators = []
281312

282-
if model_name == "llama-2-7b-chat":
313+
if model_name == "llama-2-7b-chat" or model_name == "llama-2-13b-chat":
283314
answer_sentence_indicators = [
284315
"the correct answer is",
316+
"the correct answer among",
285317
"The correct answer is",
286318
"is the correct answer",
287319
"the best answer",
@@ -297,16 +329,16 @@ def analyze_trace_label_match(test_result):
297329
for formatted_single_result in formatted_single_results:
298330
table_row = {
299331
"question_id": formatted_single_result["question_id"],
332+
"question": formatted_single_result["question"],
300333
"reasoning": formatted_single_result["reasoning"],
301-
#"is_extractable": "",
302334
"answer_sentences": "",
303335
"automatic_extraction": "",
304336
"manual_extraction": "",
305337
"trace_label_correct": "",
306338
"model_choice": formatted_single_result["model_choice"],
307339
"model_choice_correct": formatted_single_result["is_correct"],
308340
"correct_answer": formatted_single_result["correct_answer"],
309-
"labels_do_match": "",
341+
"labels_match": "",
310342
"comment": ""
311343
}
312344

0 commit comments

Comments
 (0)