-
Notifications
You must be signed in to change notification settings - Fork 67
Adding seed for reproducibility and sampling methods #344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 50 commits
e4d8871
22b0318
30ea360
dc1f7c4
7b13967
6fb1c62
480e5f1
fc24463
561c3bb
49dc67b
b41b7d5
b65ba09
84babd2
ecf23bd
a821f6a
319b2f0
31f3d5f
d074f65
b0ecc05
2992bdf
ccebaed
dcc4809
922bf0c
82838d1
4e471cb
baa5478
4eb4ee4
4588a9d
ff58d02
0028ed7
0c83b6b
ada3ea8
4dd5d99
0a616b2
16c2a4a
c3b1922
f2a30a9
410f03d
f247893
2e03fef
b48ed02
567264a
627cc20
8decc0e
317cc29
5055889
6d0abbd
16d50f8
3e58819
0280941
a4c2b83
8e29047
268ba05
17ba026
b2a0c5a
60ed670
c5e634f
36c38ec
d11ee2c
a089d1f
271c502
48635cb
0c67471
524d804
a057ba4
58e101d
10c4ba7
4535d4a
5d74992
56bfa16
b7299e9
8fa4b3e
4f89e58
db5f074
bd74c22
a88b80a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -433,7 +433,7 @@ def get_dataset_tasks(self, dset_df): | |
return self.tasks is not None | ||
|
||
# **************************************************************************************** | ||
def split_dataset(self): | ||
def split_dataset(self, random_state=None, seed=None): | ||
"""Splits the dataset into paired training/validation and test subsets, according to the split strategy | ||
selected by the model params. For traditional train/valid/test splits, there is only one training/validation | ||
pair. For k-fold cross-validation splits, there are k different train/valid pairs; the validation sets are | ||
|
@@ -452,7 +452,7 @@ def split_dataset(self): | |
|
||
# Create object to delegate splitting to. | ||
if self.splitting is None: | ||
self.splitting = split.create_splitting(self.params) | ||
self.splitting = split.create_splitting(self.params, random_state=random_state, seed=seed) | ||
self.train_valid_dsets, self.test_dset, self.train_valid_attr, self.test_attr = \ | ||
self.splitting.split_dataset(self.dataset, self.attr, self.params.smiles_col) | ||
if self.train_valid_dsets is None: | ||
|
@@ -479,6 +479,12 @@ def _check_classes(self): | |
(Boolean): boolean specifying if all classes are specified in all splits | ||
""" | ||
ref_class_set = get_classes(self.train_valid_dsets[0][0].y) | ||
num_classes = len(ref_class_set) | ||
if num_classes != self.params.class_number: | ||
logger = logging.getLogger('ATOM') | ||
logger.warning(f"Expected class_number:{self.params.class_number} " | ||
f"classes but got {num_classes} instead. Double check " | ||
"response columns or class_number parameter.") | ||
for train, valid in self.train_valid_dsets: | ||
if not ref_class_set == get_classes(train.y): | ||
return False | ||
|
@@ -563,7 +569,7 @@ def create_dataset_split_table(self): | |
return split_df | ||
|
||
# **************************************************************************************** | ||
def load_presplit_dataset(self, directory=None): | ||
def load_presplit_dataset(self, directory=None, random_state=None, seed=None): | ||
"""Loads a table of compound IDs assigned to split subsets, and uses them to split | ||
the currently loaded featurized dataset. | ||
|
||
|
@@ -590,7 +596,7 @@ def load_presplit_dataset(self, directory=None): | |
""" | ||
|
||
# Load the split table from the datastore or filesystem | ||
self.splitting = split.create_splitting(self.params) | ||
self.splitting = split.create_splitting(self.params, random_state=random_state, seed=seed) | ||
|
||
try: | ||
split_df, split_kv = self.load_dataset_split_table(directory) | ||
|
@@ -655,11 +661,31 @@ def combined_training_data(self): | |
# All of the splits have the same combined train/valid data, regardless of whether we're using | ||
# k-fold or train/valid/test splitting. | ||
if self.combined_train_valid_data is None: | ||
# normally combining one fold is sufficient, but if SMOTE or undersampling is being used | ||
# just combining the first fold isn't enough | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For undersampling, it looks like it assumes that K-fold undersampling would sample the entire non-test dataset. What if this isn't the case? Is this assumption ensured elsewhere? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think a compound can ever be wiped entirely out of existence due to undersampling. Undersampling is only applied to the training set of each fold. This isn't tested, but do we need to test it anywhere? I think it's ok if a compound is dropped entirely, since that's what happens when using undersampling without k-fold validation. |
||
(train, valid) = self.train_valid_dsets[0] | ||
combined_X = np.concatenate((train.X, valid.X), axis=0) | ||
combined_y = np.concatenate((train.y, valid.y), axis=0) | ||
combined_w = np.concatenate((train.w, valid.w), axis=0) | ||
combined_ids = np.concatenate((train.ids, valid.ids)) | ||
|
||
if self.params.sampling_method=='SMOTE' or self.params.sampling_method=='undersampling': | ||
# for each successive fold, merge in any new compounds | ||
# this loop just won't run if there are no additional folds | ||
for train, valid in self.train_valid_dsets[1:]: | ||
fold_ids = np.concatenate((train.ids, valid.ids)) | ||
new_id_indexes = [i for i in range(len(fold_ids)) if i not in combined_ids] | ||
|
||
fold_ids = fold_ids[new_id_indexes] | ||
fold_X = np.concatenate((train.X, valid.X), axis=0)[new_id_indexes] | ||
fold_y = np.concatenate((train.y, valid.y), axis=0)[new_id_indexes] | ||
fold_w = np.concatenate((train.w, valid.w), axis=0)[new_id_indexes] | ||
|
||
combined_X = np.concatenate((combined_X, fold_X), axis=0) | ||
combined_y = np.concatenate((combined_y, fold_y), axis=0) | ||
combined_w = np.concatenate((combined_w, fold_w), axis=0) | ||
combined_ids = np.concatenate((combined_ids, fold_ids)) | ||
|
||
self.combined_train_valid_data = NumpyDataset(combined_X, combined_y, w=combined_w, ids=combined_ids) | ||
return self.combined_train_valid_data | ||
|
||
|
@@ -697,7 +723,8 @@ def get_subset_responses_and_weights(self, subset, transformers): | |
""" | ||
if subset not in self.subset_response_dict: | ||
if subset in ('train', 'valid', 'train_valid'): | ||
dataset = self.combined_training_data() | ||
for fold, (train, valid) in enumerate(self.train_valid_dsets): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you are just looking for new compounds in each fold, can you just concatenate all train_valid_dsets and then call set(ids) or drop_duplicates() or something? Might make the code more efficient than multiple for loops, but I'm not sure if it is actually easier based on the way the datasets are stored. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These datasets are NumpyDatasets and contain an X matrix, y matrix, w matrix, and ids. I'd have to put them into a data frame, call drop_duplicates, and then put it back into a NumpyDataset. However, I think I can get rid of that loop on 726. That's not necessary. |
||
dataset = self.combined_training_data() | ||
elif subset == 'test': | ||
dataset = self.test_dset | ||
else: | ||
|
Uh oh!
There was an error while loading. Please reload this page.