7
7
import pandas as pd
8
8
import pickle as pkl
9
9
import sys
10
+ import warnings
11
+ import scipy .stats
12
+ import numpy as np
10
13
repo_dir = dirname (dirname (os .path .abspath (__file__ )))
11
14
12
15
def get_results_df (results_dir , use_cached = False ) -> pd .DataFrame :
@@ -35,7 +38,7 @@ def get_main_args_list(experiment_filename='01_train_model.py'):
35
38
36
39
Params
37
40
------
38
- fname : str
41
+ experiment_filename : str
39
42
Full path + name of the experiments script, e.g. /home/user/tree-prompt/experiments/01_train_model.py
40
43
"""
41
44
if experiment_filename .endswith ('.py' ):
@@ -51,7 +54,7 @@ def fill_missing_args_with_default(df, experiment_filename='01_train_model.py'):
51
54
Params
52
55
------
53
56
54
- fname : str
57
+ experiment_filename : str
55
58
Full path + name of the experiments script, e.g. /home/user/tree-prompt/experiments/01_train_model.py
56
59
"""
57
60
if experiment_filename .endswith ('.py' ):
@@ -82,4 +85,39 @@ def delete_runs_in_dataframe(df: pd.DataFrame, actually_delete=False, directory_
82
85
num_deleted += 1
83
86
except :
84
87
pass
85
- print (f'Deleted { num_deleted } /{ df .shape [0 ]} directories.' )
88
+ print (f'Deleted { num_deleted } /{ df .shape [0 ]} directories.' )
89
+
90
+
91
+ def average_over_seeds (df : pd .DataFrame , experiment_filename = '01_train_model.py' , key_to_average_over = 'seed' ):
92
+ """Returns values averaged over seed.
93
+ Standard errors of the mean are added with columns suffixed with _err
94
+ For example, 'accuracy_test' yields two columns
95
+ 'accuracy_test' now holds the mean value
96
+ 'accuracy_test_err' now holds the standard error of the mean
97
+
98
+ Params
99
+ ------
100
+ experiment_filename: str
101
+ Full path + name of the experiments script, e.g. /home/user/tree-prompt/experiments/01_train_model.py
102
+ This is used to get the names of the arguments to aggregate over
103
+ """
104
+ def sem (x ):
105
+ '''Compute standard error of the mean, ignoring NaNs
106
+ '''
107
+ with warnings .catch_warnings ():
108
+ return scipy .stats .sem (x , ddof = 0 )
109
+
110
+ group_keys = [
111
+ k for k in get_main_args_list (experiment_filename )
112
+ if not k == key_to_average_over
113
+ ]
114
+
115
+ df_avg = (
116
+ df
117
+ .groupby (by = group_keys )
118
+ .aggregate ([np .mean , sem ], numeric_only = True )
119
+ # .mean(numeric_only=True)
120
+ .reset_index ()
121
+ )
122
+ df_avg .columns = [x [0 ]+ '_err' if x [1 ] == 'sem' else x [0 ] for x in df_avg .columns ]
123
+ return df_avg
0 commit comments