Skip to content

Commit 0e0089b

Browse files
support multi-gpu prediction (#350)
1 parent 145ac86 commit 0e0089b

File tree

4 files changed

+99
-47
lines changed

4 files changed

+99
-47
lines changed

ppsci/solver/solver.py

Lines changed: 80 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os
2121
import sys
2222
from typing import Any
23+
from typing import Callable
2324
from typing import Dict
2425
from typing import Optional
2526
from typing import Union
@@ -261,7 +262,7 @@ def __init__(
261262
logger.warning(
262263
f"Detected world_size({self.world_size}) > 1, it is recommended to "
263264
"scale up the learning rate and reduce the epochs or "
264-
"iters_per_epoch according to the world_size number both linearly."
265+
"iters_per_epoch according to the world_size both linearly."
265266
)
266267

267268
self.global_step = 0
@@ -468,55 +469,100 @@ def visualize(self, epoch_id: int = 0):
468469
self.visu_func(self, epoch_id)
469470
logger.info(f"[Visualize][Epoch {epoch_id}] Finished visualization")
470471

471-
@paddle.no_grad()
472472
@misc.run_on_eval_mode
473473
def predict(
474474
self,
475475
input_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
476+
expr_dict: Optional[Dict[str, Callable]] = None,
476477
batch_size: int = 64,
478+
no_grad: bool = True,
477479
) -> Dict[str, paddle.Tensor]:
478-
"""Pure prediction using model.forward(...), support single device prediction yet.
480+
"""Pure prediction using model.forward(...) and expression(optional, if given).
479481
480482
Args:
481483
input_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Input data in dict.
484+
expr_dict (Optional[Dict[str, Callable]]): Expression dict, which guide to
485+
compute equation variable with callable function. Defaults to None.
482486
batch_size (int, optional): Predicting by batch size. Defaults to 64.
483-
487+
no_grad (bool): Whether set stop_gradient=True for entire prediction, mainly
488+
for memory-efficiency. Defaults to True.
484489
Returns:
485490
Dict[str, paddle.Tensor]: Prediction in dict.
486491
"""
487-
if self.world_size > 1:
488-
raise NotImplementedError(
489-
"Solver.predict only support single device yet, "
490-
f"but got {self.world_size} devices."
491-
)
492-
493492
num_samples = len(next(iter(input_dict.values())))
494-
batch_num = (num_samples + (batch_size - 1)) // batch_size
493+
num_pad = (self.world_size - num_samples % self.world_size) % self.world_size
494+
# pad with last element if `num_samples` is not divisible by `world_size`
495+
# ensuring every device get same number of data.
496+
if num_pad > 0:
497+
for k, v in input_dict.items():
498+
repeat_times = (num_pad, *(1 for _ in range(v.ndim - 1)))
499+
input_dict[k] = paddle.concat(
500+
(
501+
v,
502+
paddle.tile(v[num_samples - 1 : num_samples], repeat_times),
503+
),
504+
)
505+
506+
num_samples_pad = num_samples + num_pad
507+
local_num_samples_pad = num_samples_pad // self.world_size
508+
local_input_dict = (
509+
{k: v[self.rank :: self.world_size] for k, v in input_dict.items()}
510+
if self.world_size > 1
511+
else input_dict
512+
)
513+
local_batch_num = (local_num_samples_pad + (batch_size - 1)) // batch_size
495514
pred_dict = misc.Prettydefaultdict(list)
496-
for batch_id in range(batch_num):
497-
batch_input_dict = {}
498-
st = batch_id * batch_size
499-
ed = min(num_samples, (batch_id + 1) * batch_size)
500-
501-
# prepare batch input dict
502-
for key in input_dict:
503-
if not paddle.is_tensor(input_dict[key]):
504-
batch_input_dict[key] = paddle.to_tensor(
505-
input_dict[key][st:ed], paddle.get_default_dtype()
515+
with self.no_grad_context_manager(no_grad), self.no_sync_context_manager(
516+
self.world_size > 1, self.model
517+
):
518+
for batch_id in range(local_batch_num):
519+
batch_input_dict = {}
520+
st = batch_id * batch_size
521+
ed = min(local_num_samples_pad, (batch_id + 1) * batch_size)
522+
523+
# prepare batch input dict
524+
for key in local_input_dict:
525+
if not paddle.is_tensor(local_input_dict[key]):
526+
batch_input_dict[key] = paddle.to_tensor(
527+
local_input_dict[key][st:ed], paddle.get_default_dtype()
528+
)
529+
else:
530+
batch_input_dict[key] = local_input_dict[key][st:ed]
531+
batch_input_dict[key].stop_gradient = no_grad
532+
533+
# forward
534+
with self.autocast_context_manager(self.use_amp, self.amp_level):
535+
batch_output_dict = self.forward_helper.visu_forward(
536+
expr_dict, batch_input_dict, self.model
506537
)
507-
else:
508-
batch_input_dict[key] = input_dict[key][st:ed]
509-
batch_input_dict[key].stop_gradient = False
510-
511-
# forward
512-
with self.autocast_context_manager(self.use_amp, self.amp_level):
513-
batch_output_dict = self.model(batch_input_dict)
514538

515-
# collect batch data
516-
for key, batch_output in batch_output_dict.items():
517-
pred_dict[key].append(batch_output)
518-
519-
pred_dict = {key: paddle.concat(value) for key, value in pred_dict.items()}
539+
# collect batch data
540+
for key, batch_output in batch_output_dict.items():
541+
pred_dict[key].append(batch_output.detach())
542+
543+
# concatenate local predictions
544+
pred_dict = {key: paddle.concat(value) for key, value in pred_dict.items()}
545+
546+
if self.world_size > 1:
547+
# gather global predictions from all devices if world_size > 1
548+
pred_dict = {
549+
key: misc.all_gather(value) for key, value in pred_dict.items()
550+
}
551+
552+
# rearange predictions as the same order of input_dict according to inverse
553+
# permutation, then discard predictions of padding data at the end
554+
perm = np.arange(num_samples_pad, dtype="int64")
555+
perm = np.concatenate(
556+
[perm[rank :: self.world_size] for rank in range(self.world_size)],
557+
axis=0,
558+
)
559+
perm_inv = np.empty_like(perm)
560+
perm_inv[perm] = np.arange(num_samples_pad, dtype="int64")
561+
perm_inv = paddle.to_tensor(perm_inv)
562+
pred_dict = {
563+
key: value[perm_inv][:num_samples]
564+
for key, value in pred_dict.items()
565+
}
520566

521567
return pred_dict
522568

@@ -599,7 +645,7 @@ def no_sync_context_manager(
599645
if not isinstance(ddp_model, paddle.DataParallel):
600646
raise TypeError(
601647
"no_sync interface is only for model with type paddle.DataParallel, "
602-
f"but got type {type(ddp_model)}"
648+
f"but got type {misc.typename(ddp_model)}"
603649
)
604650
ctx_manager = ddp_model.no_sync()
605651
else:

ppsci/solver/visu.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,17 @@
1414

1515
import os
1616
import os.path as osp
17+
from typing import TYPE_CHECKING
1718

1819
import paddle
1920

21+
if TYPE_CHECKING:
22+
from ppsci import solver
23+
2024
from ppsci.utils import misc
2125

2226

23-
def visualize_func(solver, epoch_id: int):
27+
def visualize_func(solver: "solver.Solver", epoch_id: int):
2428
"""Visualization program
2529
2630
Args:

ppsci/utils/expression.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import TYPE_CHECKING
1616
from typing import Callable
1717
from typing import Dict
18+
from typing import Optional
1819
from typing import Tuple
1920

2021
import paddle
@@ -121,23 +122,23 @@ def eval_forward(
121122

122123
def visu_forward(
123124
self,
124-
expr_dict: Dict[str, Callable],
125+
expr_dict: Optional[Dict[str, Callable]],
125126
input_dict: Dict[str, paddle.Tensor],
126127
model: nn.Layer,
127128
):
128129
# model forward
129-
if callable(next(iter(expr_dict.values()))):
130-
output_dict = model(input_dict)
130+
output_dict = model(input_dict)
131131

132-
# equation forward
133-
for name, expr in expr_dict.items():
134-
if callable(expr):
135-
output_dict[name] = expr({**output_dict, **input_dict})
136-
else:
137-
raise TypeError(f"expr type({type(expr)}) is invalid")
132+
if isinstance(expr_dict, dict):
133+
# equation forward
134+
for name, expr in expr_dict.items():
135+
if callable(expr):
136+
output_dict[name] = expr({**output_dict, **input_dict})
137+
else:
138+
raise TypeError(f"expr type({type(expr)}) is invalid")
138139

139-
# clear differentiation cache
140-
clear()
140+
# clear differentiation cache
141+
clear()
141142

142143
# compute loss for each validator according to its' own output, label and weight
143144
return output_dict

ppsci/utils/misc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Dict
2020
from typing import List
2121
from typing import Tuple
22+
from typing import Union
2223

2324
import numpy as np
2425
import paddle
@@ -121,7 +122,7 @@ def convert_to_dict(array: np.ndarray, keys: Tuple[str, ...]) -> Dict[str, np.nd
121122

122123
def all_gather(
123124
tensor: paddle.Tensor, concat: bool = True, axis: int = 0
124-
) -> List[paddle.Tensor]:
125+
) -> Union[paddle.Tensor, List[paddle.Tensor]]:
125126
"""Gather tensor from all devices, concatenate them along given axis if specified.
126127
127128
Args:

0 commit comments

Comments
 (0)