Skip to content

Commit b7ad945

Browse files
Merge pull request #328 from HydrogenSulfate/refine_code
optimize code
2 parents c0e6c0a + eaee439 commit b7ad945

File tree

21 files changed

+98
-129
lines changed

21 files changed

+98
-129
lines changed

docs/zh/api/data/dataset.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66
members:
77
- IterableNamedArrayDataset
88
- NamedArrayDataset
9-
- CylinderDataset
10-
- LorenzDataset
11-
- RosslerDataset
129
- CSVDataset
13-
- MatDataset
10+
- IterableCSVDataset
1411
- ERA5Dataset
1512
- ERA5SampledDataset
13+
- IterableMatDataset
14+
- MatDataset
15+
- CylinderDataset
16+
- LorenzDataset
17+
- RosslerDataset
18+
- VtuDataset
1619
show_root_heading: false

docs/zh/api/equation.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010
- NavierStokes
1111
- NormalDotVec
1212
- Poisson
13+
- Vibration
1314
show_root_heading: false
1415
heading_level: 3

docs/zh/api/geometry.md

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,3 @@
2020
- TimeXGeometry
2121
show_root_heading: false
2222
heading_level: 3
23-
24-
<!-- # Geometry
25-
26-
::: ppsci.arch
27-
28-
This is on a separate line
29-
30-
$$
31-
\operatorname{ker} f=\{g\in G:f(g)=e_{H}\}{\mbox{.}}
32-
$$
33-
34-
The homomorphism $f$ is injective if and only if its kernel is only the
35-
singleton set $e_G$, because otherwise $\exists a,b\in G$ with $a\neq b$ such
36-
that $f(a)=f(b)$.
37-
38-
```python
39-
40-
--8<--
41-
./ppsci/data/dataset/array_dataset.py:16:49
42-
--8<--
43-
44-
``` -->

docs/zh/api/lr_scheduler.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
handler: python
55
options:
66
members:
7-
- ConstLR
87
- Linear
98
- Cosine
109
- Step

docs/zh/api/metric.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
handler: python
55
options:
66
members:
7+
- Metric
78
- MAE
89
- MSE
910
- RMSE

docs/zh/api/utils.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
- initializer
88
- logger
99
- misc
10-
- reader
10+
- load_csv_file
11+
- load_mat_file
12+
- load_vtk_file
1113
- run_check
1214
- AttrDict
1315
- ExpressionSolver

examples/cylinder/3d_unsteady_discrete/cylinder3d_unsteady.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,6 @@
342342
600000,
343343
label,
344344
time_list,
345-
len(time_list),
346345
"result_uvwp",
347346
)
348347
}

ppsci/arch/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def concat_to_tensor(
6262
Returns:
6363
Tuple[paddle.Tensor, ...]: Concatenated tensor.
6464
"""
65+
if len(keys) == 1:
66+
return data_dict[keys[0]]
6567
data = [data_dict[key] for key in keys]
6668
return paddle.concat(data, axis)
6769

@@ -78,6 +80,8 @@ def split_to_dict(
7880
Returns:
7981
Dict[str, paddle.Tensor]: Dict contains tensor.
8082
"""
83+
if len(keys) == 1:
84+
return {keys[0]: data_tensor}
8185
data = paddle.split(data_tensor, len(keys), axis=axis)
8286
return {key: data[i] for i, key in enumerate(keys)}
8387

ppsci/data/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,9 @@ def build_dataloader(_dataset, cfg):
8383
# build collate_fn if specified
8484
batch_transforms_cfg = cfg.pop("batch_transforms", None)
8585

86+
collate_fn = None
8687
if isinstance(batch_transforms_cfg, dict) and batch_transforms_cfg:
8788
collate_fn = batch_transform.build_batch_transforms(batch_transforms_cfg)
88-
else:
89-
collate_fn = batch_transform.default_collate_fn_allow_none
9089

9190
# build init function
9291
init_fn = partial(
@@ -97,7 +96,7 @@ def build_dataloader(_dataset, cfg):
9796
)
9897

9998
# build dataloader
100-
dataloader = io.DataLoader(
99+
dataloader_ = io.DataLoader(
101100
dataset=_dataset,
102101
places=device.get_device(),
103102
batch_sampler=sampler,
@@ -107,4 +106,4 @@ def build_dataloader(_dataset, cfg):
107106
worker_init_fn=init_fn,
108107
)
109108

110-
return dataloader
109+
return dataloader_

ppsci/data/dataset/array_dataset.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,7 @@ def __init__(
5656
def __getitem__(self, idx):
5757
input_item = {key: value[idx] for key, value in self.input.items()}
5858
label_item = {key: value[idx] for key, value in self.label.items()}
59-
weight_item = (
60-
{key: value[idx] for key, value in self.weight.items()}
61-
if self.weight is not None
62-
else None
63-
)
59+
weight_item = {key: value[idx] for key, value in self.weight.items()}
6460

6561
# TODO(sensen): Transforms may be applied on label and weight.
6662
if self.transforms is not None:

0 commit comments

Comments
 (0)