diff --git a/torch_cfd/advection.py b/torch_cfd/advection.py index 3bb40ed..7aea9ab 100644 --- a/torch_cfd/advection.py +++ b/torch_cfd/advection.py @@ -282,6 +282,9 @@ def forward(self, cs: GridVariableVector, v: GridVariableVector) -> GridVariable flux = GridVariableVector(tuple(c * u for c, u in zip(cs, v))) # wrap flux with boundary conditions to flux if not periodic + # flux = GridVariableVector( + # tuple(bc.impose_bc(f) for f, bc in zip(flux, self.flux_bcs)) + # ) flux = GridVariableVector(tuple(GridVariable(f.data, offset, f.grid, bc) for f, offset, bc in zip(flux, self.offsets, self.flux_bcs))) diff --git a/torch_cfd/boundaries.py b/torch_cfd/boundaries.py index 4dc7ea0..708cfc1 100644 --- a/torch_cfd/boundaries.py +++ b/torch_cfd/boundaries.py @@ -384,7 +384,7 @@ def pad_and_impose_bc( ) return GridVariable(u.data, u.offset, u.grid, self) - def impose_bc(self, u: GridVariable, mode: str="") -> GridVariable: + def impose_bc(self, u: GridVariable, mode: str = "") -> GridVariable: """Returns GridVariable with correct boundary condition. Some grid points of GridVariable might coincide with boundary. This ensures @@ -435,12 +435,32 @@ def is_bc_periodic_boundary_conditions(bc: BoundaryConditions, dim: int) -> bool ) return True +def is_bc_all_periodic_boundary_conditions(bc: BoundaryConditions) -> bool: + """Returns true if scalar has periodic bc along all axes.""" + for dim in range(bc.ndim): + if not is_bc_periodic_boundary_conditions(bc, dim): + return False + return True + def is_periodic_boundary_conditions(c: GridVariable, dim: int) -> bool: """Returns true if scalar has periodic bc along axis.""" return is_bc_periodic_boundary_conditions(c.bc, dim) +def is_bc_pure_neumann_boundary_conditions(bc: BoundaryConditions) -> bool: + """Returns true if scalar has pure Neumann bc along all axes.""" + for dim in range(bc.ndim): + if bc.types[dim][0] != BCType.NEUMANN or bc.types[dim][1] != BCType.NEUMANN: + return False + return True + + +def is_pure_neumann_boundary_conditions(c: GridVariable) -> bool: + """Returns true if scalar has pure Neumann bc along all axes.""" + return is_bc_pure_neumann_boundary_conditions(c.bc) + + # Convenience utilities to ease updating of BoundaryConditions implementation def periodic_boundary_conditions(ndim: int) -> BoundaryConditions: """Returns periodic BCs for a variable with `ndim` spatial dimension.""" diff --git a/torch_cfd/finite_differences.py b/torch_cfd/finite_differences.py index 91ce71a..5be7751 100644 --- a/torch_cfd/finite_differences.py +++ b/torch_cfd/finite_differences.py @@ -42,7 +42,7 @@ import torch from torch_cfd import boundaries, grids -ArrayVector = Sequence[torch.Tensor] +ArrayVector = List[torch.Tensor] GridVariable = grids.GridVariable GridTensor = grids.GridTensor GridVariableVector = Union[grids.GridVariableVector, Sequence[grids.GridVariable]] @@ -159,20 +159,22 @@ def set_laplacian_matrix( grid: grids.Grid, bc: boundaries.BoundaryConditions, device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, ) -> ArrayVector: """Initialize the Laplacian operators.""" offset = grid.cell_center - return laplacian_matrix_w_boundaries(grid, offset=offset, bc=bc, device=device) + return laplacian_matrix_w_boundaries(grid, offset=offset, bc=bc, device=device, dtype=dtype) -def laplacian_matrix(n: int, step: float, sparse: bool = False) -> torch.Tensor: +def laplacian_matrix(n: int, step: float, sparse: bool = False, dtype=torch.float32) -> torch.Tensor: """ Create 1D Laplacian operator matrix, with periodic BC. - modified the scipy.linalg.circulant implementation to native torch + The matrix is a tri-diagonal matrix with [1, -2, 1]/h**2 + Modified the scipy.linalg.circulant implementation to native torch """ if sparse: - values = torch.tensor([1.0, -2.0, 1.0]) / step**2 + values = torch.tensor([1.0, -2.0, 1.0], dtype=dtype) / step**2 idx_row = torch.arange(n).repeat(3) idx_col = torch.cat( [ @@ -188,7 +190,7 @@ def laplacian_matrix(n: int, step: float, sparse: bool = False) -> torch.Tensor: ) return torch.sparse_coo_tensor(indices, data, size=(n, n)) else: - column = torch.zeros(n) + column = torch.zeros(n, dtype=dtype) column[0] = -2 / step**2 column[1] = column[-1] = 1 / step**2 idx = (n - torch.arange(n)[None].T + torch.arange(n)[None]) % n @@ -196,25 +198,37 @@ def laplacian_matrix(n: int, step: float, sparse: bool = False) -> torch.Tensor: def _laplacian_boundary_dirichlet_cell_centered( - laplacians: ArrayVector, grid: grids.Grid, axis: int, side: str + laplacians: ArrayVector, grid: grids.Grid, dim: int, side: str ) -> None: """Converts 1d laplacian matrix to satisfy dirichlet homogeneous bc. laplacians[i] contains a 3 point stencil matrix L that approximates d^2/dx_i^2. For detailed documentation on laplacians input type see - array_utils.laplacian_matrix. - The default return of array_utils.laplacian_matrix makes a matrix for - periodic boundary. For dirichlet boundary, the correct equation is - L(u_interior) = rhs_interior and BL_boundary = u_fixed_boundary. So + fdm.laplacian_matrix. + The default return of fdm.laplacian_matrix makes a matrix for + periodic boundary. For (homogeneous) dirichlet boundary, the correct equation is + L(u_interior) = rhs_interior + BL_boundary = u_fixed_boundary. + So laplacian_boundary_dirichlet restricts the matrix L to - interior points only. + interior points only. + + Denote the node in the 3-pt stencil as + u[ghost], u[boundary], u[interior] = u[0], u[1], u[2]. + The original stencil on the boundary is + [1, -2, 1] * [u[0], u[1], u[2]] = u[0] - 2*u[1] + u[2] + In the homogeneous Dirichlet bc case if the offset + is 0.5 away from the wall, the ghost cell value u[0] = -u[1]. So the + 3 point stencil [1 -2 1] * [u[0] u[1] u[2]] = -3 u[1] + u[2]. + The original diagonal of Laplacian Lap[0, 0] is -2/h**2, we need -3/h**2, + thus 1/h**2 is subtracted from the diagonal, and the ghost cell dof is set to zero (Lap[0, -1]) This function assumes RHS has cell-centered offset. Args: laplacians: list of 1d laplacians grid: grid object - axis: axis along which to impose dirichlet bc. + dim: axis along which to impose dirichlet bc. side: lower or upper side to assign boundary to. Returns: @@ -223,52 +237,50 @@ def _laplacian_boundary_dirichlet_cell_centered( TODO: [ ]: this function is not implemented in the original Jax-CFD code. """ - # This function assumes homogeneous boundary, in which case if the offset - # is 0.5 away from the wall, the ghost cell value u[0] = -u[1]. So the - # 3 point stencil [1 -2 1] * [u[0] u[1] u[2]] = -3 u[1] + u[2]. + if side == "lower": - laplacians[axis][0, 0] = laplacians[axis][0, 0] - 1 / grid.step[axis] ** 2 + laplacians[dim][0, 0] = laplacians[dim][0, 0] - 1 / grid.step[dim] ** 2 else: - laplacians[axis][-1, -1] = laplacians[axis][-1, -1] - 1 / grid.step[axis] ** 2 + laplacians[dim][-1, -1] = laplacians[dim][-1, -1] - 1 / grid.step[dim] ** 2 # deletes corner dependencies on the "looped-around" part. # this should be done irrespective of which side, since one boundary cannot # be periodic while the other is. - laplacians[axis][0, -1] = 0.0 - laplacians[axis][-1, 0] = 0.0 - return laplacians + laplacians[dim][0, -1] = 0.0 + laplacians[dim][-1, 0] = 0.0 + return def _laplacian_boundary_neumann_cell_centered( - laplacians: List[Any], grid: grids.Grid, axis: int, side: str + laplacians: ArrayVector, grid: grids.Grid, dim: int, side: str ) -> None: """Converts 1d laplacian matrix to satisfy neumann homogeneous bc. This function assumes the RHS will have a cell-centered offset. Neumann boundaries are not defined for edge-aligned offsets elsewhere in the - code. + code. For homogeneous Neumann BC (du/dn = 0), the ghost cell should equal the interior cell: u[ghost] = u[1]. The stencil becomes: + [1, -2, 1] * [u[1], u[1], u[2]] = u[1] - 2*u[1] + u[2] = -u[1] + u[2] + The original diagonal of Laplacian Lap[0, 0] is -2/h**2, we need -1/h**2, + thus 1/h**2 is added to the diagonal, and the ghost cell dof is set to zero (Lap[0, -1]). Args: laplacians: list of 1d laplacians grid: grid object - axis: axis along which to impose dirichlet bc. + dim: axis along which to impose dirichlet bc. side: which boundary side to convert to neumann homogeneous bc. Returns: updated list of 1d laplacians. - - TODO - [ ]: this function is not implemented in the original Jax-CFD code. """ if side == "lower": - laplacians[axis][0, 0] = laplacians[axis][0, 0] + 1 / grid.step[axis] ** 2 + laplacians[dim][0, 0] = laplacians[dim][0, 0] + 1 / grid.step[dim] ** 2 else: - laplacians[axis][-1, -1] = laplacians[axis][-1, -1] + 1 / grid.step[axis] ** 2 + laplacians[dim][-1, -1] = laplacians[dim][-1, -1] + 1 / grid.step[dim] ** 2 # deletes corner dependencies on the "looped-around" part. # this should be done irrespective of which side, since one boundary cannot # be periodic while the other is. - laplacians[axis][0, -1] = 0.0 - laplacians[axis][-1, 0] = 0.0 - return laplacians + laplacians[dim][0, -1] = 0.0 + laplacians[dim][-1, 0] = 0.0 + return def laplacian_matrix_w_boundaries( @@ -277,6 +289,7 @@ def laplacian_matrix_w_boundaries( bc: grids.BoundaryConditions, laplacians: Optional[ArrayVector] = None, device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, sparse: bool = False, ) -> ArrayVector: """Returns 1d laplacians that satisfy boundary conditions bc on grid. @@ -323,11 +336,13 @@ def laplacian_matrix_w_boundaries( raise NotImplementedError( "edge-aligned Neumann boundaries are not implemented." ) - return list(lap.to(device) for lap in laplacians) if device else laplacians + return list(lap.to(dtype).to(device) for lap in laplacians) def _linear_along_axis(c: GridVariable, offset: float, dim: int) -> GridVariable: - """Linear interpolation of `c` to `offset` along a single specified `axis`.""" + """Linear interpolation of `c` to `offset` along a single specified `axis`. + dim here is >= 0, the negative indexing for batched implementation is handled by grids.shift. + """ offset_delta = offset - c.offset[dim] # If offsets are the same, `c` is unchanged. @@ -383,8 +398,8 @@ def linear( f"got {c.offset} and {offset}." ) interpolated = c - for a, o in enumerate(offset): - interpolated = _linear_along_axis(interpolated, offset=o, dim=a) + for dim, o in enumerate(offset): + interpolated = _linear_along_axis(interpolated, offset=o, dim=dim) return interpolated @@ -405,15 +420,15 @@ def gradient_tensor(v): if not isinstance(v, GridVariable): return GridTensor(torch.stack([gradient_tensor(u) for u in v], dim=-1)) grad = [] - for axis in range(v.grid.ndim): - offset = v.offset[axis] + for dim in range(-v.grid.ndim, 0): + offset = v.offset[dim] if offset == 0: - derivative = forward_difference(v, axis) + derivative = forward_difference(v, dim) elif offset == 1: - derivative = backward_difference(v, axis) + derivative = backward_difference(v, dim) elif offset == 0.5: v_centered = linear(v, v.grid.cell_center) - derivative = central_difference(v_centered, axis) + derivative = central_difference(v_centered, dim) else: raise ValueError(f"expected offset values in {{0, 0.5, 1}}, got {offset}") grad.append(derivative) @@ -427,4 +442,4 @@ def curl_2d(v: GridVariableVector) -> GridVariable: grid = grids.consistent_grid_arrays(*v) if grid.ndim != 2: raise ValueError(f"Grid dimensionality is not 2: {grid.ndim}") - return forward_difference(v[1], dim=0) - forward_difference(v[0], dim=1) + return forward_difference(v[1], dim=-2) - forward_difference(v[0], dim=-1) diff --git a/torch_cfd/forcings.py b/torch_cfd/forcings.py index 8900998..69caac4 100644 --- a/torch_cfd/forcings.py +++ b/torch_cfd/forcings.py @@ -25,6 +25,7 @@ Grid = grids.Grid GridVariable = grids.GridVariable +GridVariableVector = grids.GridVariableVector def forcing_eval(eval_func): @@ -79,7 +80,7 @@ class ForcingFn(nn.Module): def __init__( self, grid: Grid, - scale: float = 1, + scale: float = 1.0, wave_number: int = 1, diam: float = 1.0, swap_xy: bool = False, @@ -100,12 +101,12 @@ def __init__( @forcing_eval def velocity_eval( - grid: Grid, velocity: Optional[Tuple[torch.Tensor, torch.Tensor]] - ) -> Tuple[torch.Tensor, torch.Tensor]: + self, grid: Grid, velocity: Optional[Tuple[GridVariable, GridVariable]] + ) -> GridVariableVector: raise NotImplementedError @forcing_eval - def vorticity_eval(grid: Grid, vorticity: Optional[torch.Tensor]) -> torch.Tensor: + def vorticity_eval(self, grid: Grid, vorticity: Optional[torch.Tensor]) -> GridVariable: raise NotImplementedError def forward( @@ -113,7 +114,7 @@ def forward( grid: Optional[Union[Grid, Tuple[Grid, Grid]]] = None, velocity: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, vorticity: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Union[GridVariable, GridVariableVector]: if not self.vorticity: return self.velocity_eval(grid, velocity) else: @@ -166,7 +167,7 @@ def velocity_eval( self, grid: Optional[Grid], velocity: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> GridVariableVector: offsets = self.offsets grid = self.grid if grid is None else grid domain_factor = 2 * torch.pi / self.diam @@ -187,13 +188,13 @@ def velocity_eval( grid, ) v = GridVariable(torch.zeros_like(u.data), (1 / 2, 1), grid) - return tuple((u, v)) + return GridVariableVector(tuple((u, v))) def vorticity_eval( self, grid: Optional[Grid], vorticity: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> GridVariable: offsets = self.offsets grid = self.grid if grid is None else grid domain_factor = 2 * torch.pi / self.diam @@ -243,9 +244,9 @@ class SimpleSolenoidalForcing(ForcingFn): def __init__( self, - scale=1, + scale=1.0, diam=1.0, - k=1.0, + wave_number=1, offsets=((0, 0), (0, 0)), vorticity=True, *args, @@ -255,7 +256,7 @@ def __init__( *args, scale=scale, diam=diam, - wave_number=k, + wave_number=wave_number, offsets=offsets, vorticity=vorticity, **kwargs, @@ -273,7 +274,7 @@ def velocity_eval( self, grid: Optional[Grid], velocity: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> GridVariableVector: offsets = self.offsets grid = self.grid if grid is None else grid domain_factor = 2 * torch.pi / self.diam @@ -292,7 +293,7 @@ def velocity_eval( rot = self.potential(x, y, scale, k) u = GridVariable(rot, offsets[0], grid) v = GridVariable(-rot, (1 / 2, 1), grid) - return tuple((u, v)) + return GridVariableVector(tuple((u, v))) def vorticity_eval( self, @@ -339,7 +340,7 @@ def __init__( self, scale=0.1, diam=1.0, - k=1.0, + wave_number=1, offsets=((0, 0), (0, 0)), *args, **kwargs, @@ -348,7 +349,7 @@ def __init__( *args, scale=scale, diam=diam, - k=k, + wave_number=wave_number, offsets=offsets, **kwargs, ) diff --git a/torch_cfd/fvm.py b/torch_cfd/fvm.py index bc71911..3ce7e0f 100644 --- a/torch_cfd/fvm.py +++ b/torch_cfd/fvm.py @@ -14,21 +14,25 @@ # Modifications copyright (C) 2025 S.Cao # ported Google's Jax-CFD functional template to PyTorch's tensor ops +"""Finite volume methods on MAC grids with pressure projection.""" + from __future__ import annotations -from typing import Callable, Dict, List, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn import torch_cfd.finite_differences as fdm -from torch_cfd import advection, boundaries, forcings, grids, pressure +from torch_cfd import advection, boundaries, forcings, grids, solvers Grid = grids.Grid GridVariable = grids.GridVariable GridVariableVector = grids.GridVariableVector +BoundaryConditions = boundaries.BoundaryConditions ForcingFn = forcings.ForcingFn +Solver = solvers.SolverBase def wrap_field_same_bcs(v, field_ref): @@ -52,7 +56,9 @@ def explicit_terms(self, *args, **kwargs) -> GridVariableVector: """ raise NotImplementedError - def pressure_projection(self, *args, **kwargs) -> Tuple[GridVariableVector, GridVariable]: + def pressure_projection( + self, *args, **kwargs + ) -> Tuple[GridVariableVector, GridVariable]: """Pressure projection step.""" raise NotImplementedError @@ -107,7 +113,7 @@ def __init__( # Set the tableau first directly, either directly or from method name if tableau is not None: self.tableau = tableau - else: + elif method is not None: self.method = method self._set_params() @@ -187,8 +193,8 @@ def forward( Returns: Updated velocity field after one time step - Port note: - - In Jax-CFD, dvdt is wrapped with the same bc with v, + Port note: + - In Jax-CFD, dvdt is wrapped with the same bc with v, which does not work for inhomogeneous boundary condition. see explicit_terms_with_same_bcs in jax_cfd/base/equation.py """ @@ -196,8 +202,8 @@ def forward( beta = self.params["b"] num_steps = len(beta) - u = [None] * num_steps - k = [None] * num_steps + u: List[Optional[GridVariableVector]] = [None] * num_steps + k: List[Optional[GridVariableVector]] = [None] * num_steps # First stage u[0] = u0 @@ -227,15 +233,151 @@ def forward( return u_final, p +class PressureProjection(nn.Module): + def __init__( + self, + grid: grids.Grid, + bc: BoundaryConditions, + dtype: torch.dtype = torch.float32, + solver: Union[str, Solver] = "pseudoinverse", + implementation: Optional[str] = None, + laplacians: Optional[List[torch.Tensor]] = None, + **solver_kwargs, + ): + """ + Args: + grid: Grid object describing the spatial domain. + bc: Boundary conditions for the Laplacian operator (for pressure). + dtype: Tensor data type. For consistency purpose. + implementation: One of ['fft', 'rfft', 'matmul']. + circulant: If True, bc is periodical + laplacians: Precomputed Laplacian operators. If None, they are computed from the grid during initiliazation. + initial_guess_pressure: Initial guess for pressure. If None, a zero tensor is used. + """ + super().__init__() + self.grid = grid + self.bc = bc + self.dtype = dtype + self.implementation = implementation + solvers._set_laplacian(self, laplacians, grid, bc) + self.ndim = grid.ndim + + @property + def inverse(self) -> torch.Tensor: + return self.solver.inverse + + @property + def operators(self) -> List[torch.Tensor]: + """Get the list of 1D Laplacian operators.""" + return [getattr(self.solver, f"laplacian_{i}") for i in range(self.ndim)] + + if isinstance(solver, nn.Module): + self.solver = solver + elif isinstance(solver, str): + if solver in ["conjugate_gradient", "cg"]: + self.solver = solvers.ConjugateGradient( + grid=grid, + bc=bc, + dtype=dtype, + laplacians=laplacians, + pure_neumann=True, + **solver_kwargs, + ) + elif solver in ["pseudoinverse", "fft", "rfft", "svd"]: + self.solver = solvers.PseudoInverse( + grid=grid, + bc=bc, + dtype=dtype, + hermitian=True, + implementation=implementation, + laplacians=laplacians, + ) + else: + raise NotImplementedError(f"Unsupported solver: {solver}") + + def forward(self, v: GridVariableVector) -> Tuple[GridVariableVector, GridVariable]: + """Project velocity to be divergence-free.""" + solver = self.solver.to(v.device) + if hasattr(self, "q0"): + # Use the previous pressure as initial guess + q0 = self.q0.to(v.device) + else: + # No previous pressure, use zero as initial guess + q0 = GridVariable( + torch.zeros_like(v[0].data, dtype=self.dtype), + v[0].offset, + v[0].grid, + v[0].bc, + ).to(v.device) + self.q0 = q0 + _ = grids.consistent_grid(self.grid, *v) + pressure_bc = boundaries.get_pressure_bc_from_velocity(v) + + rhs = fdm.divergence(v) + rhs_transformed = self.rhs_transform(rhs, pressure_bc) + rhs_inv = solver.solve(rhs_transformed, q0.data) + q = GridVariable(rhs_inv, rhs.offset, rhs.grid) + q = pressure_bc.impose_bc(q) + q_grad = fdm.forward_difference(q) + v_projected = GridVariableVector( + tuple(u.bc.impose_bc(u - q_g) for u, q_g in zip(v, q_grad)) + ) + self.q0 = q + # assert v_projected.__len__() == v.__len__() + return v_projected, q + + @staticmethod + def rhs_transform( + u: GridVariable, + bc: BoundaryConditions, + ) -> torch.Tensor: + """Transform the RHS of pressure projection equation for stability.""" + u_data = u.data # (b, n, m) or (n, m) + ndim = u.grid.ndim + for dim in range(ndim): + if ( + bc.types[dim][0] == boundaries.BCType.NEUMANN + and bc.types[dim][1] == boundaries.BCType.NEUMANN + ): + # Check if we have batched data + if u_data.ndim > ndim: + # For batched data, calculate mean separately for each batch + # Keep the batch dimension, reduce over grid dimensions + dims = tuple(range(-ndim, 0)) + mean = torch.mean(u_data, dim=dims, keepdim=True) + else: + # For non-batched data, calculate global mean + mean = torch.mean(u_data) + u_data = u_data - mean + return u_data + + class NavierStokes2DFVMProjection(ProjectionExplicitODE): r"""incompressible Navier-Stokes velocity pressure formulation Runge-Kutta time stepper for the NSE discretized using a MAC grid FVM with a pressure projection Chorin's method. The x- and y-dofs of the velocity are on a staggered grid, which is reflected in the offset attr. + References: + - Sanderse, B., & Koren, B. (2012). Accuracy analysis of explicit Runge-Kutta methods applied to the incompressible Navier-Stokes equations. Journal of Computational Physics, 231(8), 3041-3063. + - Almgren, A. S., Bell, J. B., & Szymczak, W. G. (1996). A numerical method for the incompressible Navier-Stokes equations based on an approximate projection. SIAM Journal on Scientific Computing, 17(2), 358-369. + - Capuano, F., Coppola, G., Chiatto, M., & de Luca, L. (2016). Approximate projection method for the incompressible Navier-Stokes equations. AIAA journal, 54(7), 2179-2182. + + Args: + viscosity: 1/Re + grid: Grid on which the fields are defined + bcs: Boundary conditions for the velocity field (default: periodic) + drag: Drag coefficient applied to the velocity field (default: 0.0) + density: Density of the fluid (default: 1.0) + convection: Convection term function (default: advection.ConvectionVector) + pressure_proj: Pressure projection function (default: pressure.PressureProjection) + forcing: Forcing function applied to the velocity field (default: None) + step_fn: Runge-Kutta stepper function (default: RKStepper with classic_rk4 method) + Original implementation in Jax-CFD repository: - semi_implicit_navier_stokes in jax_cfd.base.fvm which returns a stepper function `time_stepper(ode, dt)` where `ode` specifies the explicit terms and the pressure projection. + - The pressure projection is done by calling `pressure.projection` which can solve the solver to solve the Poisson equation \Delta q = div(u). - The time_stepper is a wrapper function by jax.named_call( navier_stokes_rk()) that implements the various Runge-Kutta method according to the Butcher tableau. - navier_stokes_rk() implements Runge-Kutta time-stepping for the NSE using the explicit terms and pressure projection with equation as an input where user needs to specify the explicit terms and pressure projection. @@ -253,10 +395,10 @@ def __init__( bcs: Optional[Sequence[boundaries.BoundaryConditions]] = None, drag: float = 0.0, density: float = 1.0, - convection: Callable = None, - pressure_proj: Callable = None, + convection: Optional[Callable] = None, + pressure_proj: Optional[Callable] = None, forcing: Optional[ForcingFn] = None, - step_fn: RKStepper = None, + step_fn: Optional[RKStepper] = None, **kwargs, ): """ @@ -290,7 +432,7 @@ def _set_pressure_projection(self): if self.pressure_proj is not None: self._projection = self.pressure_proj return - self._projection = pressure.PressureProjection( + self._projection = PressureProjection( grid=self.grid, bc=self.pressure_bc, ) diff --git a/torch_cfd/grids.py b/torch_cfd/grids.py index a58c22c..35d788f 100644 --- a/torch_cfd/grids.py +++ b/torch_cfd/grids.py @@ -592,7 +592,7 @@ def device(self) -> torch.device: def norm(self, p: Optional[Union[int, float]] = None, **kwargs) -> torch.Tensor: """Returns the norm of the data.""" - return torch.linalg.norm(self.data, ord=p, **kwargs) + return torch.linalg.norm(self.data, p, **kwargs) @property def L2norm(self) -> torch.Tensor: @@ -613,12 +613,19 @@ def to(self, *args, **kwargs): def __getitem__(self, index): """Allows indexing into the GridVariable like a tensor.""" # This is necessary to ensure that the offset and grid are preserved - # when slicing the data. + # when slicing the data, bc will be removed. new_data = self.data[index] if isinstance(new_data, torch.Tensor): - return GridVariable(new_data, self.offset, self.grid, self.bc) + return GridVariable(new_data, self.offset, self.grid) return new_data + def __setitem__(self, index, value): + """Allows setting items in the GridVariable like a tensor.""" + if isinstance(value, GridVariable): + self.data[index] = value.data + else: + self.data[index] = value + @staticmethod def is_torch_fft_func(func): return getattr(func, "__module__", "").startswith("torch._C._fft") diff --git a/torch_cfd/initial_conditions.py b/torch_cfd/initial_conditions.py index 2170cde..d18a537 100644 --- a/torch_cfd/initial_conditions.py +++ b/torch_cfd/initial_conditions.py @@ -23,7 +23,7 @@ import torch.fft as fft import torch.nn as nn -from torch_cfd import boundaries, grids, pressure +from torch_cfd import boundaries, grids, fvm Grid = grids.Grid GridVariable = grids.GridVariable @@ -122,7 +122,14 @@ def project_and_normalize( grid = grids.consistent_grid_arrays(*v) pressure_bc = boundaries.get_pressure_bc_from_velocity(v) if projection is None: - projection = pressure.PressureProjection(grid, pressure_bc).to(v.device) + is_periodic = all( + [ + boundaries.is_bc_periodic_boundary_conditions(pressure_bc, dim) + for dim in range(grid.ndim) + ] + ) + solver = 'pseudoinverse' if is_periodic else 'cg' + projection = fvm.PressureProjection(grid, pressure_bc, solver=solver, dtype=v.dtype).to(v.device) v, _ = projection(v) vmax = torch.linalg.norm(torch.stack([u.data for u in v]), dim=0).max() v = GridVariableVector( diff --git a/torch_cfd/pressure.py b/torch_cfd/pressure.py deleted file mode 100644 index 762c4e7..0000000 --- a/torch_cfd/pressure.py +++ /dev/null @@ -1,365 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Modifications copyright (C) 2024 S.Cao -# ported Google's Jax-CFD functional template to PyTorch's tensor ops - -"""Functions for computing and applying pressure.""" - -from functools import reduce, partial -from typing import List, Optional, Tuple, Union - -import torch -import torch.fft as fft -import torch.nn as nn - -from torch_cfd import ( - boundaries, - finite_differences as fdm, - grids, -) - -GridVariable = grids.GridVariable -GridVariableVector = grids.GridVariableVector -BoundaryConditions = grids.BoundaryConditions - -def _set_laplacian(module: nn.Module, laplacians: torch.Tensor, grid: grids.Grid, bc: BoundaryConditions): - """ - Initialize the 1D Laplacian operators with ndim - Args: - laplacians have the shape (ndim, n, n) - """ - if laplacians is None: - laplacians = fdm.set_laplacian_matrix(grid, bc) - laplacians = torch.stack(laplacians, dim=0) - else: - # Check if the provided laplacians are consistent with the grid - for laplacian in laplacians: - if laplacian.shape != grid.shape: - raise ValueError("Provided laplacians do not match the grid shape.") - module.register_buffer("laplacians", laplacians, persistent=True) - - -class PressureProjection(nn.Module): - def __init__( - self, - grid: grids.Grid, - bc: BoundaryConditions, - dtype: Optional[torch.dtype] = torch.float32, - implementation: Optional[str] = None, - laplacians: Optional[torch.Tensor] = None, - initial_guess_pressure: Optional[GridVariable] = None, - ): - """ - Args: - grid: Grid object describing the spatial domain. - bc: Boundary conditions for the Laplacian operator (for pressure). - dtype: Tensor data type. For consistency purpose. - implementation: One of ['fft', 'rfft', 'matmul']. - circulant: If True, bc is periodical - laplacians: Precomputed Laplacian operators. If None, they are computed from the grid during initiliazation. - initial_guess_pressure: Initial guess for pressure. If None, a zero tensor is used. - """ - super().__init__() - self.grid = grid - self.bc = bc - self.dtype = dtype - self.implementation = implementation - _set_laplacian(self, laplacians, grid, bc) - - self.solver = Pseudoinverse( - grid=grid, - bc=bc, - dtype=dtype, - hermitian=True, - implementation=implementation, - laplacians=self.laplacians - ) - if initial_guess_pressure is None: - initial_guess_pressure = GridVariable( - torch.zeros(grid.shape), grid.cell_center, grid - ) - self.q0 = bc.impose_bc(initial_guess_pressure) - - def forward(self, v: GridVariableVector) -> Tuple[GridVariableVector, GridVariable]: - """Project velocity to be divergence-free.""" - _ = grids.consistent_grid(self.grid, *v) - pressure_bc = boundaries.get_pressure_bc_from_velocity(v) - - rhs = fdm.divergence(v) - rhs_transformed = self.rhs_transform(rhs, pressure_bc) - rhs_inv = self.solver(rhs_transformed) - q = GridVariable(rhs_inv, rhs.offset, rhs.grid) - q = pressure_bc.impose_bc(q) - q_grad = fdm.forward_difference(q) - v_projected = GridVariableVector( - tuple(u.bc.impose_bc(u.array - q_g) for u, q_g in zip(v, q_grad)) - ) - return v_projected, q - - - @staticmethod - def rhs_transform( - u: GridVariable, - bc: BoundaryConditions, - ) -> torch.Tensor: - """Transform the RHS of pressure projection equation for stability.""" - u_data = u.data # (b, n, m) or (n, m) - for axis in range(u.grid.ndim): - if ( - bc.types[axis][0] == boundaries.BCType.NEUMANN - and bc.types[axis][1] == boundaries.BCType.NEUMANN - ): - # Check if we have batched data - if u_data.ndim > u.grid.ndim: - # For batched data, calculate mean separately for each batch - # Keep the batch dimension, reduce over grid dimensions - dims = tuple(range(1, u_data.ndim)) - mean = torch.mean(u_data, dim=dims, keepdim=True) - else: - # For non-batched data, calculate global mean - mean = torch.mean(u_data) - - u_data = u_data - mean - return u_data - - @property - def inverse(self) -> torch.Tensor: - return self.solver.inverse - - @property - def laplacians(self) -> torch.Tensor: - return self.solver.laplacians - -class Pseudoinverse(nn.Module): - def __init__( - self, - grid: grids.Grid, - bc: Optional[BoundaryConditions] = None, - dtype: torch.dtype = torch.float32, - hermitian: bool = True, - circulant: bool = True, - implementation: Optional[str] = None, - laplacians: Optional[torch.Tensor] = None, - cutoff: Optional[float] = None, - ): - r""" - This class applies the pseudoinverse of the Laplacian operator on a given Grid. - This class re-implements to Jax-cfd's function_call type implementations - - _hermitian_matmul_transform() - - _circulant_fft_transform() - - _circulant_rfft_transform() - in the fast_diagonalization.py: - https://github.com/google/jax-cfd/blob/main/jax_cfd/base/fast_diagonalization.py - to PyTorch's tensor ops using nn.Module. - - The application of a linear operator (the inverse of Laplacian) - can be written as a sum of operators on each axis. - Such linear operators are *separable*, and can be written as a sum of tensor - products, e.g., `operators = [A, B]` corresponds to the linear operator - A ⊗ I + I ⊗ B, where the tensor product ⊗ indicates a separation between - operators applied along the first and second axis. - - This function computes matrix-valued functions of such linear operators via - the "fast diagonalization method" [1]: - F(A ⊗ I + I ⊗ B) - = (X(A) ⊗ X(B)) F(Λ(A) ⊗ I + I ⊗ Λ(B)) (X(A)^{-1} ⊗ X(B)^{-1}) - - where X(A) denotes the matrix of eigenvectors of A and Λ(A) denotes the - (diagonal) matrix of eigenvalues. The function `F` is easy to compute in - this basis, because matrix Λ(A) ⊗ I + I ⊗ Λ(B) is diagonal. - - The current implementation directly diagonalizes dense matrices for each - linear operator, which limits it's applicability to grids with less than - 1e3-1e4 elements per side (~1 second to several minutes of setup time). - - Example: The Laplacian operator can be written as a sum of 1D Laplacian - operators along each axis, i.e., as a sum of 1D convolutions along each axis. - This can be seen mathematically (∇² = ∂²/∂x² + ∂²/∂y² + ∂²/∂z²) or by - decomposing the 2D kernel: - - [0 1 0] [ 1] - [1 -4 1] = [1 -2 1] ⊕ [-2] - [0 1 0] [ 1] - - Args: - grid: Grid object describing the spatial domain. - bc: Boundary conditions for the Laplacian operator (for pressure). - dtype: Tensor data type. - hermitian: hermitian: whether or not all linear operator are Hermitian (i.e., symmetric in the real valued case). - circulant: If True, bc is periodical - implementation: One of ['fft', 'rfft', 'matmul']. - cutoff: Minimum eigenvalue to invert. - laplacians: Precomputed Laplacian operators. If None, they are computed from the grid during initiliazation. - - - implementation: how to implement fast diagonalization. Default uses 'rfft' - for grid size larger than 1024 and 'matmul' otherwise: - - 'matmul': scales like O(N**(d+1)) for d N-dimensional operators, but - makes good use of matmul hardware. Requires hermitian=True. - - 'fft': scales like O(N**d * log(N)) for d N-dimensional operators. - Requires circulant=True. - - 'rfft': use the RFFT instead of the FFT. This is a little faster than - 'fft' but also has slightly larger error. It currently requires an even - sized last axis and circulant=True. - precision: numerical precision for matrix multplication. Only relevant on - TPUs with implementation='matmul'. - - Returns: - The pseudoinverse of the Laplacian operator acting on the input tensor. - - TODO: - - [x] change the implementation to tensor2tensor - - [x] originally the laplacian is implemented as - laplacians = array_utils.laplacian_matrix_w_boundaries(rhs.grid, rhs.offset, pressure_bc), needs to add this wrapper to support non-periodic BCs. (May 2025): now this is passed by fdm.set_laplacian_matrix - - [x] add the precomputation to the eigenvalues - - References: - [1] Lynch, R. E., Rice, J. R. & Thomas, D. H. Direct solution of partial - difference equations by tensor product methods. Numer. Math. 6, 185-199 - (1964). https://paperpile.com/app/p/b7fdea4e-b2f7-0ada-b056-a282325c3ecf - - """ - super().__init__() - self.grid = grid - self.bc = bc - - if self.bc is None: - self.bc = boundaries.periodic_boundary_conditions(ndim=grid.ndim) - - self.cutoff = cutoff or 10 * torch.finfo(dtype).eps - - self.hermitian = hermitian - self.circulant = circulant - self.implementation = implementation - _set_laplacian(self, laplacians, grid, self.bc) - - - if implementation is None: - self.implementation = "rfft" - self.circulant = True - if implementation == "rfft" and self.laplacians[-1].shape[0] % 2: - self.implementation = "matmul" - self.circulant = False - - if self.implementation == "rfft": - self.ifft = partial(fft.irfftn, s=grid.shape) - self.fft = partial(fft.rfftn, dim=tuple(range(-grid.ndim, 0))) - elif self.implementation == "fft": - self.ifft = partial(fft.ifftn, s=grid.shape) - self.fft = partial(fft.fftn, dim=tuple(range(-grid.ndim, 0))) - if self.implementation not in ("fft", "rfft", "matmul"): - raise NotImplementedError(f"Unsupported implementation: {implementation}") - - self._compute_eigenvalues() - - if self.implementation in ("fft", "rfft"): - if not self.circulant: - raise ValueError( - f"non-circulant operators not yet supported with implementation='fft' or 'rfft' " - ) - self._forward = self._apply_in_frequency_space - elif self.implementation == "matmul": - if not self.hermitian: - raise ValueError( - "matmul implementation requires hermitian=True. " - "Use fft or rfft for non-hermitian operators." - ) - self._forward = self._apply_in_svd_space - - def forward(self, value: torch.Tensor) -> torch.Tensor: - """ - Apply the pseudoinverse (with a cutoff) Laplacian operator to the input tensor. - - Args: - value: right-hand-side of the linear operator. This is a tensor with `len(operators)` dimensions, where each dimension corresponds to one of the linear operators. - """ - return self._forward(value, self.inverse) - - @staticmethod - def outer_sum(x: Union[List[torch.Tensor], Tuple[torch.Tensor]]) -> torch.Tensor: - """ - Returns the outer sum of a list of one dimensional arrays - Example: - x = [a, b, c] - out = a[..., None, None] + b[..., None] + c - - The full outer sum is equivalent to: - def _sum(a, b): - return a[..., None] + b - return reduce(_sum, x) - """ - - return reduce(lambda a, b: a[..., None] + b, x) - - def _compute_eigenvalues(self): - """ - Precompute the Laplacian eigenvalues on the Grid mesh. - """ - eigenvalues = torch.tensor([1.0] * self.grid.ndim) - eigenvectors = torch.tensor([1.0] * self.grid.ndim) - if self.implementation == "fft": - eigenvalues = [fft.fft(op[:, 0]) for op in self.laplacians] - elif self.implementation == "rfft": - eigenvalues = [fft.fft(op[:, 0]) for op in self.laplacians[:-1]] + [ - fft.rfft(self.laplacians[-1][:, 0]) - ] - elif self.implementation == "matmul": - eigenvalues, eigenvectors = zip(*map(torch.linalg.eigh, self.laplacians)) - else: - raise NotImplementedError( - f"Unsupported implementation: {self.implementation} and eigenvalues are not precomputed." - ) - summed_eigenvalues = self.outer_sum(eigenvalues) - inverse_eigvs = torch.asarray( - self._filter_eigenvalues(summed_eigenvalues) - ) - - - if inverse_eigvs.shape != summed_eigenvalues.shape: - raise ValueError( - "output shape from func() does not match input shape: " - f"{inverse_eigvs.shape} vs {summed_eigenvalues.shape}" - ) - self.register_buffer("inverse", inverse_eigvs, persistent=True) - self.register_buffer("eigenvectors", eigenvectors, persistent=True) - - def _filter_eigenvalues(self, eigenvalues: torch.Tensor) -> torch.Tensor: - """ - Apply a cutoff function to the eigenvalues. - """ - return torch.where(torch.abs(eigenvalues) > self.cutoff, 1 / eigenvalues, 0) - - def _apply_in_frequency_space( - self, v: torch.Tensor, multiplier: torch.Tensor - ) -> torch.Tensor: - """ - Apply the inverse in frequency domain and return to real space. - """ - return self.ifft(multiplier * self.fft(v)).real - - def _apply_in_svd_space( - self, v: torch.Tensor, multiplier: torch.Tensor - ) -> torch.Tensor: - """ - Apply the inverse in SVD space and return to real space. - """ - assert self.implementation == "matmul" - out = v - for vectors in self.eigenvectors: - out = torch.tensordot(out, vectors, dims=(0, 0)) - out *= multiplier - for vectors in self.eigenvectors: - out = torch.tensordot(out, vectors, dims=(0, 1)) - return out diff --git a/torch_cfd/solvers.py b/torch_cfd/solvers.py new file mode 100644 index 0000000..dae84a9 --- /dev/null +++ b/torch_cfd/solvers.py @@ -0,0 +1,925 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modifications copyright (C) 2025 S.Cao +# ported Google's Jax-CFD functional template to PyTorch's tensor ops + +"""Collections of linear system solvers.""" + +from functools import partial, reduce +from typing import Any, Callable, List, Optional, Sequence + +import torch +import torch.fft as fft +import torch.nn as nn + +from torch_cfd import boundaries, finite_differences as fdm, grids + +Grid = grids.Grid +GridVariable = grids.GridVariable +GridVariableVector = grids.GridVariableVector +BoundaryConditions = grids.BoundaryConditions + + +def _set_laplacian( + module: nn.Module, + laplacians: List[torch.Tensor] | None, + grid: Grid, + bc: BoundaryConditions, + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, +): + """ + Initialize the 1D Laplacian operators with ndim + Args: + laplacians have the shape (ndim, n, n) + """ + if laplacians is None: + laplacians = fdm.set_laplacian_matrix(grid, bc, device, dtype) + else: + # Check if the provided laplacians are consistent with the grid + for laplacian in laplacians: + if laplacian.shape != grid.shape: + raise ValueError("Provided laplacians do not match the grid shape.") + + # Register each laplacian separately since they may have different sizes + for i, laplacian in enumerate(laplacians): + module.register_buffer(f"laplacian_{i}", laplacian, persistent=True) + + +def outer_sum(x: Sequence[torch.Tensor]) -> torch.Tensor: + """ + Returns the outer sum of a list of one dimensional arrays + Example: + x = [a, b, c] + out = a[..., None, None] + b[..., None] + c + + The full outer sum is equivalent to: + def _sum(a, b): + return a[..., None] + b + return reduce(_sum, x) + """ + return reduce(lambda a, b: a[..., None] + b, x) + + +class Identity(nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + + def forward(self, x: torch.Tensor, *args) -> torch.Tensor: + return x + + +class SolverBase(nn.Module): + """ + Base class for solvers. This class defines the interface for solvers that apply a linear operator equation on a 2D grid. + + Args: + grid: Grid object describing the spatial domain. + bc: Boundary conditions for the Laplacian operator (for pressure). + dtype: Tensor data type. + laplacians: Precomputed Laplacian operators. If None, they are computed from + the grid during initialization. + tol: Tolerance for filtering eigenvalues in the pseudoinverse/iterative solver's rel residual. + """ + + def __init__( + self, + grid: grids.Grid, + bc: BoundaryConditions | None = None, + dtype: torch.dtype = torch.float32, + laplacians: Optional[List[torch.Tensor]] = None, + tol: float = 1e-8, + **kwargs, + ): + super().__init__() + self.grid = grid + if bc is None: + bc = boundaries.periodic_boundary_conditions(ndim=grid.ndim) + self.bc = bc + self.ndim = grid.ndim + self.tol = tol + self.dtype = dtype + _set_laplacian(self, laplacians, grid, bc, dtype=dtype) + self._compute_inverse_diagonals() + + @property + def operators(self) -> List[torch.Tensor]: + """Get the list of 1D Laplacian operators.""" + return [getattr(self, f"laplacian_{i}") for i in range(self.ndim)] + + def _compute_inverse_diagonals(self): + """ + Precompute the inverse diagonals of the Laplacian operator on the Grid mesh. Must be implemented by subclasses. + + For PseudoInverse class, the diagonals are in FFT/SVD spaces, which corresponds to the eigenvalues of the Laplacian operator. + For IterativeSolver class, this is simply the inverse diagonal of the original 1D Laplacian operators. + """ + raise NotImplementedError( + "Subclasses must implement _compute_inverse_diagonals" + ) + + def forward(self, f: torch.Tensor, q0: torch.Tensor) -> torch.Tensor: + """ + For PseudoInverseBase: apply the pseudoinverse (with a cutoff) Laplacian operator to the input tensor. + For IterativeSolverBase: solve the linear system Au = f, where A is the Laplacian operator and f is the right-hand side, q0 is the initial guess. + + Args: + value: right-hand-side of the linear operator. This is a tensor with `len(operators)` dimensions, where each dimension corresponds to one of the linear operators. + q0: initial guess for the solution. Not used in PseudoInverseBase, but may be used in IterativeSolverBase. + + Returns: + A^{*} rhs, where A^{*} is either the pseudoinverse of the Laplacian operator (eigen-expansion with a cut-off) or the iterative solver's A_h^{-1}'s approximation. + """ + raise NotImplementedError("Subclasses must implement forward method") + + def solve(self, f: torch.Tensor, q0: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("Subclasses must implement solve method") + + +class PseudoInverseBase(SolverBase): + """ + Base class for pseudoinverse of the Laplacian operator on a given Grid. + + This class applies the pseudoinverse of the Laplacian operator using the + "fast diagonalization method" for separable linear operators. + + The application of a linear operator (the inverse of Laplacian) + can be written as a sum of operators on each axis. + Such linear operators are *separable*, and can be written as a sum of tensor + products, e.g., `operators = [A, B]` corresponds to the linear operator + A ⊗ I + I ⊗ B, where the tensor product ⊗ indicates a separation between + operators applied along the first and second axis. + + This function computes matrix-valued functions of such linear operators via + the "fast diagonalization method" [1]: + F(A ⊗ I + I ⊗ B) + = (X(A) ⊗ X(B)) F(Λ(A) ⊗ I + I ⊗ Λ(B)) (X(A)^{-1} ⊗ X(B)^{-1}) + + where X(A) denotes the matrix of eigenvectors of A and Λ(A) denotes the + (diagonal) matrix of eigenvalues. The function `F` is easy to compute in + this basis, because matrix Λ(A) ⊗ I + I ⊗ Λ(B) is diagonal. + + References: + [1] Lynch, R. E., Rice, J. R. & Thomas, D. H. Direct solution of partial + difference equations by tensor product methods. Numer. Math. 6, 185-199 + (1964). https://paperpile.com/app/p/b7fdea4e-b2f7-0ada-b056-a282325c3ecf + """ + + def __init__( + self, + grid: grids.Grid, + bc: Optional[BoundaryConditions] = None, + dtype: torch.dtype = torch.float32, + laplacians: Optional[List[torch.Tensor]] = None, + tol: float = 1e-8, + **kwargs, + ): + super().__init__(grid, bc, dtype, laplacians, tol, **kwargs) + + @property + def eigenvectors(self) -> List[torch.Tensor]: + """Get the list of eigenvector matrices.""" + if hasattr(self, "eigenvectors_0"): + return [getattr(self, f"eigenvectors_{i}") for i in range(self.ndim)] + return [] + + def _filter_eigenvalues(self, eigenvalues: torch.Tensor) -> torch.Tensor: + """ + Apply a cutoff function to the eigenvalues. + """ + return torch.where(torch.abs(eigenvalues) > self.tol, 1 / eigenvalues, 0) + + def solve(self, f: torch.Tensor, q0: torch.Tensor) -> torch.Tensor: + return self.forward(f, q0) + + +class PseudoInverseFFT(PseudoInverseBase): + """ + PseudoInverse implementation using complex FFT. + + This implementation uses standard FFT for complex-valued computations. + Requires circulant operators. + Scales like O(N**d * log(N)) for d N-dimensional operators. + """ + + def __init__( + self, + grid: grids.Grid, + bc: Optional[BoundaryConditions] = None, + dtype: torch.dtype = torch.float32, + laplacians: Optional[List[torch.Tensor]] = None, + tol: float = 1e-8, + **kwargs, + ): + super().__init__(grid, bc, dtype, laplacians, tol, **kwargs) + + self.ifft = partial(fft.ifftn, s=grid.shape) + self.fft = partial(fft.fftn, dim=tuple(range(-grid.ndim, 0))) + + def _compute_inverse_diagonals(self): + """ + Precompute eigenvalues using FFT for FFT implementation. + """ + eigenvalues = [fft.fft(op[:, 0]) for op in self.operators] + + summed_eigenvalues = outer_sum(eigenvalues) + inverse_eigvs = torch.asarray(self._filter_eigenvalues(summed_eigenvalues)) + + if inverse_eigvs.shape != summed_eigenvalues.shape: + raise ValueError( + "output shape from func() does not match input shape: " + f"{inverse_eigvs.shape} vs {summed_eigenvalues.shape}" + ) + self.register_buffer("inverse_diag", inverse_eigvs, persistent=True) + + def forward(self, f: torch.Tensor, q0: torch.Tensor) -> torch.Tensor: + """ + Apply the inverse in frequency domain and return to real space. + """ + + return self.ifft(self.inverse_diag * self.fft(f)).real + + +class PseudoInverseRFFT(PseudoInverseFFT): + """ + PseudoInverse implementation using Real FFT. + + This implementation uses RFFT for faster computation with real-valued data. + Requires circulant operators and an even-sized last axis. + Scales like O(N**d * log(N)) for d N-dimensional operators. + """ + + def __init__( + self, + grid: grids.Grid, + bc: Optional[BoundaryConditions] = None, + dtype: torch.dtype = torch.float32, + laplacians: Optional[List[torch.Tensor]] = None, + tol: float = 1e-8, + **kwargs, + ): + super().__init__(grid, bc, dtype, laplacians, tol, **kwargs) + + if grid.shape[-1] % 2: + raise ValueError("RFFT implementation requires even-sized last axis") + + self.ifft = partial(fft.irfftn, s=grid.shape) + self.fft = partial(fft.rfftn, dim=tuple(range(-grid.ndim, 0))) + + def _compute_inverse_diagonals(self): + """ + Precompute eigenvalues using FFT for RFFT implementation. + """ + eigenvalues = [fft.fft(op[:, 0]) for op in self.operators[:-1]] + [ + fft.rfft(self.operators[-1][:, 0]) + ] + + summed_eigenvalues = outer_sum(eigenvalues) + inverse_eigvs = torch.asarray(self._filter_eigenvalues(summed_eigenvalues)) + + if inverse_eigvs.shape != summed_eigenvalues.shape: + raise ValueError( + "output shape from func() does not match input shape: " + f"{inverse_eigvs.shape} vs {summed_eigenvalues.shape}" + ) + self.register_buffer("inverse_diag", inverse_eigvs, persistent=True) + + +class PseudoInverseMatmul(PseudoInverseBase): + """ + PseudoInverse implementation using matrix multiplication in eigenspace. + + This implementation directly diagonalizes dense matrices for each linear operator. + Requires hermitian operators. + Scales like O(N**(d+1)) for d N-dimensional operators, but makes good use of matmul hardware. + """ + + def __init__( + self, + grid: grids.Grid, + bc: Optional[BoundaryConditions] = None, + dtype: torch.dtype = torch.float32, + laplacians: Optional[List[torch.Tensor]] = None, + tol: float = 1e-8, + **kwargs, + ): + super().__init__(grid, bc, dtype, laplacians, tol, **kwargs) + + def _compute_inverse_diagonals(self): + """ + Precompute eigenvalues and eigenvectors using matrix diagonalization. + """ + eigenvalues, eigenvectors = zip(*map(torch.linalg.eigh, self.operators)) + + summed_eigenvalues = outer_sum(eigenvalues) + inverse_eigvs = torch.asarray(self._filter_eigenvalues(summed_eigenvalues)) + + if inverse_eigvs.shape != summed_eigenvalues.shape: + raise ValueError( + "output shape from func() does not match input shape: " + f"{inverse_eigvs.shape} vs {summed_eigenvalues.shape}" + ) + self.register_buffer("inverse_diag", inverse_eigvs, persistent=True) + + # Register eigenvectors + for i, evecs in enumerate(eigenvectors): + self.register_buffer(f"eigenvectors_{i}", evecs, persistent=True) + + def forward(self, f: torch.Tensor, q0: torch.Tensor) -> torch.Tensor: + """ + Apply the inverse in SVD space and return to real space. + """ + out = f + # Forward transform: contract along spatial dimensions from the end + for vectors in self.eigenvectors: + out = torch.tensordot(out, vectors, dims=([-2], [0])) # type: ignore + out *= torch.as_tensor(self.inverse_diag, dtype=out.dtype) + # Inverse transform: contract along spatial dimensions from the end + for vectors in self.eigenvectors: + out = torch.tensordot(out, vectors, dims=([-2], [1])) # type: ignore + + return out + + +class PseudoInverse(nn.Module): + """ + Factory class for creating PseudoInverse solvers with different implementations. + + This class automatically selects the appropriate implementation based on the + parameters and grid properties, or creates a specific implementation if requested. + + Args: + grid: Grid object describing the spatial domain. + bc: Boundary conditions for the Laplacian operator (for pressure). + dtype: Tensor data type. + hermitian: whether or not all linear operator are Hermitian (i.e., symmetric in the real valued case). + circulant: If True, bc is periodical + implementation: One of ['fft', 'rfft', 'matmul']. If None, automatically selects based on grid properties. + cutoff: Minimum eigenvalue to invert. + laplacians: Precomputed Laplacian operators. If None, they are computed from the grid during initialization. + + implementation: how to implement fast diagonalization: + - 'matmul': scales like O(N**(d+1)) for d N-dimensional operators, but + makes good use of matmul hardware. Requires hermitian=True. + - 'fft': scales like O(N**d * log(N)) for d N-dimensional operators. + Requires circulant=True. + - 'rfft': use the RFFT instead of the FFT. This is a little faster than + 'fft' but also has slightly larger error. It currently requires an even + sized last axis and circulant=True. + + Returns: + An instance of the appropriate PseudoInverse implementation. + """ + + def __new__( + cls, + grid: grids.Grid, + bc: Optional[BoundaryConditions] = None, + dtype: torch.dtype = torch.float32, + hermitian: bool = True, + circulant: bool = True, + implementation: Optional[str] = None, + laplacians: Optional[List[torch.Tensor]] = None, + cutoff: float = 1e-8, + **kwargs, + ): + # Auto-select implementation if not specified + if implementation is None: + implementation = "rfft" if circulant else "matmul" + + # if the last axis is odd, we cannot use rfft + if implementation == "rfft" and grid.shape[-1] % 2: + implementation = "fft" if circulant else "matmul" + + # Validate implementation requirements + if implementation in ["rfft", "fft"] and not circulant: + raise ValueError( + f"non-circulant operators not yet supported with implementation='{implementation}'" + ) + if implementation in ["matmul", "svd"] and not hermitian: + raise ValueError("matmul implementation requires hermitian=True. ") + + # Create the appropriate implementation + if implementation == "rfft": + return PseudoInverseRFFT(grid, bc, dtype, laplacians, cutoff, **kwargs) + elif implementation == "fft": + return PseudoInverseFFT(grid, bc, dtype, laplacians, cutoff, **kwargs) + elif implementation == "matmul": + return PseudoInverseMatmul(grid, bc, dtype, laplacians, cutoff, **kwargs) + else: + raise NotImplementedError(f"Unsupported implementation: {implementation}") + + +class IterativeSolver(SolverBase): + """ + Base class for iterative solvers that apply a separable Laplacian + operator to a tensor `u` of shape (..., *grid.shape) and solve + the linear system Au = f. + + Args: + grid: Grid object describing the spatial domain. + bc: Boundary conditions for the Laplacian operator (for pressure). + dtype: Tensor data type. + laplacians: Precomputed Laplacian operators. If None, they are computed from + the grid during initialization. + tol: Tolerance for the iterative solver's relative residual. + max_iter: Maximum number of iterations for the iterative solver. + check_iter: Frequency of checking the residual norm during iterations. + record_residuals: If True, record the residual norms during iterations. + """ + + def __init__( + self, + grid: Grid, + bc: BoundaryConditions, + dtype: torch.dtype = torch.float64, + laplacians: Optional[List[torch.Tensor]] = None, + tol: float = 1e-5, + max_iter: int = 1000, + check_iter: int = 10, + record_residuals: bool = False, + ): + super().__init__(grid, bc, dtype, laplacians, tol) + + self.max_iter = max_iter + self.stop_iter = max_iter + self.check_iter = check_iter + self.record_residuals = record_residuals + self.residual_norms = [1.0] # relative residual + + def _compute_inverse_diagonals(self): + # inverse diagonal of sum of 1D ops + self.eps = 1e-10 # small value to avoid division by zero + diag = outer_sum([torch.diag(op) for op in self.operators]) + inv_diag = torch.where( + torch.abs(diag) > self.eps, + 1.0 / diag + self.eps, + torch.zeros_like(diag), + ) + self.register_buffer("inverse_diag", inv_diag) + + @property + def residual_norms(self): + return torch.tensor(self._residual_norms) + + @residual_norms.setter + def residual_norms(self, value): + if not hasattr(self, "_residual_norms"): + self._residual_norms = [] + if isinstance(value, (torch.Tensor, float)): + self._residual_norms.append(value) + elif isinstance(value, list): + self._residual_norms.extend(value) + else: + raise TypeError("Residual norms must be a tensor or a list of tensors.") + + def expand_as(self, inp: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Expand tensor in-place for broadcasting with x.""" + if target.ndim > self.ndim: + return inp[(slice(None),) + (None,) * self.ndim] + return inp + + def _apply_laplacian(self, u, operators: Optional[List[torch.Tensor]] = None): + """ + Apply the separable 2D Laplacian: Au = Lx @ u + u @ Ly.T + """ + ndim = self.grid.ndim + out = torch.zeros_like(u.data) + operators = self.operators if operators is None else operators + data = u.data + for i, lap in enumerate(operators): + dim = i - ndim + _out = torch.tensordot(data, lap, dims=([dim], [-1])) # type: ignore + out += _out.transpose( + dim, -1 + ) # swap the first axis to the correct position + return out + + def residual(self, f, u): + return f - self._apply_laplacian(u) + + def forward(self, f, u, *args, **kwargs) -> torch.Tensor: + """ + Perform a single iteration step of the iterative solver. + + Args: + f: Right-hand side tensor. + u: Current solution tensor. + + Returns: + Updated solution tensor after one iteration step. u <- u + M(b - Au) + """ + raise NotImplementedError("forward method must be implemented in subclasses.") + + def solve( + self, + f: torch.Tensor, + u: Optional[torch.Tensor] = None, + iters: Optional[int] = None, + ) -> torch.Tensor: + u = torch.zeros_like(f) if u is None else u + f_norm = torch.linalg.norm(f) + iters = self.max_iter if iters is None else iters + for i in range(iters): + u_new = self.forward(f, u) + if i % self.check_iter == 0: + res_norm = torch.linalg.norm(self.residual(f, u_new)) / f_norm + if self.record_residuals: + self.residual_norms = res_norm.item() + if res_norm < self.tol: + self.stop_iter = i + 1 + break + u = u_new + return u + + +class Jacobi(IterativeSolver): + def __init__( + self, + grid: Grid, + bc: BoundaryConditions, + dtype: torch.dtype = torch.float64, + laplacians: Optional[List[torch.Tensor]] = None, + tol: float = 1e-5, + max_iter: int = 1000, + check_iter: int = 10, + record_residuals: bool = False, + pure_neumann: bool = False, + interior_only: bool = False, + ): + super().__init__( + grid, bc, dtype, laplacians, tol, max_iter, check_iter, record_residuals + ) + self.pure_neumann = pure_neumann + self.interior_only = interior_only + self._set_up_masks() + + def update( + self, f: torch.Tensor, u: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + unsqueezed = False + if u.ndim == self.ndim: + unsqueezed = True + u = u.unsqueeze(0) + u_new = u.clone() + u_neighbors = torch.zeros( + *u.shape, 2 * self.ndim, dtype=u.dtype, device=u.device + ) + mask = mask.to(u.device) + # idx = 0 left idx = 1 right + # idx = 2 down idx = 3 up + Lx, Ly = self.operators + u_neighbors[..., 1:, :, 0] = Lx.diagonal(-1)[:, None] * u[..., :-1, :] + u_neighbors[..., :-1, :, 1] = Lx.diagonal(1)[:, None] * u[..., 1:, :] + u_neighbors[..., 1:, 2] = Ly.diagonal(-1)[None, :] * u[..., :, :-1] + u_neighbors[..., :-1, 3] = Ly.diagonal(1)[None, :] * u[..., :, 1:] + update = f - u_neighbors.sum(dim=-1) + Dinv = torch.as_tensor(self.inverse_diag).expand_as(update) + mask = mask.expand_as(update) + u_new[mask] = Dinv[mask] * update[mask] + return u_new.squeeze(0) if unsqueezed else u_new + + def _set_up_masks(self): + self.masks = [] + mask = torch.zeros(self.grid.shape, dtype=torch.bool) + if self.interior_only: + mask[..., 1:-1, 1:-1] = True + else: + mask[:] = True + self.masks.append(mask) + + def forward(self, f: torch.Tensor, u: torch.Tensor) -> torch.Tensor: + """Jacobi iteration using explicit neighbor contributions""" + for mask in self.masks: + u = self.update(f, u, mask) + if self.pure_neumann: + u -= torch.mean( + u, dim=(-2, -1), keepdim=True + ) # Remove mean for pure Neumann BC + return u + + +class GaussSeidel(Jacobi): + """ + Gauss-Seidel iterative solver for Laplacian systems. + + The implementation uses red-black ordering to update + the solution with a mask inherited from the Jacobi class. + + Reference: Long Chen's notes on vectorizing finite difference methods in MATLAB + https://www.math.uci.edu/~chenlong/226/FDMcode.pdf + """ + + def __init__( + self, + grid: Grid, + bc: BoundaryConditions, + dtype: torch.dtype = torch.float64, + laplacians: Optional[List[torch.Tensor]] = None, + tol: float = 1e-5, + max_iter: int = 1000, + check_iter: int = 10, + record_residuals: bool = False, + pure_neumann: bool = False, + ): + super().__init__( + grid, + bc, + dtype, + laplacians, + tol, + max_iter, + check_iter, + record_residuals, + pure_neumann, + ) + + def _set_up_masks(self): + nx, ny = self.grid.shape + ix = torch.arange(nx, device=self.grid.device) + iy = torch.arange(ny, device=self.grid.device) + ix, iy = torch.meshgrid(ix, iy, indexing="ij") + red_mask = (ix + iy) % 2 == 0 + black_mask = ~red_mask + self.masks = [red_mask, black_mask] + + def forward(self, f: torch.Tensor, u: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Gauss-Seidel iteration using explicit neighbor contributions. + The method alternates between red and black masks to update the solution. + """ + return super().forward(f, u, **kwargs) + + +class ConjugateGradient(IterativeSolver): + def __init__( + self, + grid: Grid, + bc: BoundaryConditions, + dtype: torch.dtype = torch.float64, + laplacians: Optional[List[torch.Tensor]] = None, + tol: float = 1e-6, + max_iter: int = 1000, + check_iter: int = 10, + record_residuals: bool = False, + pure_neumann: bool = False, + preconditioner: Optional[str | Callable] = None, + ): + super().__init__( + grid, bc, dtype, laplacians, tol, max_iter, check_iter, record_residuals + ) + self.pure_neumann = pure_neumann + # setup preconditioner + if isinstance(preconditioner, str): + if preconditioner == "jacobi": + self.preconditioner = Jacobi( + grid, bc, laplacians=laplacians, max_iter=1 + ) + elif preconditioner in ["gauss_seidel", "gs"]: + self.preconditioner = GaussSeidel( + grid, bc, laplacians=laplacians, max_iter=1 + ) + self.inv_diag = self.preconditioner.inverse_diag + elif callable(preconditioner): + self.preconditioner = preconditioner + elif preconditioner is None: + self.preconditioner = Identity() + else: + raise NotImplementedError( + f"Preconditioner {preconditioner} not implemented." + ) + + def apply_preconditioner(self, r): + """ + Apply the preconditioner to the residual r. + If the preconditioner is an Identity, it returns r unchanged. + """ + return self.preconditioner.forward(r, torch.zeros_like(r)) + + def forward(self, u, r, p, rsold): + """One step of preconditioned conjugate gradient iteration.""" + Ap = self._apply_laplacian(p) + + # Use negative indexing to handle both batched and non-batched cases + spatial_dims = tuple(range(-self.ndim, 0)) + + pAp = torch.sum(p * Ap, dim=spatial_dims) + alpha = rsold / (pAp + self.eps) + + alpha = self.expand_as(alpha, u) + + u += alpha * p + r -= alpha * Ap + + z = self.apply_preconditioner(r) + # z = self.inv_diag.expand_as(r) * r + + rznew = torch.sum(r * z, dim=spatial_dims) + beta = rznew / (rsold + self.eps) + + beta = self.expand_as(beta, u) + + p = z + beta * p + + return u, r, p, rznew + + def solve(self, f, u=None): + u = torch.zeros_like(f) if u is None else u + res = self.residual(f, u) + z = self.apply_preconditioner(res) + p = z.clone() + + spatial_dims = tuple(range(-self.ndim, 0)) + f_norm = torch.linalg.norm(f) + rdotz = torch.sum(res * z, dim=spatial_dims) + for i in range(1, self.max_iter + 1): + u, res, p, rdotz = self.forward(u, res, p, rdotz) + if self.pure_neumann: + u -= torch.mean(u, dim=(-2, -1), keepdim=True) + if i % self.check_iter == 0: + residual_norm = torch.linalg.norm(res) / f_norm + if residual_norm >= self.residual_norms[-1]: + self.preconditioner = Identity() + if self.record_residuals: + self.residual_norms = residual_norm.item() + if residual_norm < self.tol: + self.stop_iter = i + break + return u + + +class MultigridSolver(IterativeSolver): + """ + Multilevel V-cycle multigrid solver for Neumann Laplacian (pressure projection). + + On MAC grids, the velocity's multigrid needs specially designed prolongation and restriction operators. + + References: + Long Chen's lecture notes on MAC grids and how to implement the multigrid for the Stokes system: + https://www.math.uci.edu/~chenlong/226/MACcode.pdf + """ + + def __init__( + self, + grid: Grid, + bc: BoundaryConditions, + dtype: torch.dtype = torch.float64, + laplacians: Optional[List[torch.Tensor]] = None, + tol: float = 1e-6, + max_iter: int = 10, + check_iter: int = 1, + levels: int = 2, + pre_smooth: int = 1, + post_smooth: int = 1, + record_residuals: bool = False, + pure_neumann: bool = False, + ): + super().__init__( + grid, bc, dtype, laplacians, tol, max_iter, check_iter, record_residuals + ) + self.levels = levels + self.pre_smooth = pre_smooth + self.post_smooth = post_smooth + self.pure_neumann = pure_neumann + # build grids hierarchy + self.grids: List[Grid] = [grid] + ops = [laplacians or getattr(self, "operators")] + + for _ in range(1, levels): + # 0 is the finest grid, -1 is the coarsest + prev = self.grids[-1] + shape = tuple(s // 2 for s in prev.shape) + c_grid = Grid(shape, domain=grid.domain, device=grid.device) + self.grids.append(c_grid) + ops.append(fdm.set_laplacian_matrix(c_grid, bc, dtype=dtype)) + self._register_level_operators(ops) + + # smoothers per level (Gauss-Seidel) + self.smoothers = nn.ModuleList( + [ + GaussSeidel( + g, + bc, + dtype, + tol=tol, + laplacians=op, + max_iter=1, + ) + for g, op in zip(self.grids, self.ops) + ] + ) + # precompute coarse operator and direct solver for coarsest level + L0, L1 = self.ops[-1] + I0 = torch.eye(L0.size(0), dtype=dtype) + I1 = torch.eye(L1.size(0), dtype=dtype) + A_coarse = torch.kron(I1, L0) + torch.kron(L1, I0) + self.register_buffer("A_coarse", A_coarse, persistent=True) + + def _register_level_operators(self, operators: List[List[torch.Tensor]]): + """Register operators for a specific level as buffers.""" + for lvl, ops in enumerate(operators): + for i, op in enumerate(ops): + self.register_buffer(f"ops_level_{lvl}_{i}", op, persistent=True) + + def _get_level_operators(self, level: int) -> List[torch.Tensor]: + """Get operators for a specific level.""" + assert ( + 0 <= level < self.levels + ), f"Invalid level: {level}. Must be in range [0, {self.levels - 1}]." + operators = [] + for i in range(self.ndim): + operators.append(getattr(self, f"ops_level_{level}_{i}")) + return operators + + @property + def ops(self) -> List[List[torch.Tensor]]: + """Get all operators for all levels (for backward compatibility).""" + return [self._get_level_operators(lvl) for lvl in range(self.levels)] + + def restrict(self, r): + # full-weighting restriction + return 0.25 * ( + r[..., ::2, ::2] + + r[..., 1::2, ::2] + + r[..., ::2, 1::2] + + r[..., 1::2, 1::2] + ) + + def prolong(self, e): + """ + Prolongation (interpolation) operator - transpose of restriction. + For full-weighting restriction, the adjoint prolongation distributes + each coarse grid value to the 4 corresponding fine grid points. + """ + *batch, nx_c, ny_c = e.shape + nx_f, ny_f = nx_c * 2, ny_c * 2 + up = torch.zeros(*batch, nx_f, ny_f, device=e.device, dtype=e.dtype) + + # Distribute each coarse grid point to the corresponding 2x2 fine grid region + # This is the adjoint of full-weighting restriction + up[..., ::2, ::2] += e # top-left + up[..., 1::2, ::2] += e # top-right + up[..., ::2, 1::2] += e # bottom-left + up[..., 1::2, 1::2] += e # bottom-right + + return up + + def _coarse_solve(self, r): + """ + Direct solve on coarsest grid using precomputed dense operator. + """ + # flatten residual: (batch_size, nx, ny) -> (batch_size, nx*ny) + batch_size, nx, ny = r.shape + r_vec = r.reshape(batch_size, -1) # (batch_size, nx*ny) + r_vec = r_vec.T # (nx*ny, batch_size) + + # solve A @ x = r for each sample in the batch + x_vec = torch.linalg.solve(self.A_coarse, r_vec) # (nx*ny, batch_size) + + x_vec = x_vec.T # (batch_size, nx*ny) + return x_vec.reshape(batch_size, nx, ny) + + def v_cycle(self, level, f, u): + all_ops = self.ops + # pre-smoothing + for _ in range(self.pre_smooth): + u = self.smoothers[level].forward(f, u) + # residual + if level == 0: + r = f - self._apply_laplacian(u) + else: + r = f - self._apply_laplacian(u, operators=all_ops[level]) + if level == self.levels - 1: + # coarsest: direct solve via dense matrix + e = self._coarse_solve(r) + else: + # restrict + rc = self.restrict(r) + ec0 = torch.zeros_like(rc) + ec = self.v_cycle(level + 1, rc, ec0) + e = self.prolong(ec) + # correction + u += e + if self.pure_neumann: + # ensure Neumann BC by removing mean + u -= torch.mean(u, dim=(-2, -1), keepdim=True) + # post-smoothing + for _ in range(self.post_smooth): + u = self.smoothers[level].forward(f, u) + return u + + def forward(self, f, u=None): + u = torch.zeros_like(f) if u is None else u + return self.v_cycle(0, f, u) diff --git a/torch_cfd/tests/test_advection.py b/torch_cfd/tests/test_advection.py index 8c0d033..f147b57 100644 --- a/torch_cfd/tests/test_advection.py +++ b/torch_cfd/tests/test_advection.py @@ -76,6 +76,10 @@ def _unit_velocity(grid, velocity_sign=1.0): ) ) +def _velocity_implicit(grid, offset, u, t): + """Returns solution of a Burgers equation at time `t`.""" + x = grid.mesh(offset)[0] + return grids.GridVariable(torch.sin(x - u * t), offset, grid) def _total_variation(c: GridVariable, dim: int = 0): next_values = c.shift(1, dim) @@ -89,7 +93,85 @@ def _total_variation(c: GridVariable, dim: int = 0): advect_van_leer_using_limiters = advection.AdvectionVanLeer -class AdvectionTestAnalytical(test_utils.TestCase): +class AdvectionTestAnalytical1D(test_utils.TestCase): + @parameterized.named_parameters( + dict( + testcase_name="dirichlet_1d_200_cell_center", + shape=(200,), + offset=0.5, + num_steps=200, + ), + dict( + testcase_name="dirichlet_1d_400_cell_center", + shape=(400,), + offset=0.5, + num_steps=400, + ), + dict( + testcase_name="dirichlet_1d_200_cell_edge_0", + shape=(200,), + offset=0.0, + num_steps=200, + ), + dict( + testcase_name="dirichlet_1d_400_cell_edge_0", + shape=(400,), + offset=0.0, + num_steps=400, + ), + dict( + testcase_name="dirichlet_1d_200_cell_edge_1", + shape=(200,), + offset=1.0, + num_steps=200, + ), + dict( + testcase_name="dirichlet_1d_400_cell_edge_1", + shape=(400,), + offset=1.0, + num_steps=400, + ), + ) + def test_burgers_analytical_dirichlet_convergence( + self, + shape, + offset, + num_steps + ): + def _step_func(v, dt, method): + """ + dt/2 is used because for Burgers equation + the flux is u_t + (0.5*u^2)_x = 0 + """ + dv_dt = method(c=v[0], v=v, dt=dt) / 2 + return (bc.impose_bc(v[0].data + dt * dv_dt),) + + cfl_number = 0.5 + offset = (offset,) + grid = grids.Grid(shape, domain=((0.0, 2 * math.pi),)) + bc = boundaries.dirichlet_boundary_conditions( + grid.ndim, + bc_values=[ + (0.0, 0.0), + ], + ) + v = (bc.impose_bc(_velocity_implicit(grid, offset, 0, 0)),) + dt = 1 / shape[0] # 1 is the time to develope the shock wave + dt *= cfl_number + atol = dt + rtol = cfl_number * 2 * math.pi / shape[0] + advect = advect_van_leer(grid, offset) + + for _ in range(num_steps): + v = _step_func(v, dt, method=advect) + + expected = bc.impose_bc( + _velocity_implicit(grid, offset, v[0].data, dt * num_steps) + ).data + self.assertAllClose(expected, v[0].data, atol=atol, rtol=rtol) + + +class AdvectionTestAnalytical2D(test_utils.TestCase): @parameterized.named_parameters( dict( @@ -154,87 +236,6 @@ def test_advection_analytical( self.assertAllClose(expected.data, ct.data, atol=atol, rtol=rtol) - @parameterized.named_parameters( - dict( - testcase_name="dirichlet_1d_200_cell_center", - shape=(200,), - atol=0.00025, - rtol=1 / 200, - offset=0.5, - ), - dict( - testcase_name="dirichlet_1d_400_cell_center", - shape=(400,), - atol=0.00007, - rtol=1 / 400, - offset=0.5, - ), - dict( - testcase_name="dirichlet_1d_200_cell_edge_0", - shape=(200,), - atol=0.0005, - rtol=1 / 200, - offset=0.0, - ), - dict( - testcase_name="dirichlet_1d_400_cell_edge_0", - shape=(400,), - atol=0.000125, - rtol=1 / 400, - offset=0.0, - ), - dict( - testcase_name="dirichlet_1d_200_cell_edge_1", - shape=(200,), - atol=0.0005, - rtol=1 / 200, - offset=1.0, - ), - dict( - testcase_name="dirichlet_1d_400_cell_edge_1", - shape=(400,), - atol=0.000125, - rtol=1 / 400, - offset=1.0, - ), - ) - def test_burgers_analytical_dirichlet_convergence( - self, - shape, - atol, - rtol, - offset, - ): - def _step_func(v, dt, method): - """ - dt/2 is used because for Burgers equation - the flux is u_t + (0.5*u^2)_x = 0 - """ - dv_dt = method(c=v[0], v=v, dt=dt) / 2 - return (bc.impose_bc(v[0].data + dt * dv_dt),) - - def _velocity_implicit(grid, offset, u, t): - """Returns solution of a Burgers equation at time `t`.""" - x = grid.mesh(offset)[0] - return grids.GridVariable(torch.sin(x - u * t), offset, grid) - - num_steps = 1000 - cfl_number = 0.01 - step = 2 * math.pi / 1000 - offset = (offset,) - grid = grids.Grid(shape, domain=((0.0, 2 * math.pi),)) - bc = boundaries.dirichlet_boundary_conditions(grid.ndim, bc_values=[(0.0, 0.0),]) - v = (bc.impose_bc(_velocity_implicit(grid, offset, 0, 0)),) - dt = cfl_number * step - advect = advect_van_leer(grid, offset) - - for _ in range(num_steps): - v = _step_func(v, dt, method=advect) - - expected = bc.impose_bc( - _velocity_implicit(grid, offset, v[0].data, dt * num_steps) - ).data - self.assertAllClose(expected, v[0].data, atol=atol, rtol=rtol) class AdvectionTestProperties(test_utils.TestCase): diff --git a/torch_cfd/tests/test_finite_differences.py b/torch_cfd/tests/test_finite_differences.py index 6ec2db3..07ffa78 100644 --- a/torch_cfd/tests/test_finite_differences.py +++ b/torch_cfd/tests/test_finite_differences.py @@ -17,12 +17,16 @@ """Tests for torch_cfd.finite_difference.""" import math + import torch -from einops import repeat from absl.testing import absltest, parameterized +from einops import repeat from torch_cfd import boundaries, finite_differences as fdm, grids, test_utils +BCType = grids.BCType + + def _trim_boundary(array): # fixed jax-cfd bug that trims all dimension for a batched GridVariable if isinstance(array, grids.GridVariable): @@ -35,12 +39,19 @@ def _trim_boundary(array): trimmed_slices = (slice(1, -1),) * tensor.ndim return tensor[(..., *trimmed_slices)] -def periodic_grid_variable(data, offset, grid): + +def grid_variable_periodic(data, offset, grid): return grids.GridVariable( data, offset, grid, bc=boundaries.periodic_boundary_conditions(grid.ndim) ) +def grid_variable_dirichlet(data, offset, grid): + return grids.GridVariable( + data, offset, grid, bc=boundaries.dirichlet_boundary_conditions(grid.ndim) + ) + + def stack_tensor_matrix(matrix): """Stacks a 2D list or tuple of tensors into a rank-4 tensor.""" return torch.stack([torch.stack(row, dim=0) for row in matrix], dim=0) @@ -82,7 +93,7 @@ def test_finite_difference_indexing( ): """Tests finite difference code using explicit indices.""" grid = grids.Grid(shape, step) - u = periodic_grid_variable( + u = grid_variable_periodic( torch.arange(math.prod(shape)).reshape(shape), (0, 0), grid ) actual_x, actual_y = method(u) @@ -157,7 +168,7 @@ def test_finite_difference_analytic( step = tuple([2.0 * torch.pi / s for s in shape]) grid = grids.Grid(shape, step) mesh = grid.mesh() - u = periodic_grid_variable(f(*mesh), offset, grid) + u = grid_variable_periodic(f(*mesh), offset, grid) expected_grad = torch.stack([df(*mesh) for df in gradf]) actual_grad = torch.stack([array.data for array in method(u)]) self.assertAllClose(expected_grad, actual_grad, atol=atol, rtol=rtol) @@ -179,13 +190,58 @@ def test_finite_difference_analytic( atol=1e-3, rtol=1e-8, ), + dict( + testcase_name="_2D_sine", + shape=(32, 32), + f=lambda x, y: torch.sin(math.pi*x) * torch.sin(math.pi*y), + g=lambda x, y: -2 *math.pi**2 * torch.sin(math.pi*x) * torch.sin(math.pi*y), + atol=1/32, + rtol=1e-3, + ) ) - def test_laplacian(self, shape, f, g, atol, rtol): + def test_laplacian_periodic(self, shape, f, g, atol, rtol): step = tuple([1.0 / s for s in shape]) grid = grids.Grid(shape, step) offset = (0,) * len(shape) mesh = grid.mesh(offset) - u = periodic_grid_variable(f(*mesh), offset, grid) + u = grid_variable_periodic(f(*mesh), offset, grid) + expected_laplacian = _trim_boundary(grids.GridVariable(g(*mesh), offset, grid)) + actual_laplacian = _trim_boundary(fdm.laplacian(u)) + self.assertAllClose(expected_laplacian, actual_laplacian, atol=atol, rtol=rtol) + + + @parameterized.named_parameters( + dict( + testcase_name="_2D_constant", + shape=(20, 20), + f=lambda x, y: torch.ones_like(x), + g=lambda x, y: torch.zeros_like(x), + atol=1e-3, + rtol=1e-8, + ), + dict( + testcase_name="_2D_quadratic", + shape=(21, 21), + f=lambda x, y: x * (x - 1.0) + y * (y - 1.0), + g=lambda x, y: 4 * torch.ones_like(x), + atol=1e-3, + rtol=1e-8, + ), + dict( + testcase_name="_2D_sine", + shape=(32, 32), + f=lambda x, y: torch.sin(math.pi*x) * torch.sin(math.pi*y), + g=lambda x, y: -2 *math.pi**2 * torch.sin(math.pi*x) * torch.sin(math.pi*y), + atol=1/32, + rtol=1e-3, + ) + ) + def test_laplacian_dirichlet(self, shape, f, g, atol, rtol): + step = tuple([1.0 / s for s in shape]) + grid = grids.Grid(shape, step) + offset = (0,) * len(shape) + mesh = grid.mesh(offset) + u = grid_variable_dirichlet(f(*mesh), offset, grid) expected_laplacian = _trim_boundary(grids.GridVariable(g(*mesh), offset, grid)) actual_laplacian = _trim_boundary(fdm.laplacian(u)) self.assertAllClose(expected_laplacian, actual_laplacian, atol=atol, rtol=rtol) @@ -214,7 +270,7 @@ def test_divergence(self, shape, offsets, f, g, atol, rtol): step = tuple([1.0 / s for s in shape]) grid = grids.Grid(shape, step) v = [ - periodic_grid_variable(f(*grid.mesh(offset))[axis], offset, grid) + grid_variable_periodic(f(*grid.mesh(offset))[axis], offset, grid) for axis, offset in enumerate(offsets) ] expected_divergence = _trim_boundary( @@ -250,7 +306,7 @@ def test_curl_2d(self, shape, offsets, f, g, atol, rtol): step = tuple([1.0 / s for s in shape]) grid = grids.Grid(shape, step) v = [ - periodic_grid_variable(f(*grid.mesh(offset))[axis], offset, grid) + grid_variable_periodic(f(*grid.mesh(offset))[axis], offset, grid) for axis, offset in enumerate(offsets) ] result_offset = (0.5, 0.5) @@ -260,6 +316,60 @@ def test_curl_2d(self, shape, offsets, f, g, atol, rtol): actual_curl = _trim_boundary(fdm.curl_2d(v)) self.assertAllClose(expected_curl, actual_curl, atol=atol, rtol=rtol) + @parameterized.parameters( + # Periodic BC + dict( + offset=(0,), + bc_types=((BCType.PERIODIC, BCType.PERIODIC),), + expected=[[-2, 1, 0, 1], [1, -2, 1, 0], [0, 1, -2, 1], [1, 0, 1, -2]], + ), + dict( + offset=(0.5,), + bc_types=((BCType.PERIODIC, BCType.PERIODIC),), + expected=[[-2, 1, 0, 1], [1, -2, 1, 0], [0, 1, -2, 1], [1, 0, 1, -2]], + ), + dict( + offset=(1.0,), + bc_types=((BCType.PERIODIC, BCType.PERIODIC),), + expected=[[-2, 1, 0, 1], [1, -2, 1, 0], [0, 1, -2, 1], [1, 0, 1, -2]], + ), + # Dirichlet BC + dict( + offset=(0,), + bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),), + expected=[[-2, 1, 0], [1, -2, 1], [0, 1, -2]], + ), + dict( + offset=(0.5,), + bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),), + expected=[[-3, 1, 0, 0], [1, -2, 1, 0], [0, 1, -2, 1], [0, 0, 1, -3]], + ), + dict( + offset=(1.0,), + bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),), + expected=[[-2, 1, 0], [1, -2, 1], [0, 1, -2]], + ), + # Neumann BC + dict( + offset=(0.5,), + bc_types=((BCType.NEUMANN, BCType.NEUMANN),), + expected=[[-1, 1, 0, 0], [1, -2, 1, 0], [0, 1, -2, 1], [0, 0, 1, -1]], + ), + # Neumann-Dirichlet BC + dict( + offset=(0.5,), + bc_types=((BCType.NEUMANN, BCType.DIRICHLET),), + expected=[[-1, 1, 0, 0], [1, -2, 1, 0], [0, 1, -2, 1], [0, 0, 1, -3]], + ), + ) + def test_laplacian_matrix_w_boundaries(self, offset, bc_types, expected): + grid = grids.Grid((4,), step=(0.5,)) + bc = boundaries.HomogeneousBoundaryConditions(bc_types) + actual = fdm.laplacian_matrix_w_boundaries(grid, offset, bc) + actual = torch.cat([a for a in actual], dim=0) + expected = 4.0 * torch.tensor(expected) + self.assertAllClose(actual, expected) + class FiniteDifferenceBatchTest(test_utils.TestCase): """Test finite difference operations with batch dimensions in 2D.""" @@ -295,18 +405,18 @@ def test_finite_difference_batch_preserves_shape( ): """Test that finite difference operations preserve batch dimensions.""" grid = grids.Grid(shape, step) - + # Create batched data: (batch_size, *shape) batched_data = torch.randn(batch_size, *shape) - u = periodic_grid_variable(batched_data, (0, 0), grid) - + u = grid_variable_periodic(batched_data, (0, 0), grid) + # Apply finite difference grad_x, grad_y = method(u) - + # Check that batch dimension is preserved self.assertEqual(grad_x.data.shape, (batch_size, *shape)) self.assertEqual(grad_y.data.shape, (batch_size, *shape)) - + # Check offsets are correct self.assertEqual(grad_x.offset, (expected_offset, 0)) self.assertEqual(grad_y.offset, (0, expected_offset)) @@ -322,8 +432,8 @@ def test_finite_difference_batch_preserves_shape( lambda x, y: torch.cos(x) * torch.cos(y), lambda x, y: -torch.sin(x) * torch.sin(y), ), - atol=2*math.pi/40, - rtol=1/40, + atol=2 * math.pi / 40, + rtol=1 / 40, ), dict( testcase_name="_forward_difference", @@ -335,8 +445,8 @@ def test_finite_difference_batch_preserves_shape( lambda x, y: torch.cos(x) * torch.cos(y), lambda x, y: -torch.sin(x) * torch.sin(y), ), - atol=2*math.pi/128, - rtol=1/128, + atol=2 * math.pi / 128, + rtol=1 / 128, ), dict( testcase_name="_central_difference_coarse_fine", @@ -348,8 +458,8 @@ def test_finite_difference_batch_preserves_shape( lambda x, y: torch.cos(x) * torch.cos(y), lambda x, y: -torch.sin(x) * torch.sin(y), ), - atol=2*math.pi/1024, - rtol=1/1024, + atol=2 * math.pi / 1024, + rtol=1 / 1024, ), ) def test_finite_difference_batch_analytic( @@ -359,21 +469,21 @@ def test_finite_difference_batch_analytic( step = tuple([2.0 * math.pi / s for s in shape]) grid = grids.Grid(shape, step) mesh = grid.mesh() - + # Create batched data by repeating the same function # In practice, each batch element could be different single_data = f(*mesh) - batched_data = repeat(single_data, 'h w -> b h w', b=batch_size) - - u = periodic_grid_variable(batched_data, (0, 0), grid) - + batched_data = repeat(single_data, "h w -> b h w", b=batch_size) + + u = grid_variable_periodic(batched_data, (0, 0), grid) + # Compute gradients actual_grad_x, actual_grad_y = method(u) - + # Expected gradients (also batched) - expected_grad_x = repeat(gradf[0](*mesh), 'h w -> b h w', b=batch_size) - expected_grad_y = repeat(gradf[1](*mesh), 'h w -> b h w', b=batch_size) - + expected_grad_x = repeat(gradf[0](*mesh), "h w -> b h w", b=batch_size) + expected_grad_y = repeat(gradf[1](*mesh), "h w -> b h w", b=batch_size) + self.assertAllClose(expected_grad_x, actual_grad_x.data, atol=atol, rtol=rtol) self.assertAllClose(expected_grad_y, actual_grad_y.data, atol=atol, rtol=rtol) @@ -384,24 +494,24 @@ def test_laplacian_batch(self): step = (0.1, 0.1) grid = grids.Grid(shape, step) offset = (0, 0) - + # Test function: f(x,y) = x^2 + y^2, so Laplacian should be 4 mesh = grid.mesh(offset) - single_data = mesh[0]**2 + mesh[1]**2 + single_data = mesh[0] ** 2 + mesh[1] ** 2 batched_data = single_data.unsqueeze(0).repeat(batch_size, 1, 1) - - u = periodic_grid_variable(batched_data, offset, grid) + + u = grid_variable_periodic(batched_data, offset, grid) actual_laplacian = fdm.laplacian(u) - + # Expected Laplacian is 4 everywhere expected_laplacian = 4 * torch.ones(batch_size, *shape) - + # Trim boundary for comparison trimmed_actual = _trim_boundary(actual_laplacian) trimmed_expected = _trim_boundary( grids.GridVariable(expected_laplacian, offset, grid) ) - + self.assertAllClose( trimmed_expected.data, trimmed_actual.data, atol=1e-2, rtol=1e-8 ) @@ -410,22 +520,22 @@ def test_laplacian_batch_analytic(self): """Test Laplacian operator on batched data against Laplacian on a single data.""" batch_size = 3 shape = (128, 128) - step = (2*math.pi/128, 2*math.pi/128) + step = (2 * math.pi / 128, 2 * math.pi / 128) grid = grids.Grid(shape, step) - + mesh = grid.mesh() single_data = torch.sin(mesh[0]) * torch.cos(mesh[1]) - batched_data = repeat(single_data, 'h w -> b h w', b=batch_size) - - u_single = periodic_grid_variable(single_data, (0, 0), grid) - u_batch = periodic_grid_variable(batched_data, (0, 0), grid) + batched_data = repeat(single_data, "h w -> b h w", b=batch_size) + + u_single = grid_variable_periodic(single_data, (0, 0), grid) + u_batch = grid_variable_periodic(batched_data, (0, 0), grid) single_laplacian = fdm.laplacian(u_single) batch_laplacian = fdm.laplacian(u_batch) - + # Trim boundary for comparison trimmed_single = _trim_boundary(single_laplacian) trimmed_batch = _trim_boundary(batch_laplacian) - + for i in range(batch_size): self.assertAllClose( trimmed_single.data, trimmed_batch.data[i], atol=1e-8, rtol=1e-12 @@ -438,34 +548,34 @@ def test_divergence_batch(self): step = (0.1, 0.1) grid = grids.Grid(shape, step) offsets = ((0.5, 0), (0, 0.5)) - + # Test vector field: v = (x, y), so divergence should be 2 mesh_x = grid.mesh(offsets[0]) mesh_y = grid.mesh(offsets[1]) - + # Create batched vector components vx_single = mesh_x[0] # x component vy_single = mesh_y[1] # y component - - vx_batched = repeat(vx_single, 'h w -> b h w', b=batch_size) - vy_batched = repeat(vy_single, 'h w -> b h w', b=batch_size) - + + vx_batched = repeat(vx_single, "h w -> b h w", b=batch_size) + vy_batched = repeat(vy_single, "h w -> b h w", b=batch_size) + v = [ - periodic_grid_variable(vx_batched, offsets[0], grid), - periodic_grid_variable(vy_batched, offsets[1], grid), + grid_variable_periodic(vx_batched, offsets[0], grid), + grid_variable_periodic(vy_batched, offsets[1], grid), ] - + actual_divergence = fdm.divergence(v) - + # Expected divergence is 2 everywhere expected_divergence = 2 * torch.ones(batch_size, *shape) - + # Trim boundary for comparison trimmed_actual = _trim_boundary(actual_divergence) trimmed_expected = _trim_boundary( grids.GridVariable(expected_divergence, (0, 0), grid) ) - + self.assertAllClose( trimmed_expected.data, trimmed_actual.data, atol=1e-2, rtol=1e-8 ) @@ -477,35 +587,35 @@ def test_curl_2d_batch(self): step = (0.1, 0.1) grid = grids.Grid(shape, step) offsets = ((0.5, 0), (0, 0.5)) - + # Test vector field: v = (-y, x), so curl should be 2 mesh_x = grid.mesh(offsets[0]) mesh_y = grid.mesh(offsets[1]) - + # Create batched vector components vx_single = -mesh_x[1] # -y component - vy_single = mesh_y[0] # x component - - vx_batched = repeat(vx_single, 'h w -> b h w', b=batch_size) - vy_batched = repeat(vy_single, 'h w -> b h w', b=batch_size) - + vy_single = mesh_y[0] # x component + + vx_batched = repeat(vx_single, "h w -> b h w", b=batch_size) + vy_batched = repeat(vy_single, "h w -> b h w", b=batch_size) + v = [ - periodic_grid_variable(vx_batched, offsets[0], grid), - periodic_grid_variable(vy_batched, offsets[1], grid), + grid_variable_periodic(vx_batched, offsets[0], grid), + grid_variable_periodic(vy_batched, offsets[1], grid), ] - + actual_curl = fdm.curl_2d(v) - + # Expected curl is 2 everywhere result_offset = (0.5, 0.5) expected_curl = 2 * torch.ones(batch_size, *shape) - + # Trim boundary for comparison trimmed_actual = _trim_boundary(actual_curl) trimmed_expected = _trim_boundary( grids.GridVariable(expected_curl, result_offset, grid) ) - + self.assertAllClose( trimmed_expected.data, trimmed_actual.data, atol=1e-2, rtol=1e-8 ) @@ -515,24 +625,25 @@ def test_batch_consistency_across_operations(self): shape = (24, 24) step = (0.05, 0.05) grid = grids.Grid(shape, step) - + # Create test function mesh = grid.mesh() single_data = torch.sin(mesh[0]) * torch.cos(mesh[1]) - + # Test with different batch sizes for batch_size in [1, 2, 4]: - batched_data = repeat(single_data, 'h w -> b h w', b=batch_size) - u = periodic_grid_variable(batched_data, (0, 0), grid) - + batched_data = repeat(single_data, "h w -> b h w", b=batch_size) + u = grid_variable_periodic(batched_data, (0, 0), grid) + # Apply central difference grad_x = fdm.central_difference(u, dim=0) grad_y = fdm.central_difference(u, dim=1) - + # Each batch element should be identical for i in range(1, batch_size): self.assertAllClose(grad_x.data[0], grad_x.data[i]) self.assertAllClose(grad_y.data[0], grad_y.data[i]) + if __name__ == "__main__": absltest.main() diff --git a/torch_cfd/tests/test_solvers.py b/torch_cfd/tests/test_solvers.py new file mode 100644 index 0000000..cd2f16f --- /dev/null +++ b/torch_cfd/tests/test_solvers.py @@ -0,0 +1,598 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modifications copyright (C) 2025 S.Cao +# ported Google's Jax-CFD functional template to PyTorch's tensor ops + +"""Tests for torch_cfd.solvers.""" + +import math + +import torch +from absl.testing import absltest, parameterized + +from torch_cfd import boundaries, finite_differences as fdm, grids, solvers, test_utils + +BCType = grids.BCType + + +def grid_variable_periodic(data, offset, grid): + return grids.GridVariable( + data, offset, grid, bc=boundaries.periodic_boundary_conditions(grid.ndim) + ) + + +def grid_variable_dirichlet(data, offset, grid): + return grids.GridVariable( + data, offset, grid, bc=boundaries.dirichlet_boundary_conditions(grid.ndim) + ) + + +class SolversTest(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.dtype = torch.float64 + torch.set_default_dtype(self.dtype) + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + + def poisson_random_data(self, bc, offset, shape, batch_size=2, random_state=42): + """Setup for Poisson equation tests.""" + torch.manual_seed(random_state) + b = torch.randn((batch_size, *shape), dtype=self.dtype, device=self.device) + if boundaries.is_bc_pure_neumann_boundary_conditions(bc): + # For Neumann BC, subtract mean to ensure solvability + b = b - b.mean(dim=(-2, -1), keepdim=True) + if boundaries.is_bc_all_periodic_boundary_conditions(bc): + b = torch.fft.ifftn(b, dim=(-2, -1)).real + grid = grids.Grid(shape, step=tuple(1.0 / s for s in shape), device=self.device) + b = grids.GridVariable(b, offset, grid, bc=bc) + + return b, grid + + def poisson_smooth_data( + self, + bc, + offset, + shape, + batch_size: int = 2, + num_modes: int = 10, + random_state: int = 42, + ): + torch.manual_seed(random_state) + grid = grids.Grid(shape, step=tuple(1.0 / s for s in shape), device=self.device) + # Create meshgrid + X, Y = grid.mesh(offset) + pure_neumann = boundaries.is_bc_pure_neumann_boundary_conditions(bc) + func = torch.cos if pure_neumann else torch.sin + # Initialize solution + u_true = torch.zeros((batch_size, *shape), device=self.device, dtype=self.dtype) + + for i in range(batch_size): + # Generate random coefficients a_k + a_k = torch.randn(num_modes, device=self.device, dtype=self.dtype) + # Sum over modes: u_true = sum(a_k * cos(k*pi*x) * cos(k*pi*y)) + for k in range(1, num_modes + 1): + b_k = torch.randint( + 1, num_modes + 1, (1,), device=self.device + ).item() # Randomly choose a mode for y + c_k = torch.randint( + 1, num_modes + 1, (1,), device=self.device + ).item() # Randomly choose a mode for x + components = ( + a_k[k - 1] * func(c_k * torch.pi * X) * func(b_k * torch.pi * Y) / k + ) + u_true[i] += components # Add batch dimension + + if pure_neumann: + u_true -= torch.mean(u_true, dim=(-2, -1), keepdim=True) + + u_true = grids.GridVariable(u_true, offset, grid, bc=bc) + + return u_true, grid + + @parameterized.named_parameters( + dict( + testcase_name="_2D_periodic_fft", + shape=(64, 64), + bc_factory=boundaries.periodic_boundary_conditions, + solver=solvers.PseudoInverseFFT, + ), + dict( + testcase_name="_2D_dirichlet_matmul", + shape=(64, 64), + bc_factory=boundaries.dirichlet_boundary_conditions, + solver=solvers.PseudoInverseMatmul, + ), + dict( + testcase_name="_2D_periodic_rfft", + shape=(64, 64), + bc_factory=boundaries.periodic_boundary_conditions, + solver=solvers.PseudoInverseRFFT, + ), + ) + def test_pseudoinverse_solvers(self, shape, bc_factory, solver): + """Test that solvers correctly solve Poisson equation.""" + ndim = len(shape) + h = 1.0 / shape[0] # Grid spacing + bc = bc_factory(ndim) + b, grid = self.poisson_random_data(bc, offset=(0.5,) * ndim, shape=shape) + + # # For periodic BC, subtract mean to ensure solvability + if boundaries.is_bc_all_periodic_boundary_conditions(bc): + b.data = b.data - b.data.mean() + + # Create solver + solver = solver(grid, bc, dtype=self.dtype, tol=1e-12).to(self.device) + + # Solve + u = solver(b.data, torch.zeros_like(b.data)) + u_var = grids.GridVariable(u, offset=(0.5,) * grid.ndim, grid=grid, bc=bc) + + # Apply Laplacian to solution + laplacian_u = fdm.laplacian(u_var) + + # Check that Laplacian of solution equals RHS (up to mean for periodic) + if boundaries.is_bc_all_periodic_boundary_conditions(bc): + expected = b.data - b.data.mean() + actual = laplacian_u.data - laplacian_u.data.mean() + else: + expected = b.data + actual = laplacian_u.data + + self.assertAllClose(actual, expected, atol=h**2, rtol=h) + + @parameterized.named_parameters( + dict( + testcase_name="_jacobi_2D", + solver=solvers.Jacobi, + shape=(32, 32), + maxiter=2000, + tol=1e-5, + ), + dict( + testcase_name="_gauss_seidel_2D", + solver=solvers.GaussSeidel, + shape=(32, 32), + maxiter=1000, + tol=1e-6, + ), + dict( + testcase_name="_cg_2D", + solver=solvers.ConjugateGradient, + shape=(32, 32), + maxiter=100, + tol=1e-8, + ), + ) + def test_iterative_solvers(self, solver, shape, maxiter, tol): + """Test iterative solvers on Poisson equation.""" + + ndim = len(shape) + bc = boundaries.dirichlet_boundary_conditions(ndim) + b, grid = self.poisson_random_data(bc, offset=(0.5,) * ndim, shape=shape) + + # Create solver + solver = solver(grid, bc, dtype=self.dtype, tol=tol, max_iter=maxiter).to( + self.device + ) + + # Solve + u_init = torch.zeros_like(b.data) + u = solver.solve(b.data, u_init) + u_var = grids.GridVariable(u, offset=(0.5, 0.5), grid=grid, bc=bc) + + # Apply Laplacian to solution + laplacian_u = fdm.laplacian(u_var) + + # Check that Laplacian of solution equals RHS + self.assertAllClose(laplacian_u.data, b.data, atol=1e-4, rtol=1e-4) + + @parameterized.named_parameters( + dict( + testcase_name="_pseudoinverse_fft", + solver=solvers.PseudoInverseFFT, + batch_size=3, + bc_factory=boundaries.periodic_boundary_conditions, + ), + dict( + testcase_name="_pseudoinverse_rfft", + solver=solvers.PseudoInverseRFFT, + batch_size=2, + bc_factory=boundaries.periodic_boundary_conditions, + ), + dict( + testcase_name="_pseudoinverse_matmul", + solver=solvers.PseudoInverseMatmul, + batch_size=4, + bc_factory=boundaries.dirichlet_boundary_conditions, + ), + ) + def test_batch_pseudoinverse(self, solver, batch_size, bc_factory): + """Test that solvers work with batch dimensions.""" + shape = (16, 16) + ndim = len(shape) + bc = bc_factory(ndim) + + torch.manual_seed(9876) + b, grid = self.poisson_random_data( + bc, offset=(0.5,) * ndim, shape=shape, batch_size=batch_size + ) + + # Handle periodic BC mean subtraction + if all(bc_type == BCType.PERIODIC for bc_type in bc.types[0]): + b = b - torch.mean(b, dim=(-2, -1), keepdim=True) + + solver = solver(grid, bc, dtype=self.dtype, tol=1e-10).to(self.device) + + # Solve batch + u_init = torch.zeros_like(b.data) + u_batch = solver.forward(b.data, u_init) + + # Test each item in batch individually + for i in range(batch_size): + b_single = b[i] + u_single = u_batch[i] + + # Create GridVariable and apply Laplacian + u_single = grids.GridVariable(u_single, offset=(0.5, 0.5), grid=grid, bc=bc) + laplacian_u = fdm.laplacian(u_single) + + # Check solution + if all(bc_type == BCType.PERIODIC for bc_type in bc.types[0]): + expected = b_single.data - torch.mean(b_single.data) + actual = laplacian_u.data - torch.mean(laplacian_u.data) + else: + expected = b_single.data + actual = laplacian_u.data + + self.assertAllClose(actual, expected, atol=1e-3, rtol=1e-3) + + def test_laplacian_eigenvalues(self): + """Test that eigenvalue computation is correct.""" + shape = (8, 8) + grid = grids.Grid(shape, step=(1.0, 1.0)) + bc = boundaries.periodic_boundary_conditions(grid.ndim) + + solver = solvers.PseudoInverseFFT(grid, bc, dtype=self.dtype).to(self.device) + + # For periodic BC with step=1, eigenvalues should be related to Fourier modes + expected_eigenvals_1d = [] + for n in shape: + k = torch.arange(n, dtype=self.dtype, device=self.device) - n // 2 + # Shift to match FFT ordering + k = torch.fft.fftshift(k) + # Eigenvalues for central difference on periodic grid + eigvals = -4 * torch.sin(math.pi * k / n) ** 2 + expected_eigenvals_1d.append(eigvals) + + # Create 2D eigenvalue matrix + expected_eigenvals_2d = ( + expected_eigenvals_1d[0][:, None] + expected_eigenvals_1d[1][None, :] + ) + + # Get actual eigenvalues from solver (assuming solver stores them) + inv_eigs = torch.as_tensor(solver.inverse_diag).real + actual_eigenvals = 1 / inv_eigs + + # Compare eigenvalues (excluding the zero mode for periodic BC) + mask = torch.abs(expected_eigenvals_2d) > 1e-12 + self.assertAllClose( + actual_eigenvals[mask], expected_eigenvals_2d[mask], atol=1e-10, rtol=1e-10 + ) + + # The solver should handle the zero eigenvalue correctly + self.assertTrue(hasattr(solver, "inverse_diag")) + + def test_solver_with_zero_rhs(self): + """Test solver behavior with zero right-hand side.""" + shape = (16, 16) + grid = grids.Grid(shape, domain=((0, 1), (0, 1))) + bc = boundaries.dirichlet_boundary_conditions(grid.ndim) + + b = torch.zeros(shape, dtype=self.dtype, device=self.device) + + solver = solvers.PseudoInverseMatmul(grid, bc, dtype=self.dtype).to(self.device) + u = solver(b, torch.zeros_like(b)) + + # Solution should be zero (up to numerical precision) + self.assertAllClose(u, torch.zeros_like(u), atol=1e-10, rtol=1e-10) + + def test_solver_single_mode(self): + shape = (64, 64) + h = 1.0 / shape[0] # Grid spacing + grid = grids.Grid(shape, domain=((0, 1), (0, 1))) + bc = boundaries.periodic_boundary_conditions(grid.ndim) + + X, Y = grid.mesh() + b = torch.sin(2 * math.pi * X) * torch.sin(2 * math.pi * Y) + b = b - b.mean() + b = b.to(self.device, dtype=self.dtype) + + solver = solvers.PseudoInverseFFT(grid, bc, dtype=self.dtype).to(self.device) + u = solver(b, torch.zeros_like(b)) + + expected_u = b / (-8 * math.pi**2) + expected_u = expected_u - expected_u.mean() + + self.assertAllClose(u - u.mean(), expected_u, atol=h**2, rtol=h**2) + + @parameterized.named_parameters( + dict( + testcase_name="_periodic", + bc_func=boundaries.periodic_boundary_conditions, + shape=(16, 16), + tol=1e-8, + ), + dict( + testcase_name="_dirichlet", + bc_func=boundaries.dirichlet_boundary_conditions, + shape=(32, 32), + tol=1e-10, + ), + dict( + testcase_name="_neumann", + bc_func=boundaries.neumann_boundary_conditions, + shape=(32, 32), + tol=1e-10, + ), + ) + def test_pseudoinverse_factory(self, bc_func, shape, tol): + """Test the PseudoInverse factory class.""" + grid = grids.Grid(shape, domain=((0, 1), (0, 1))) + + # Test automatic selection for periodic BC + bc = bc_func(grid.ndim) + is_periodic = all( + [ + boundaries.is_bc_periodic_boundary_conditions(bc, dim) + for dim in range(grid.ndim) + ] + ) + solver_auto = solvers.PseudoInverse( + grid, + bc, + dtype=self.dtype, + hermitian=True, + circulant=is_periodic, + cutoff=tol, + ) + + if is_periodic: + self.assertIsInstance( + solver_auto, (solvers.PseudoInverseFFT, solvers.PseudoInverseRFFT) + ) + else: + self.assertIsInstance(solver_auto, solvers.PseudoInverseMatmul) + + @parameterized.named_parameters( + dict( + testcase_name="_dirichlet_small", + bc_func=boundaries.dirichlet_boundary_conditions, + shape=(64, 64), + factor=4, + ), + dict( + testcase_name="_dirichlet_medium", + bc_func=boundaries.dirichlet_boundary_conditions, + shape=(256, 256), + factor=8, + ), + dict( + testcase_name="_neumann_small", + bc_func=boundaries.neumann_boundary_conditions, + shape=(64, 64), + factor=4, + ), + dict( + testcase_name="_neumann_medium", + bc_func=boundaries.neumann_boundary_conditions, + shape=(256, 256), + factor=8, + ), + ) + def test_conjugate_gradient_convergence(self, bc_func, shape, factor): + ndim = len(shape) + h = 1.0 / shape[0] # Grid spacing + bc = bc_func(ndim) + + b, grid = self.poisson_random_data( + bc, offset=(0.5,) * ndim, shape=shape, batch_size=2 + ) + pure_neumann = boundaries.is_bc_pure_neumann_boundary_conditions(bc) + # Test CG convergence + tol = h**2 * factor + rate = (1 - h / factor) / (1 + h / factor) + max_iter = int(math.ceil(math.log(tol) / math.log(rate))) + cg = solvers.ConjugateGradient( + grid, + bc, + dtype=self.dtype, + tol=tol, + max_iter=max_iter, + check_iter=1, + record_residuals=True, + pure_neumann=pure_neumann, + ).to(self.device) + + _ = cg.solve(b, torch.zeros_like(b.data)) + residuals = cg.residual_norms.detach().cpu().numpy() + for i in range(1, len(residuals)): + if residuals[i] > 0.5: + self.assertLess(residuals[i], residuals[i - 1] * rate) + self.assertLess(residuals[-1], tol) + + @parameterized.named_parameters( + dict( + testcase_name="_dirichlet_small", + bc_func=boundaries.dirichlet_boundary_conditions, + shape=(64, 64), + level=2, + ), + dict( + testcase_name="_dirichlet_medium", + bc_func=boundaries.dirichlet_boundary_conditions, + shape=(128, 128), + level=3, + ), + dict( + testcase_name="_neumann_small", + bc_func=boundaries.neumann_boundary_conditions, + shape=(64, 64), + level=2, + ), + dict( + testcase_name="_neumann_medium", + bc_func=boundaries.neumann_boundary_conditions, + shape=(128, 128), + level=3, + ), + ) + def test_multigrid_convergence(self, bc_func, shape, level): + """Test multigrid solver if implemented.""" + h = 1.0 / shape[0] # Grid spacing + ndim = len(shape) + bc = bc_func(ndim) + offset = (0.5,) * ndim + u_true, grid = self.poisson_smooth_data( + bc, offset, shape, num_modes=4 * level, batch_size=2 + ) + pure_neumann = boundaries.is_bc_pure_neumann_boundary_conditions(bc) + tol = h / 8 + factor = (1 - h / 2) / (1 + h / 2) + solver = solvers.MultigridSolver( + grid, + bc, + dtype=self.dtype, + tol=tol, + max_iter=5, + levels=level, + record_residuals=True, + pure_neumann=pure_neumann, + ).to(self.device) + + b = solver._apply_laplacian(u_true.data) + if pure_neumann: + b -= b.mean(dim=(-2, -1), keepdim=True) + + u = solver.solve(b, torch.zeros_like(b)) + residuals = solver.residual_norms.detach().cpu().numpy() + rel_err = torch.linalg.norm(u_true.data - u) / torch.linalg.norm(u_true.data) + for i in range(1, len(residuals)): + if residuals[i] > h: + self.assertLess(residuals[i], residuals[i - 1] * factor) + self.assertLess(rel_err.item(), tol) + + @parameterized.named_parameters( + dict( + testcase_name="_dirichlet_medium", + bc_func=boundaries.dirichlet_boundary_conditions, + shape=(128, 128), + level=3, + ), + dict( + testcase_name="_dirichlet_large", + bc_func=boundaries.dirichlet_boundary_conditions, + shape=(256, 256), + level=4, + ), + dict( + testcase_name="_neumann_medium", + bc_func=boundaries.neumann_boundary_conditions, + shape=(128, 128), + level=3, + ), + dict( + testcase_name="_neumann_large", + bc_func=boundaries.neumann_boundary_conditions, + shape=(256, 256), + level=4, + ), + ) + def test_mg_preconditioned_cg_convergence(self, bc_func, shape, level): + """Test CG with multigrid preconditioning.""" + h = 1.0 / shape[0] + ndim = len(shape) + bc = bc_func(ndim) + offset = (0.5,) * ndim + u_true, grid = self.poisson_smooth_data( + bc, offset, shape, num_modes=4 * level, batch_size=1 + ) + pure_neumann = boundaries.is_bc_pure_neumann_boundary_conditions(bc) + tol = h + factor = (1 - h / 8) / (1 + h / 8) + + precond = solvers.MultigridSolver( + grid, + bc, + dtype=self.dtype, + tol=h, + max_iter=1, + pre_smooth=1, + post_smooth=1, + levels=level, + ).to(self.device) + + solver = solvers.ConjugateGradient( + grid, + bc, + dtype=self.dtype, + tol=tol, + max_iter=10, + check_iter=1, + record_residuals=True, + preconditioner=precond, + pure_neumann=pure_neumann, + ).to(self.device) + + b = solver._apply_laplacian(u_true.data) + if pure_neumann: + b -= b.mean(dim=(-2, -1), keepdim=True) + + u = solver.solve(b, torch.zeros_like(b)) + residuals = solver.residual_norms.detach().cpu().numpy() + rel_err = torch.linalg.norm(u_true.data - u) / torch.linalg.norm(u_true.data) + for i in range(1, solver.stop_iter): + if residuals[i] > 8*h: + self.assertLess(residuals[i], residuals[i - 1] * factor) + self.assertLess(rel_err.item(), tol) + + @parameterized.named_parameters( + dict(testcase_name="_128x128", shape=(128, 128)), + dict(testcase_name="_256x256", shape=(256, 256)), + dict(testcase_name="_512x512", shape=(512, 512)), + ) + def test_pseudoinverse_fft(self, shape): + grid = grids.Grid(shape, domain=((0, 1), (0, 1))) + bc = boundaries.periodic_boundary_conditions(grid.ndim) + + b = torch.randn(shape, dtype=self.dtype, device=self.device) + b = b - b.mean() + + solver = solvers.PseudoInverseFFT(grid, bc, dtype=self.dtype).to(self.device) + u_fft = solver(b, torch.zeros_like(b)) + + # Basic correctness check + u_var = grids.GridVariable(u_fft, offset=(0.5, 0.5), grid=grid, bc=bc) + laplacian_u = fdm.laplacian(u_var) + expected = b - b.mean() + actual = laplacian_u.data - laplacian_u.data.mean() + + self.assertAllClose(actual, expected, atol=1e-3, rtol=1e-3) + + +if __name__ == "__main__": + absltest.main()