Skip to content

Commit 3958512

Browse files
Merge pull request #311 from lijialin03/develop
update features
2 parents df07a65 + b0ba4fa commit 3958512

File tree

9 files changed

+59
-11
lines changed

9 files changed

+59
-11
lines changed

ppsci/arch/base.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Callable
1516
from typing import Dict
1617
from typing import Tuple
1718

@@ -80,10 +81,28 @@ def split_to_dict(
8081
data = paddle.split(data_tensor, len(keys), axis=axis)
8182
return {key: data[i] for i, key in enumerate(keys)}
8283

83-
def register_input_transform(self, transform):
84+
def register_input_transform(
85+
self,
86+
transform: Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]],
87+
):
88+
"""Register input transform.
89+
90+
Args:
91+
transform (Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
92+
Input transform of network, receive a single tensor dict and return a single tensor dict.
93+
"""
8494
self._input_transform = transform
8595

86-
def register_output_transform(self, transform):
96+
def register_output_transform(
97+
self,
98+
transform: Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]],
99+
):
100+
"""Register output transform.
101+
102+
Args:
103+
transform (Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
104+
Output transform of network, receive a single tensor dict and return a single tensor dict.
105+
"""
87106
self._output_transform = transform
88107

89108
def __str__(self):

ppsci/arch/mlp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Optional
1516
from typing import Tuple
1617
from typing import Union
1718

@@ -33,6 +34,7 @@ class MLP(base.Arch):
3334
activation (str, optional): Name of activation function. Defaults to "tanh".
3435
skip_connection (bool, optional): Whether to use skip connection. Defaults to False.
3536
weight_norm (bool, optional): Whether to apply weight norm on parameter(s). Defaults to False.
37+
input_dim (Optional[int], optional): Number of input's dimension. Defaults to None.
3638
3739
Examples:
3840
>>> import ppsci
@@ -48,6 +50,7 @@ def __init__(
4850
activation: str = "tanh",
4951
skip_connection: bool = False,
5052
weight_norm: bool = False,
53+
input_dim: Optional[int] = None,
5154
):
5255
super().__init__()
5356
self.input_keys = input_keys
@@ -71,7 +74,7 @@ def __init__(
7174
)
7275

7376
# initialize FC layer(s)
74-
cur_size = len(self.input_keys)
77+
cur_size = len(self.input_keys) if input_dim is None else input_dim
7578
for _size in hidden_size:
7679
self.linears.append(nn.Linear(cur_size, _size))
7780
if weight_norm:

ppsci/constraint/supervised_constraint.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,15 @@ def __init__(
7272

7373
# construct dataloader with dataset and dataloader_cfg
7474
super().__init__(_dataset, dataloader_cfg, loss, name)
75+
76+
def __str__(self):
77+
return ", ".join(
78+
[
79+
self.__class__.__name__,
80+
f"name = {self.name}",
81+
f"input_keys = {self.input_keys}",
82+
f"output_keys = {self.output_keys}",
83+
f"output_expr = {self.output_expr}",
84+
f"loss = {self.loss}",
85+
]
86+
)

ppsci/data/dataset/mat_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class MatDataset(io.Dataset):
3434
Args:
3535
file_path (str): Mat file path.
3636
input_keys (Tuple[str, ...]): List of input keys.
37-
label_keys (Tuple[str, ...]): List of label keys.
37+
label_keys (Tuple[str, ...], optional): List of label keys. Defaults to ().
3838
alias_dict (Optional[Dict[str, str]]): Dict of alias(es) for input and label keys.
3939
i.e. {inner_key: outer_key}. Defaults to None.
4040
weight_dict (Optional[Dict[str, Union[Callable, float]]]): Define the weight of
@@ -57,7 +57,7 @@ def __init__(
5757
self,
5858
file_path: str,
5959
input_keys: Tuple[str, ...],
60-
label_keys: Tuple[str, ...],
60+
label_keys: Tuple[str, ...] = (),
6161
alias_dict: Optional[Dict[str, str]] = None,
6262
weight_dict: Optional[Dict[str, Union[Callable, float]]] = None,
6363
timestamps: Optional[Tuple[float, ...]] = None,
@@ -153,7 +153,7 @@ class IterableMatDataset(io.IterableDataset):
153153
Args:
154154
file_path (str): Mat file path.
155155
input_keys (Tuple[str, ...]): List of input keys.
156-
label_keys (Tuple[str, ...]): List of label keys.
156+
label_keys (Tuple[str, ...], optional): List of label keys. Defaults to ().
157157
alias_dict (Optional[Dict[str, str]]): Dict of alias(es) for input and label keys.
158158
i.e. {inner_key: outer_key}. Defaults to None.
159159
weight_dict (Optional[Dict[str, Union[Callable, float]]]): Define the weight of
@@ -176,7 +176,7 @@ def __init__(
176176
self,
177177
file_path: str,
178178
input_keys: Tuple[str, ...],
179-
label_keys: Tuple[str, ...],
179+
label_keys: Tuple[str, ...] = (),
180180
alias_dict: Optional[Dict[str, str]] = None,
181181
weight_dict: Optional[Dict[str, Union[Callable, float]]] = None,
182182
timestamps: Optional[Tuple[Union[int, float], ...]] = None,

ppsci/loss/integral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
6262
label_dict[key],
6363
"none",
6464
)
65-
if weight_dict is not None:
65+
if weight_dict:
6666
loss *= weight_dict[key]
6767
if isinstance(self.weight, float):
6868
loss *= self.weight

ppsci/loss/l1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
5656
losses = 0.0
5757
for key in label_dict:
5858
loss = F.l1_loss(output_dict[key], label_dict[key], "none")
59-
if weight_dict is not None:
59+
if weight_dict:
6060
loss *= weight_dict[key]
6161
if isinstance(self.weight, float):
6262
loss *= self.weight

ppsci/loss/l2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
8484
loss = F.mse_loss(
8585
output_dict[key][:n_output], output_dict[key][n_output:], "none"
8686
)
87-
if weight_dict is not None:
87+
if weight_dict:
8888
loss *= weight_dict[key]
8989
if "area" in output_dict:
9090
loss *= output_dict["area"]

ppsci/loss/mse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
5858
losses = 0.0
5959
for key in label_dict:
6060
loss = F.mse_loss(output_dict[key], label_dict[key], "none")
61-
if weight_dict is not None:
61+
if weight_dict:
6262
loss *= weight_dict[key]
6363
if isinstance(self.weight, (float, int)):
6464
loss *= self.weight

ppsci/validate/sup_validator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,17 @@ def __init__(
8383

8484
# construct dataloader with dataset and dataloader_cfg
8585
super().__init__(_dataset, dataloader_cfg, loss, metric, name)
86+
87+
def __str__(self):
88+
return ", ".join(
89+
[
90+
self.__class__.__name__,
91+
f"name = {self.name}",
92+
f"input_keys = {self.input_keys}",
93+
f"output_keys = {self.output_keys}",
94+
f"output_expr = {self.output_expr}",
95+
f"len(dataloader) = {len(self.data_loader)}",
96+
f"loss = {self.loss}",
97+
f"metric = {list(self.metric.keys())}",
98+
]
99+
)

0 commit comments

Comments
 (0)