Skip to content

Commit 761ee8b

Browse files
refine weight parameter
1 parent bfcb4cf commit 761ee8b

File tree

5 files changed

+29
-48
lines changed

5 files changed

+29
-48
lines changed

ppsci/data/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,9 @@ def build_dataloader(_dataset, cfg):
8383
# build collate_fn if specified
8484
batch_transforms_cfg = cfg.pop("batch_transforms", None)
8585

86+
collate_fn = None
8687
if isinstance(batch_transforms_cfg, dict) and batch_transforms_cfg:
8788
collate_fn = batch_transform.build_batch_transforms(batch_transforms_cfg)
88-
else:
89-
collate_fn = batch_transform.default_collate_fn_allow_none
9089

9190
# build init function
9291
init_fn = partial(
@@ -97,7 +96,7 @@ def build_dataloader(_dataset, cfg):
9796
)
9897

9998
# build dataloader
100-
dataloader = io.DataLoader(
99+
dataloader_ = io.DataLoader(
101100
dataset=_dataset,
102101
places=device.get_device(),
103102
batch_sampler=sampler,
@@ -107,4 +106,4 @@ def build_dataloader(_dataset, cfg):
107106
worker_init_fn=init_fn,
108107
)
109108

110-
return dataloader
109+
return dataloader_

ppsci/data/dataset/array_dataset.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,7 @@ def __init__(
5656
def __getitem__(self, idx):
5757
input_item = {key: value[idx] for key, value in self.input.items()}
5858
label_item = {key: value[idx] for key, value in self.label.items()}
59-
weight_item = (
60-
{key: value[idx] for key, value in self.weight.items()}
61-
if self.weight is not None
62-
else None
63-
)
59+
weight_item = {key: value[idx] for key, value in self.weight.items()}
6460

6561
# TODO(sensen): Transforms may be applied on label and weight.
6662
if self.transforms is not None:

ppsci/data/dataset/era5_dataset.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,12 @@ def __init__(
6868
self.label_keys = label_keys
6969
self.precip_file_path = precip_file_path
7070

71-
self.weight_dict = weight_dict
71+
self.weight_dict = {} if weight_dict is None else weight_dict
7272
if weight_dict is not None:
7373
self.weight_dict = {key: 1.0 for key in self.label_keys}
7474
self.weight_dict.update(weight_dict)
7575

76-
self.vars_channel = (
77-
vars_channel if vars_channel is not None else [i for i in range(20)]
78-
)
76+
self.vars_channel = list(range(20)) if vars_channel is None else vars_channel
7977
self.num_label_timestamps = num_label_timestamps
8078
self.transforms = transforms
8179
self.training = training
@@ -127,6 +125,7 @@ def __getitem__(self, global_idx):
127125
input_idx, label_idx = local_idx, local_idx + step
128126

129127
input_item = {self.input_keys[0]: input_file[input_idx, self.vars_channel]}
128+
130129
label_item = {}
131130
for i in range(self.num_label_timestamps):
132131
if self.precip_file_path is not None:
@@ -138,14 +137,11 @@ def __getitem__(self, global_idx):
138137
label_idx + i, self.vars_channel
139138
]
140139

141-
if self.weight_dict is not None:
142-
weight_shape = [1] * len(next(iter(label_item.values)).shape)
143-
weight_item = {
144-
key: np.full(weight_shape, value, paddle.get_default_dtype())
145-
for key, value in self.weight_dict.items()
146-
}
147-
else:
148-
weight_item = None
140+
weight_shape = [1] * len(next(iter(label_item.values)).shape)
141+
weight_item = {
142+
key: np.full(weight_shape, value, paddle.get_default_dtype())
143+
for key, value in self.weight_dict.items()
144+
}
149145

150146
if self.transforms is not None:
151147
input_item, label_item, weight_item = self.transforms(
@@ -187,7 +183,7 @@ def __init__(
187183
self.input_keys = input_keys
188184
self.label_keys = label_keys
189185

190-
self.weight_dict = weight_dict
186+
self.weight_dict = {} if weight_dict is None else weight_dict
191187
if weight_dict is not None:
192188
self.weight_dict = {key: 1.0 for key in self.label_keys}
193189
self.weight_dict.update(weight_dict)
@@ -201,8 +197,8 @@ def read_data(self, path: str):
201197
paths = glob.glob(path + "/*.h5")
202198
paths.sort()
203199
files = []
204-
for path in paths:
205-
_file = h5py.File(path, "r")
200+
for _path in paths:
201+
_file = h5py.File(_path, "r")
206202
files.append(_file)
207203
return files
208204

@@ -217,20 +213,18 @@ def __getitem__(self, global_idx):
217213
input_item[key] = np.asarray(
218214
_file["input_dict"][key], paddle.get_default_dtype()
219215
)
216+
220217
label_item = {}
221218
for key in _file["label_dict"]:
222219
label_item[key] = np.asarray(
223220
_file["label_dict"][key], paddle.get_default_dtype()
224221
)
225222

226-
if self.weight_dict is not None:
227-
weight_shape = [1] * len(next(iter(label_item.values)).shape)
228-
weight_item = {
229-
key: np.full(weight_shape, value, paddle.get_default_dtype())
230-
for key, value in self.weight_dict.items()
231-
}
232-
else:
233-
weight_item = None
223+
weight_shape = [1] * len(next(iter(label_item.values)).shape)
224+
weight_item = {
225+
key: np.full(weight_shape, value, paddle.get_default_dtype())
226+
for key, value in self.weight_dict.items()
227+
}
234228

235229
if self.transforms is not None:
236230
input_item, label_item, weight_item = self.transforms(

ppsci/data/dataset/vtu_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
def __getitem__(self, idx):
8484
input_item = {key: value[idx] for key, value in self.input.items()}
8585
label_item = {key: value[idx] for key, value in self.label.items()}
86-
return (input_item, label_item, None)
86+
return (input_item, label_item, {})
8787

8888
def __len__(self):
8989
return self.num_samples

ppsci/data/process/batch_transform/__init__.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525

2626
from ppsci.data.process import transform
2727

28-
__all__ = ["build_batch_transforms", "default_collate_fn_allow_none"]
28+
__all__ = ["build_batch_transforms"]
2929

3030

31-
def default_collate_fn_allow_none(batch: List[Any]) -> Any:
32-
"""Modified collate function to allow some fields to be None, such as weight field.
31+
def default_collate_fn(batch: List[Any]) -> Any:
32+
"""Default_collate_fn for paddle dataloader.
3333
34-
ref: https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/dataloader/collate.py#L24
34+
ref: https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/io/dataloader/collate.py#L25
3535
3636
Args:
3737
batch (List[Any]): Batch of samples to be collated.
@@ -40,11 +40,6 @@ def default_collate_fn_allow_none(batch: List[Any]) -> Any:
4040
Any: Collated batch data.
4141
"""
4242
sample = batch[0]
43-
44-
# allow field to be None
45-
if sample is None:
46-
return None
47-
4843
if isinstance(sample, np.ndarray):
4944
batch = np.stack(batch, axis=0)
5045
return batch
@@ -56,15 +51,12 @@ def default_collate_fn_allow_none(batch: List[Any]) -> Any:
5651
elif isinstance(sample, (str, bytes)):
5752
return batch
5853
elif isinstance(sample, Mapping):
59-
return {
60-
key: default_collate_fn_allow_none([d[key] for d in batch])
61-
for key in sample
62-
}
54+
return {key: default_collate_fn([d[key] for d in batch]) for key in sample}
6355
elif isinstance(sample, Sequence):
6456
sample_fields_num = len(sample)
6557
if not all(len(sample) == sample_fields_num for sample in iter(batch)):
6658
raise RuntimeError("fileds number not same among samples in a batch")
67-
return [default_collate_fn_allow_none(fields) for fields in zip(*batch)]
59+
return [default_collate_fn(fields) for fields in zip(*batch)]
6860

6961
raise TypeError(
7062
"batch data can only contains: tensor, numpy.ndarray, "
@@ -80,6 +72,6 @@ def collate_fn_batch_transforms(batch: List[Any]):
8072
# apply batch transform on uncollated data
8173
batch = batch_transforms(batch)
8274
# then do collate
83-
return default_collate_fn_allow_none(batch)
75+
return default_collate_fn(batch)
8476

8577
return collate_fn_batch_transforms

0 commit comments

Comments
 (0)