Skip to content

Commit 3d054e2

Browse files
committed
fix shuffle bug
1 parent 95566b3 commit 3d054e2

File tree

3 files changed

+3
-10
lines changed

3 files changed

+3
-10
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ Per users' request, we processed two non-anthropogenic datasets
108108
Explore the following tutorials that can be opened directly in Google Colab:
109109

110110
- [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_1_dataset.ipynb) Tutorial 1: Dataset in EasyTPP.
111+
- [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_2_tfb_wb.ipynb) Tutorial 2: Tensorboard in EasyTPP.
111112

112113
### End-to-end Example
113114

easy_tpp/config_factory/runner_config.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def __init__(self, base_config, model_config, data_config, trainer_config):
2626
self.base_config = base_config
2727
self.trainer_config = trainer_config
2828

29-
self.ensure_valid_config()
3029
self.update_config()
3130

3231
# save the complete config
@@ -87,15 +86,6 @@ def parse_from_yaml_config(yaml_config, **kwargs):
8786
trainer_config=trainer_config
8887
)
8988

90-
def ensure_valid_config(self):
91-
"""Do some sanity check about the config, to avoid conflicts in settings.
92-
"""
93-
94-
# during testing we dont do shuffle by default
95-
self.trainer_config.shuffle = False
96-
97-
return
98-
9989
def update_config(self):
10090
"""Updated config dict.
10191
"""

easy_tpp/preprocess/data_loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ def test_loader(self, **kwargs):
136136
Returns:
137137
EasyTPP.DataLoader: data loader for test set.
138138
"""
139+
# for test set, we do not shuffle
140+
kwargs['shuffle'] = False
139141
return self.get_loader('test', **kwargs)
140142

141143
def get_statistics(self, split='train'):

0 commit comments

Comments
 (0)