Skip to content

Commit b013415

Browse files
committed
add func to average over seeds + save sem
1 parent 2fece1e commit b013415

File tree

1 file changed

+41
-3
lines changed

1 file changed

+41
-3
lines changed

imodelsx/process_results.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import pandas as pd
88
import pickle as pkl
99
import sys
10+
import warnings
11+
import scipy.stats
12+
import numpy as np
1013
repo_dir = dirname(dirname(os.path.abspath(__file__)))
1114

1215
def get_results_df(results_dir, use_cached=False) -> pd.DataFrame:
@@ -35,7 +38,7 @@ def get_main_args_list(experiment_filename='01_train_model.py'):
3538
3639
Params
3740
------
38-
fname: str
41+
experiment_filename: str
3942
Full path + name of the experiments script, e.g. /home/user/tree-prompt/experiments/01_train_model.py
4043
"""
4144
if experiment_filename.endswith('.py'):
@@ -51,7 +54,7 @@ def fill_missing_args_with_default(df, experiment_filename='01_train_model.py'):
5154
Params
5255
------
5356
54-
fname: str
57+
experiment_filename: str
5558
Full path + name of the experiments script, e.g. /home/user/tree-prompt/experiments/01_train_model.py
5659
"""
5760
if experiment_filename.endswith('.py'):
@@ -82,4 +85,39 @@ def delete_runs_in_dataframe(df: pd.DataFrame, actually_delete=False, directory_
8285
num_deleted += 1
8386
except:
8487
pass
85-
print(f'Deleted {num_deleted}/{df.shape[0]} directories.')
88+
print(f'Deleted {num_deleted}/{df.shape[0]} directories.')
89+
90+
91+
def average_over_seeds(df: pd.DataFrame, experiment_filename='01_train_model.py', key_to_average_over='seed'):
92+
"""Returns values averaged over seed.
93+
Standard errors of the mean are added with columns suffixed with _err
94+
For example, 'accuracy_test' yields two columns
95+
'accuracy_test' now holds the mean value
96+
'accuracy_test_err' now holds the standard error of the mean
97+
98+
Params
99+
------
100+
experiment_filename: str
101+
Full path + name of the experiments script, e.g. /home/user/tree-prompt/experiments/01_train_model.py
102+
This is used to get the names of the arguments to aggregate over
103+
"""
104+
def sem(x):
105+
'''Compute standard error of the mean, ignoring NaNs
106+
'''
107+
with warnings.catch_warnings():
108+
return scipy.stats.sem(x, ddof=0)
109+
110+
group_keys = [
111+
k for k in get_main_args_list(experiment_filename)
112+
if not k == key_to_average_over
113+
]
114+
115+
df_avg = (
116+
df
117+
.groupby(by=group_keys)
118+
.aggregate([np.mean, sem], numeric_only=True)
119+
# .mean(numeric_only=True)
120+
.reset_index()
121+
)
122+
df_avg.columns = [x[0]+'_err' if x[1] == 'sem' else x[0] for x in df_avg.columns]
123+
return df_avg

0 commit comments

Comments
 (0)