Skip to content

Commit 4dad760

Browse files
change more float32 to get_default_dtype
1 parent 84a1fe8 commit 4dad760

File tree

4 files changed

+11
-4
lines changed

4 files changed

+11
-4
lines changed

ppsci/arch/physx_transformer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,9 @@ def _init_weights(self, module):
308308

309309
def get_position_embed(self, x):
310310
B, N, _ = x.shape
311-
position_ids = paddle.arange(0, N, dtype="float32").reshape([1, N, 1])
311+
position_ids = paddle.arange(0, N, dtype=paddle.get_default_dtype()).reshape(
312+
[1, N, 1]
313+
)
312314
position_ids = position_ids.repeat_interleave(B, axis=0)
313315

314316
position_embeds = paddle.zeros_like(x)

ppsci/utils/misc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,11 @@ def combine_array_with_time(x, t):
148148
nx = len(x)
149149
tx = []
150150
for ti in t:
151-
tx.append(np.hstack((np.full([nx, 1], float(ti), dtype="float32"), x)))
151+
tx.append(
152+
np.hstack(
153+
(np.full([nx, 1], float(ti), dtype=paddle.get_default_dtype()), x)
154+
)
155+
)
152156
tx = np.vstack(tx)
153157
return tx
154158

ppsci/validate/geo_validator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Union
2121

2222
import numpy as np
23+
import paddle
2324
import sympy
2425
from sympy.parsing import sympy_parser as sp_parser
2526
from typing_extensions import Literal
@@ -152,7 +153,7 @@ def __init__(
152153
label[key] = np.full(
153154
(next(iter(input.values())).shape[0], 1),
154155
label[key],
155-
"float32",
156+
paddle.get_default_dtype(),
156157
)
157158
else:
158159
raise NotImplementedError(f"type of {type(value)} is invalid yet.")

ppsci/visualize/vtu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _save_vtu_from_array(filename, coord, value, value_keys, num_timestamp=1):
6363
if coord_ndim == 2:
6464
axis_x = np.ascontiguousarray(coord[t * nx : (t + 1) * nx, 0])
6565
axis_y = np.ascontiguousarray(coord[t * nx : (t + 1) * nx, 1])
66-
axis_z = np.zeros([nx], dtype="float32")
66+
axis_z = np.zeros([nx], dtype=paddle.get_default_dtype())
6767
elif coord_ndim == 3:
6868
axis_x = np.ascontiguousarray(coord[t * nx : (t + 1) * nx, 0])
6969
axis_y = np.ascontiguousarray(coord[t * nx : (t + 1) * nx, 1])

0 commit comments

Comments
 (0)