|
20 | 20 | import os
|
21 | 21 | import sys
|
22 | 22 | from typing import Any
|
| 23 | +from typing import Callable |
23 | 24 | from typing import Dict
|
24 | 25 | from typing import Optional
|
25 | 26 | from typing import Union
|
@@ -261,7 +262,7 @@ def __init__(
|
261 | 262 | logger.warning(
|
262 | 263 | f"Detected world_size({self.world_size}) > 1, it is recommended to "
|
263 | 264 | "scale up the learning rate and reduce the epochs or "
|
264 |
| - "iters_per_epoch according to the world_size number both linearly." |
| 265 | + "iters_per_epoch according to the world_size both linearly." |
265 | 266 | )
|
266 | 267 |
|
267 | 268 | self.global_step = 0
|
@@ -468,55 +469,100 @@ def visualize(self, epoch_id: int = 0):
|
468 | 469 | self.visu_func(self, epoch_id)
|
469 | 470 | logger.info(f"[Visualize][Epoch {epoch_id}] Finished visualization")
|
470 | 471 |
|
471 |
| - @paddle.no_grad() |
472 | 472 | @misc.run_on_eval_mode
|
473 | 473 | def predict(
|
474 | 474 | self,
|
475 | 475 | input_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
|
| 476 | + expr_dict: Optional[Dict[str, Callable]] = None, |
476 | 477 | batch_size: int = 64,
|
| 478 | + no_grad: bool = True, |
477 | 479 | ) -> Dict[str, paddle.Tensor]:
|
478 |
| - """Pure prediction using model.forward(...), support single device prediction yet. |
| 480 | + """Pure prediction using model.forward(...) and expression(optional, if given). |
479 | 481 |
|
480 | 482 | Args:
|
481 | 483 | input_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Input data in dict.
|
| 484 | + expr_dict (Optional[Dict[str, Callable]]): Expression dict, which guide to |
| 485 | + compute equation variable with callable function. Defaults to None. |
482 | 486 | batch_size (int, optional): Predicting by batch size. Defaults to 64.
|
483 |
| -
|
| 487 | + no_grad (bool): Whether set stop_gradient=True for entire prediction, mainly |
| 488 | + for memory-efficiency. Defaults to True. |
484 | 489 | Returns:
|
485 | 490 | Dict[str, paddle.Tensor]: Prediction in dict.
|
486 | 491 | """
|
487 |
| - if self.world_size > 1: |
488 |
| - raise NotImplementedError( |
489 |
| - "Solver.predict only support single device yet, " |
490 |
| - f"but got {self.world_size} devices." |
491 |
| - ) |
492 |
| - |
493 | 492 | num_samples = len(next(iter(input_dict.values())))
|
494 |
| - batch_num = (num_samples + (batch_size - 1)) // batch_size |
| 493 | + num_pad = (self.world_size - num_samples % self.world_size) % self.world_size |
| 494 | + # pad with last element if `num_samples` is not divisible by `world_size` |
| 495 | + # ensuring every device get same number of data. |
| 496 | + if num_pad > 0: |
| 497 | + for k, v in input_dict.items(): |
| 498 | + repeat_times = (num_pad, *(1 for _ in range(v.ndim - 1))) |
| 499 | + input_dict[k] = paddle.concat( |
| 500 | + ( |
| 501 | + v, |
| 502 | + paddle.tile(v[num_samples - 1 : num_samples], repeat_times), |
| 503 | + ), |
| 504 | + ) |
| 505 | + |
| 506 | + num_samples_pad = num_samples + num_pad |
| 507 | + local_num_samples_pad = num_samples_pad // self.world_size |
| 508 | + local_input_dict = ( |
| 509 | + {k: v[self.rank :: self.world_size] for k, v in input_dict.items()} |
| 510 | + if self.world_size > 1 |
| 511 | + else input_dict |
| 512 | + ) |
| 513 | + local_batch_num = (local_num_samples_pad + (batch_size - 1)) // batch_size |
495 | 514 | pred_dict = misc.Prettydefaultdict(list)
|
496 |
| - for batch_id in range(batch_num): |
497 |
| - batch_input_dict = {} |
498 |
| - st = batch_id * batch_size |
499 |
| - ed = min(num_samples, (batch_id + 1) * batch_size) |
500 |
| - |
501 |
| - # prepare batch input dict |
502 |
| - for key in input_dict: |
503 |
| - if not paddle.is_tensor(input_dict[key]): |
504 |
| - batch_input_dict[key] = paddle.to_tensor( |
505 |
| - input_dict[key][st:ed], paddle.get_default_dtype() |
| 515 | + with self.no_grad_context_manager(no_grad), self.no_sync_context_manager( |
| 516 | + self.world_size > 1, self.model |
| 517 | + ): |
| 518 | + for batch_id in range(local_batch_num): |
| 519 | + batch_input_dict = {} |
| 520 | + st = batch_id * batch_size |
| 521 | + ed = min(local_num_samples_pad, (batch_id + 1) * batch_size) |
| 522 | + |
| 523 | + # prepare batch input dict |
| 524 | + for key in local_input_dict: |
| 525 | + if not paddle.is_tensor(local_input_dict[key]): |
| 526 | + batch_input_dict[key] = paddle.to_tensor( |
| 527 | + local_input_dict[key][st:ed], paddle.get_default_dtype() |
| 528 | + ) |
| 529 | + else: |
| 530 | + batch_input_dict[key] = local_input_dict[key][st:ed] |
| 531 | + batch_input_dict[key].stop_gradient = no_grad |
| 532 | + |
| 533 | + # forward |
| 534 | + with self.autocast_context_manager(self.use_amp, self.amp_level): |
| 535 | + batch_output_dict = self.forward_helper.visu_forward( |
| 536 | + expr_dict, batch_input_dict, self.model |
506 | 537 | )
|
507 |
| - else: |
508 |
| - batch_input_dict[key] = input_dict[key][st:ed] |
509 |
| - batch_input_dict[key].stop_gradient = False |
510 |
| - |
511 |
| - # forward |
512 |
| - with self.autocast_context_manager(self.use_amp, self.amp_level): |
513 |
| - batch_output_dict = self.model(batch_input_dict) |
514 | 538 |
|
515 |
| - # collect batch data |
516 |
| - for key, batch_output in batch_output_dict.items(): |
517 |
| - pred_dict[key].append(batch_output) |
518 |
| - |
519 |
| - pred_dict = {key: paddle.concat(value) for key, value in pred_dict.items()} |
| 539 | + # collect batch data |
| 540 | + for key, batch_output in batch_output_dict.items(): |
| 541 | + pred_dict[key].append(batch_output.detach()) |
| 542 | + |
| 543 | + # concatenate local predictions |
| 544 | + pred_dict = {key: paddle.concat(value) for key, value in pred_dict.items()} |
| 545 | + |
| 546 | + if self.world_size > 1: |
| 547 | + # gather global predictions from all devices if world_size > 1 |
| 548 | + pred_dict = { |
| 549 | + key: misc.all_gather(value) for key, value in pred_dict.items() |
| 550 | + } |
| 551 | + |
| 552 | + # rearange predictions as the same order of input_dict according to inverse |
| 553 | + # permutation, then discard predictions of padding data at the end |
| 554 | + perm = np.arange(num_samples_pad, dtype="int64") |
| 555 | + perm = np.concatenate( |
| 556 | + [perm[rank :: self.world_size] for rank in range(self.world_size)], |
| 557 | + axis=0, |
| 558 | + ) |
| 559 | + perm_inv = np.empty_like(perm) |
| 560 | + perm_inv[perm] = np.arange(num_samples_pad, dtype="int64") |
| 561 | + perm_inv = paddle.to_tensor(perm_inv) |
| 562 | + pred_dict = { |
| 563 | + key: value[perm_inv][:num_samples] |
| 564 | + for key, value in pred_dict.items() |
| 565 | + } |
520 | 566 |
|
521 | 567 | return pred_dict
|
522 | 568 |
|
@@ -599,7 +645,7 @@ def no_sync_context_manager(
|
599 | 645 | if not isinstance(ddp_model, paddle.DataParallel):
|
600 | 646 | raise TypeError(
|
601 | 647 | "no_sync interface is only for model with type paddle.DataParallel, "
|
602 |
| - f"but got type {type(ddp_model)}" |
| 648 | + f"but got type {misc.typename(ddp_model)}" |
603 | 649 | )
|
604 | 650 | ctx_manager = ddp_model.no_sync()
|
605 | 651 | else:
|
|
0 commit comments