Skip to content

Commit 6e5a58c

Browse files
support more dtype, such as float64
1 parent d739ace commit 6e5a58c

24 files changed

+359
-186
lines changed

ppsci/arch/embedding_koopman.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,14 @@ def build_koopman_operator(self, embed_size: int):
117117
data = paddle.linspace(1, 0, embed_size)
118118
k_diag = paddle.create_parameter(
119119
shape=data.shape,
120-
dtype=data.dtype,
120+
dtype=paddle.get_default_dtype(),
121121
default_initializer=nn.initializer.Assign(data),
122122
)
123123

124124
data = 0.1 * paddle.rand([2 * embed_size - 3])
125125
k_ut = paddle.create_parameter(
126126
shape=data.shape,
127-
dtype=data.dtype,
127+
dtype=paddle.get_default_dtype(),
128128
default_initializer=nn.initializer.Assign(data),
129129
)
130130
return k_diag, k_ut

ppsci/constraint/boundary_constraint.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __init__(
111111
label = {}
112112
for key, value in label_dict.items():
113113
if isinstance(value, (int, float)):
114-
label[key] = np.full_like(next(iter(input.values())), float(value))
114+
label[key] = np.full_like(next(iter(input.values())), value)
115115
elif isinstance(value, sympy.Basic):
116116
func = sympy.lambdify(
117117
sympy.symbols(geom.dim_keys),
@@ -125,9 +125,7 @@ def __init__(
125125
func = value
126126
label[key] = func(input)
127127
if isinstance(label[key], (int, float)):
128-
label[key] = np.full_like(
129-
next(iter(input.values())), float(label[key])
130-
)
128+
label[key] = np.full_like(next(iter(input.values())), label[key])
131129
else:
132130
raise NotImplementedError(f"type of {type(value)} is invalid yet.")
133131

@@ -139,7 +137,7 @@ def __init__(
139137
value = sp_parser.parse_expr(value)
140138

141139
if isinstance(value, (int, float)):
142-
weight[key] = np.full_like(next(iter(label.values())), float(value))
140+
weight[key] = np.full_like(next(iter(label.values())), value)
143141
elif isinstance(value, sympy.Basic):
144142
func = sympy.lambdify(
145143
[sympy.Symbol(k) for k in geom.dim_keys],
@@ -152,7 +150,7 @@ def __init__(
152150
weight[key] = func(input)
153151
if isinstance(weight[key], (int, float)):
154152
weight[key] = np.full_like(
155-
next(iter(input.values())), float(weight[key])
153+
next(iter(input.values())), weight[key]
156154
)
157155
else:
158156
raise NotImplementedError(f"type of {type(value)} is invalid yet.")

ppsci/constraint/initial_constraint.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __init__(
110110
if isinstance(value, str):
111111
value = sp_parser.parse_expr(value)
112112
if isinstance(value, (int, float)):
113-
label[key] = np.full_like(next(iter(input.values())), float(value))
113+
label[key] = np.full_like(next(iter(input.values())), value)
114114
elif isinstance(value, sympy.Basic):
115115
func = sympy.lambdify(
116116
sympy.symbols(geom.dim_keys),
@@ -124,9 +124,7 @@ def __init__(
124124
func = value
125125
label[key] = func(input)
126126
if isinstance(label[key], (int, float)):
127-
label[key] = np.full_like(
128-
next(iter(input.values())), float(label[key])
129-
)
127+
label[key] = np.full_like(next(iter(input.values())), label[key])
130128
else:
131129
raise NotImplementedError(f"type of {type(value)} is invalid yet.")
132130

@@ -137,7 +135,7 @@ def __init__(
137135
if isinstance(value, str):
138136
value = sp_parser.parse_expr(value)
139137
if isinstance(value, (int, float)):
140-
weight[key] = np.full_like(next(iter(label.values())), float(value))
138+
weight[key] = np.full_like(next(iter(label.values())), value)
141139
elif isinstance(value, sympy.Basic):
142140
func = sympy.lambdify(
143141
sympy.symbols(geom.dim_keys),
@@ -152,7 +150,7 @@ def __init__(
152150
weight[key] = func(input)
153151
if isinstance(weight[key], (int, float)):
154152
weight[key] = np.full_like(
155-
next(iter(input.values())), float(weight[key])
153+
next(iter(input.values())), weight[key]
156154
)
157155
else:
158156
raise NotImplementedError(f"type of {type(value)} is invalid yet.")

ppsci/constraint/integral_constraint.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Union
2121

2222
import numpy as np
23+
import paddle
2324
import sympy
2425
from sympy.parsing import sympy_parser as sp_parser
2526
from typing_extensions import Literal
@@ -114,7 +115,9 @@ def __init__(
114115
for key, value in label_dict.items():
115116
if isinstance(value, (int, float)):
116117
label[key] = np.full(
117-
(next(iter(input.values())).shape[0], 1), float(value), "float32"
118+
(next(iter(input.values())).shape[0], 1),
119+
value,
120+
paddle.get_default_dtype(),
118121
)
119122
elif isinstance(value, sympy.Basic):
120123
func = sympy.lambdify(
@@ -130,7 +133,9 @@ def __init__(
130133
label[key] = func(input)
131134
if isinstance(label[key], (int, float)):
132135
label[key] = np.full(
133-
(next(iter(input.values())).shape[0], 1), float(label[key])
136+
(next(iter(input.values())).shape[0], 1),
137+
label[key],
138+
paddle.get_default_dtype(),
134139
)
135140
else:
136141
raise NotImplementedError(f"type of {type(value)} is invalid yet.")
@@ -144,7 +149,7 @@ def __init__(
144149
value = sp_parser.parse_expr(value)
145150

146151
if isinstance(value, (int, float)):
147-
weight[key] = np.full_like(next(iter(label.values())), float(value))
152+
weight[key] = np.full_like(next(iter(label.values())), value)
148153
elif isinstance(value, sympy.Basic):
149154
func = sympy.lambdify(
150155
sympy.symbols(geom.dim_keys),
@@ -159,7 +164,7 @@ def __init__(
159164
weight[key] = func(input)
160165
if isinstance(weight[key], (int, float)):
161166
weight[key] = np.full_like(
162-
next(iter(input.values())), float(weight[key])
167+
next(iter(input.values())), weight[key]
163168
)
164169
else:
165170
raise NotImplementedError(f"type of {type(value)} is invalid yet.")

ppsci/constraint/interior_constraint.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(
113113
if isinstance(value, str):
114114
value = sp_parser.parse_expr(value)
115115
if isinstance(value, (int, float)):
116-
label[key] = np.full_like(next(iter(input.values())), float(value))
116+
label[key] = np.full_like(next(iter(input.values())), value)
117117
elif isinstance(value, sympy.Basic):
118118
func = sympy.lambdify(
119119
sympy.symbols(geom.dim_keys),
@@ -127,9 +127,7 @@ def __init__(
127127
func = value
128128
label[key] = func(input)
129129
if isinstance(label[key], (int, float)):
130-
label[key] = np.full_like(
131-
next(iter(input.values())), float(label[key])
132-
)
130+
label[key] = np.full_like(next(iter(input.values())), label[key])
133131
else:
134132
raise NotImplementedError(f"type of {type(value)} is invalid yet.")
135133

@@ -141,7 +139,7 @@ def __init__(
141139
value = sp_parser.parse_expr(value)
142140

143141
if isinstance(value, (int, float)):
144-
weight[key] = np.full_like(next(iter(label.values())), float(value))
142+
weight[key] = np.full_like(next(iter(label.values())), value)
145143
elif isinstance(value, sympy.Basic):
146144
func = sympy.lambdify(
147145
sympy.symbols(geom.dim_keys),
@@ -156,7 +154,7 @@ def __init__(
156154
weight[key] = func(input)
157155
if isinstance(weight[key], (int, float)):
158156
weight[key] = np.full_like(
159-
next(iter(input.values())), float(weight[key])
157+
next(iter(input.values())), weight[key]
160158
)
161159
else:
162160
raise NotImplementedError(f"type of {type(value)} is invalid yet.")

ppsci/constraint/periodic_constraint.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Union
2121

2222
import numpy as np
23+
import paddle
2324
import sympy
2425
from sympy.parsing import sympy_parser as sp_parser
2526
from typing_extensions import Literal
@@ -123,7 +124,9 @@ def __init__(
123124
for key, value in label_dict.items():
124125
# set all label's to zero for dummy data.
125126
label[key] = np.full(
126-
(next(iter(mixed_input.values())).shape[0], 1), 0, "float32"
127+
(next(iter(mixed_input.values())).shape[0], 1),
128+
0,
129+
paddle.get_default_dtype(),
127130
)
128131

129132
# # prepare weight, keep weight the same shape as input_periodic
@@ -134,7 +137,7 @@ def __init__(
134137
value = sp_parser.parse_expr(value)
135138

136139
if isinstance(value, (int, float)):
137-
weight[key] = np.full_like(next(iter(label.values())), float(value))
140+
weight[key] = np.full_like(next(iter(label.values())), value)
138141
elif isinstance(value, sympy.Basic):
139142
func = sympy.lambdify(
140143
[sympy.Symbol(k) for k in geom.dim_keys],
@@ -147,7 +150,7 @@ def __init__(
147150
weight[key] = func(mixed_input)
148151
if isinstance(weight[key], (int, float)):
149152
weight[key] = np.full_like(
150-
next(iter(mixed_input.values())), float(weight[key])
153+
next(iter(mixed_input.values())), weight[key]
151154
)
152155
else:
153156
raise NotImplementedError(f"type of {type(value)} is invalid yet.")

ppsci/data/dataset/csv_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,14 @@ def __init__(
116116
for key, value in weight_dict.items():
117117
if isinstance(value, (int, float)):
118118
self.weight[key] = np.full_like(
119-
next(iter(self.label.values())), float(value)
119+
next(iter(self.label.values())), value
120120
)
121121
elif isinstance(value, types.FunctionType):
122122
func = value
123123
self.weight[key] = func(self.input)
124124
if isinstance(self.weight[key], (int, float)):
125125
self.weight[key] = np.full_like(
126-
next(iter(self.label.values())), float(self.weight[key])
126+
next(iter(self.label.values())), self.weight[key]
127127
)
128128
else:
129129
raise NotImplementedError(f"type of {type(value)} is invalid yet.")
@@ -234,14 +234,14 @@ def __init__(
234234
for key, value in weight_dict.items():
235235
if isinstance(value, (int, float)):
236236
self.weight[key] = np.full_like(
237-
next(iter(self.label.values())), float(value)
237+
next(iter(self.label.values())), value
238238
)
239239
elif isinstance(value, types.FunctionType):
240240
func = value
241241
self.weight[key] = func(self.input)
242242
if isinstance(self.weight[key], (int, float)):
243243
self.weight[key] = np.full_like(
244-
next(iter(self.label.values())), float(self.weight[key])
244+
next(iter(self.label.values())), self.weight[key]
245245
)
246246
else:
247247
raise NotImplementedError(f"type of {type(value)} is invalid yet.")

ppsci/data/dataset/mat_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,14 @@ def __init__(
116116
for key, value in weight_dict.items():
117117
if isinstance(value, (int, float)):
118118
self.weight[key] = np.full_like(
119-
next(iter(self.label.values())), float(value)
119+
next(iter(self.label.values())), value
120120
)
121121
elif isinstance(value, types.FunctionType):
122122
func = value
123123
self.weight[key] = func(self.input)
124124
if isinstance(self.weight[key], (int, float)):
125125
self.weight[key] = np.full_like(
126-
next(iter(self.label.values())), float(self.weight[key])
126+
next(iter(self.label.values())), self.weight[key]
127127
)
128128
else:
129129
raise NotImplementedError(f"type of {type(value)} is invalid yet.")
@@ -234,14 +234,14 @@ def __init__(
234234
for key, value in weight_dict.items():
235235
if isinstance(value, (int, float)):
236236
self.weight[key] = np.full_like(
237-
next(iter(self.label.values())), float(value)
237+
next(iter(self.label.values())), value
238238
)
239239
elif isinstance(value, types.FunctionType):
240240
func = value
241241
self.weight[key] = func(self.input)
242242
if isinstance(self.weight[key], (int, float)):
243243
self.weight[key] = np.full_like(
244-
next(iter(self.label.values())), float(self.weight[key])
244+
next(iter(self.label.values())), self.weight[key]
245245
)
246246
else:
247247
raise NotImplementedError(f"type of {type(value)} is invalid yet.")

ppsci/data/dataset/trphysx_dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def read_data(self, file_path: str, block_size: int, stride: int):
9191
with h5py.File(file_path, "r") as f:
9292
data_num = 0
9393
for key in f.keys():
94-
data_series = np.asarray(f[key], dtype="float32")
94+
data_series = np.asarray(f[key], dtype=paddle.get_default_dtype())
9595
for i in range(0, data_series.shape[0] - block_size + 1, stride):
9696
data.append(data_series[i : i + block_size])
9797
data_num += 1
@@ -246,9 +246,9 @@ def read_data(self, file_path: str, block_size: int, stride: int):
246246
data_num = 0
247247
for key in f.keys():
248248
visc0 = 2.0 / float(key)
249-
ux = np.asarray(f[key + "/ux"], dtype="float32")
250-
uy = np.asarray(f[key + "/uy"], dtype="float32")
251-
p = np.asarray(f[key + "/p"], dtype="float32")
249+
ux = np.asarray(f[key + "/ux"], dtype=paddle.get_default_dtype())
250+
uy = np.asarray(f[key + "/uy"], dtype=paddle.get_default_dtype())
251+
p = np.asarray(f[key + "/p"], dtype=paddle.get_default_dtype())
252252
data_series = np.stack([ux, uy, p], axis=1)
253253

254254
for i in range(0, data_series.shape[0] - block_size + 1, stride):
@@ -260,7 +260,7 @@ def read_data(self, file_path: str, block_size: int, stride: int):
260260
break
261261

262262
data = np.asarray(data)
263-
visc = np.asarray(visc, dtype="float32")
263+
visc = np.asarray(visc, dtype=paddle.get_default_dtype())
264264
return data, visc
265265

266266
def __len__(self):

ppsci/equation/pde/viv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ def __init__(self, rho: float, k1: float, k2: float):
3838
self.rho = rho
3939
self.k1 = paddle.create_parameter(
4040
shape=[1],
41-
dtype="float32",
41+
dtype=paddle.get_default_dtype(),
4242
default_initializer=initializer.Constant(k1),
4343
)
4444
self.k2 = paddle.create_parameter(
4545
shape=[1],
46-
dtype="float32",
46+
dtype=paddle.get_default_dtype(),
4747
default_initializer=initializer.Constant(k2),
4848
)
4949
self.learnable_parameters.append(self.k1)

0 commit comments

Comments
 (0)