Skip to content

Added multigrid solvers for pressure projection and tests #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torch_cfd/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ def forward(self, cs: GridVariableVector, v: GridVariableVector) -> GridVariable
flux = GridVariableVector(tuple(c * u for c, u in zip(cs, v)))

# wrap flux with boundary conditions to flux if not periodic
# flux = GridVariableVector(
# tuple(bc.impose_bc(f) for f, bc in zip(flux, self.flux_bcs))
# )
flux = GridVariableVector(tuple(GridVariable(f.data, offset, f.grid, bc) for f, offset, bc in zip(flux, self.offsets, self.flux_bcs)))


Expand Down
22 changes: 21 additions & 1 deletion torch_cfd/boundaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def pad_and_impose_bc(
)
return GridVariable(u.data, u.offset, u.grid, self)

def impose_bc(self, u: GridVariable, mode: str="") -> GridVariable:
def impose_bc(self, u: GridVariable, mode: str = "") -> GridVariable:
"""Returns GridVariable with correct boundary condition.

Some grid points of GridVariable might coincide with boundary. This ensures
Expand Down Expand Up @@ -435,12 +435,32 @@ def is_bc_periodic_boundary_conditions(bc: BoundaryConditions, dim: int) -> bool
)
return True

def is_bc_all_periodic_boundary_conditions(bc: BoundaryConditions) -> bool:
"""Returns true if scalar has periodic bc along all axes."""
for dim in range(bc.ndim):
if not is_bc_periodic_boundary_conditions(bc, dim):
return False
return True


def is_periodic_boundary_conditions(c: GridVariable, dim: int) -> bool:
"""Returns true if scalar has periodic bc along axis."""
return is_bc_periodic_boundary_conditions(c.bc, dim)


def is_bc_pure_neumann_boundary_conditions(bc: BoundaryConditions) -> bool:
"""Returns true if scalar has pure Neumann bc along all axes."""
for dim in range(bc.ndim):
if bc.types[dim][0] != BCType.NEUMANN or bc.types[dim][1] != BCType.NEUMANN:
return False
return True


def is_pure_neumann_boundary_conditions(c: GridVariable) -> bool:
"""Returns true if scalar has pure Neumann bc along all axes."""
return is_bc_pure_neumann_boundary_conditions(c.bc)


# Convenience utilities to ease updating of BoundaryConditions implementation
def periodic_boundary_conditions(ndim: int) -> BoundaryConditions:
"""Returns periodic BCs for a variable with `ndim` spatial dimension."""
Expand Down
99 changes: 57 additions & 42 deletions torch_cfd/finite_differences.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import torch
from torch_cfd import boundaries, grids

ArrayVector = Sequence[torch.Tensor]
ArrayVector = List[torch.Tensor]
GridVariable = grids.GridVariable
GridTensor = grids.GridTensor
GridVariableVector = Union[grids.GridVariableVector, Sequence[grids.GridVariable]]
Expand Down Expand Up @@ -159,20 +159,22 @@ def set_laplacian_matrix(
grid: grids.Grid,
bc: boundaries.BoundaryConditions,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> ArrayVector:
"""Initialize the Laplacian operators."""

offset = grid.cell_center
return laplacian_matrix_w_boundaries(grid, offset=offset, bc=bc, device=device)
return laplacian_matrix_w_boundaries(grid, offset=offset, bc=bc, device=device, dtype=dtype)


def laplacian_matrix(n: int, step: float, sparse: bool = False) -> torch.Tensor:
def laplacian_matrix(n: int, step: float, sparse: bool = False, dtype=torch.float32) -> torch.Tensor:
"""
Create 1D Laplacian operator matrix, with periodic BC.
modified the scipy.linalg.circulant implementation to native torch
The matrix is a tri-diagonal matrix with [1, -2, 1]/h**2
Modified the scipy.linalg.circulant implementation to native torch
"""
if sparse:
values = torch.tensor([1.0, -2.0, 1.0]) / step**2
values = torch.tensor([1.0, -2.0, 1.0], dtype=dtype) / step**2
idx_row = torch.arange(n).repeat(3)
idx_col = torch.cat(
[
Expand All @@ -188,33 +190,45 @@ def laplacian_matrix(n: int, step: float, sparse: bool = False) -> torch.Tensor:
)
return torch.sparse_coo_tensor(indices, data, size=(n, n))
else:
column = torch.zeros(n)
column = torch.zeros(n, dtype=dtype)
column[0] = -2 / step**2
column[1] = column[-1] = 1 / step**2
idx = (n - torch.arange(n)[None].T + torch.arange(n)[None]) % n
return torch.gather(column[None, ...].expand(n, -1), 1, idx)


def _laplacian_boundary_dirichlet_cell_centered(
laplacians: ArrayVector, grid: grids.Grid, axis: int, side: str
laplacians: ArrayVector, grid: grids.Grid, dim: int, side: str
) -> None:
"""Converts 1d laplacian matrix to satisfy dirichlet homogeneous bc.

laplacians[i] contains a 3 point stencil matrix L that approximates
d^2/dx_i^2.
For detailed documentation on laplacians input type see
array_utils.laplacian_matrix.
The default return of array_utils.laplacian_matrix makes a matrix for
periodic boundary. For dirichlet boundary, the correct equation is
L(u_interior) = rhs_interior and BL_boundary = u_fixed_boundary. So
fdm.laplacian_matrix.
The default return of fdm.laplacian_matrix makes a matrix for
periodic boundary. For (homogeneous) dirichlet boundary, the correct equation is
L(u_interior) = rhs_interior
BL_boundary = u_fixed_boundary.
So
laplacian_boundary_dirichlet restricts the matrix L to
interior points only.
interior points only.

Denote the node in the 3-pt stencil as
u[ghost], u[boundary], u[interior] = u[0], u[1], u[2].
The original stencil on the boundary is
[1, -2, 1] * [u[0], u[1], u[2]] = u[0] - 2*u[1] + u[2]
In the homogeneous Dirichlet bc case if the offset
is 0.5 away from the wall, the ghost cell value u[0] = -u[1]. So the
3 point stencil [1 -2 1] * [u[0] u[1] u[2]] = -3 u[1] + u[2].
The original diagonal of Laplacian Lap[0, 0] is -2/h**2, we need -3/h**2,
thus 1/h**2 is subtracted from the diagonal, and the ghost cell dof is set to zero (Lap[0, -1])

This function assumes RHS has cell-centered offset.
Args:
laplacians: list of 1d laplacians
grid: grid object
axis: axis along which to impose dirichlet bc.
dim: axis along which to impose dirichlet bc.
side: lower or upper side to assign boundary to.

Returns:
Expand All @@ -223,52 +237,50 @@ def _laplacian_boundary_dirichlet_cell_centered(
TODO:
[ ]: this function is not implemented in the original Jax-CFD code.
"""
# This function assumes homogeneous boundary, in which case if the offset
# is 0.5 away from the wall, the ghost cell value u[0] = -u[1]. So the
# 3 point stencil [1 -2 1] * [u[0] u[1] u[2]] = -3 u[1] + u[2].

if side == "lower":
laplacians[axis][0, 0] = laplacians[axis][0, 0] - 1 / grid.step[axis] ** 2
laplacians[dim][0, 0] = laplacians[dim][0, 0] - 1 / grid.step[dim] ** 2
else:
laplacians[axis][-1, -1] = laplacians[axis][-1, -1] - 1 / grid.step[axis] ** 2
laplacians[dim][-1, -1] = laplacians[dim][-1, -1] - 1 / grid.step[dim] ** 2
# deletes corner dependencies on the "looped-around" part.
# this should be done irrespective of which side, since one boundary cannot
# be periodic while the other is.
laplacians[axis][0, -1] = 0.0
laplacians[axis][-1, 0] = 0.0
return laplacians
laplacians[dim][0, -1] = 0.0
laplacians[dim][-1, 0] = 0.0
return


def _laplacian_boundary_neumann_cell_centered(
laplacians: List[Any], grid: grids.Grid, axis: int, side: str
laplacians: ArrayVector, grid: grids.Grid, dim: int, side: str
) -> None:
"""Converts 1d laplacian matrix to satisfy neumann homogeneous bc.

This function assumes the RHS will have a cell-centered offset.
Neumann boundaries are not defined for edge-aligned offsets elsewhere in the
code.
code. For homogeneous Neumann BC (du/dn = 0), the ghost cell should equal the interior cell: u[ghost] = u[1]. The stencil becomes:
[1, -2, 1] * [u[1], u[1], u[2]] = u[1] - 2*u[1] + u[2] = -u[1] + u[2]
The original diagonal of Laplacian Lap[0, 0] is -2/h**2, we need -1/h**2,
thus 1/h**2 is added to the diagonal, and the ghost cell dof is set to zero (Lap[0, -1]).

Args:
laplacians: list of 1d laplacians
grid: grid object
axis: axis along which to impose dirichlet bc.
dim: axis along which to impose dirichlet bc.
side: which boundary side to convert to neumann homogeneous bc.

Returns:
updated list of 1d laplacians.

TODO
[ ]: this function is not implemented in the original Jax-CFD code.
"""
if side == "lower":
laplacians[axis][0, 0] = laplacians[axis][0, 0] + 1 / grid.step[axis] ** 2
laplacians[dim][0, 0] = laplacians[dim][0, 0] + 1 / grid.step[dim] ** 2
else:
laplacians[axis][-1, -1] = laplacians[axis][-1, -1] + 1 / grid.step[axis] ** 2
laplacians[dim][-1, -1] = laplacians[dim][-1, -1] + 1 / grid.step[dim] ** 2
# deletes corner dependencies on the "looped-around" part.
# this should be done irrespective of which side, since one boundary cannot
# be periodic while the other is.
laplacians[axis][0, -1] = 0.0
laplacians[axis][-1, 0] = 0.0
return laplacians
laplacians[dim][0, -1] = 0.0
laplacians[dim][-1, 0] = 0.0
return


def laplacian_matrix_w_boundaries(
Expand All @@ -277,6 +289,7 @@ def laplacian_matrix_w_boundaries(
bc: grids.BoundaryConditions,
laplacians: Optional[ArrayVector] = None,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
sparse: bool = False,
) -> ArrayVector:
"""Returns 1d laplacians that satisfy boundary conditions bc on grid.
Expand Down Expand Up @@ -323,11 +336,13 @@ def laplacian_matrix_w_boundaries(
raise NotImplementedError(
"edge-aligned Neumann boundaries are not implemented."
)
return list(lap.to(device) for lap in laplacians) if device else laplacians
return list(lap.to(dtype).to(device) for lap in laplacians)


def _linear_along_axis(c: GridVariable, offset: float, dim: int) -> GridVariable:
"""Linear interpolation of `c` to `offset` along a single specified `axis`."""
"""Linear interpolation of `c` to `offset` along a single specified `axis`.
dim here is >= 0, the negative indexing for batched implementation is handled by grids.shift.
"""
offset_delta = offset - c.offset[dim]

# If offsets are the same, `c` is unchanged.
Expand Down Expand Up @@ -383,8 +398,8 @@ def linear(
f"got {c.offset} and {offset}."
)
interpolated = c
for a, o in enumerate(offset):
interpolated = _linear_along_axis(interpolated, offset=o, dim=a)
for dim, o in enumerate(offset):
interpolated = _linear_along_axis(interpolated, offset=o, dim=dim)
return interpolated


Expand All @@ -405,15 +420,15 @@ def gradient_tensor(v):
if not isinstance(v, GridVariable):
return GridTensor(torch.stack([gradient_tensor(u) for u in v], dim=-1))
grad = []
for axis in range(v.grid.ndim):
offset = v.offset[axis]
for dim in range(-v.grid.ndim, 0):
offset = v.offset[dim]
if offset == 0:
derivative = forward_difference(v, axis)
derivative = forward_difference(v, dim)
elif offset == 1:
derivative = backward_difference(v, axis)
derivative = backward_difference(v, dim)
elif offset == 0.5:
v_centered = linear(v, v.grid.cell_center)
derivative = central_difference(v_centered, axis)
derivative = central_difference(v_centered, dim)
else:
raise ValueError(f"expected offset values in {{0, 0.5, 1}}, got {offset}")
grad.append(derivative)
Expand All @@ -427,4 +442,4 @@ def curl_2d(v: GridVariableVector) -> GridVariable:
grid = grids.consistent_grid_arrays(*v)
if grid.ndim != 2:
raise ValueError(f"Grid dimensionality is not 2: {grid.ndim}")
return forward_difference(v[1], dim=0) - forward_difference(v[0], dim=1)
return forward_difference(v[1], dim=-2) - forward_difference(v[0], dim=-1)
31 changes: 16 additions & 15 deletions torch_cfd/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

Grid = grids.Grid
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector


def forcing_eval(eval_func):
Expand Down Expand Up @@ -79,7 +80,7 @@ class ForcingFn(nn.Module):
def __init__(
self,
grid: Grid,
scale: float = 1,
scale: float = 1.0,
wave_number: int = 1,
diam: float = 1.0,
swap_xy: bool = False,
Expand All @@ -100,20 +101,20 @@ def __init__(

@forcing_eval
def velocity_eval(
grid: Grid, velocity: Optional[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
self, grid: Grid, velocity: Optional[Tuple[GridVariable, GridVariable]]
) -> GridVariableVector:
raise NotImplementedError

@forcing_eval
def vorticity_eval(grid: Grid, vorticity: Optional[torch.Tensor]) -> torch.Tensor:
def vorticity_eval(self, grid: Grid, vorticity: Optional[torch.Tensor]) -> GridVariable:
raise NotImplementedError

def forward(
self,
grid: Optional[Union[Grid, Tuple[Grid, Grid]]] = None,
velocity: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
vorticity: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Union[GridVariable, GridVariableVector]:
if not self.vorticity:
return self.velocity_eval(grid, velocity)
else:
Expand Down Expand Up @@ -166,7 +167,7 @@ def velocity_eval(
self,
grid: Optional[Grid],
velocity: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> GridVariableVector:
offsets = self.offsets
grid = self.grid if grid is None else grid
domain_factor = 2 * torch.pi / self.diam
Expand All @@ -187,13 +188,13 @@ def velocity_eval(
grid,
)
v = GridVariable(torch.zeros_like(u.data), (1 / 2, 1), grid)
return tuple((u, v))
return GridVariableVector(tuple((u, v)))

def vorticity_eval(
self,
grid: Optional[Grid],
vorticity: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> GridVariable:
offsets = self.offsets
grid = self.grid if grid is None else grid
domain_factor = 2 * torch.pi / self.diam
Expand Down Expand Up @@ -243,9 +244,9 @@ class SimpleSolenoidalForcing(ForcingFn):

def __init__(
self,
scale=1,
scale=1.0,
diam=1.0,
k=1.0,
wave_number=1,
offsets=((0, 0), (0, 0)),
vorticity=True,
*args,
Expand All @@ -255,7 +256,7 @@ def __init__(
*args,
scale=scale,
diam=diam,
wave_number=k,
wave_number=wave_number,
offsets=offsets,
vorticity=vorticity,
**kwargs,
Expand All @@ -273,7 +274,7 @@ def velocity_eval(
self,
grid: Optional[Grid],
velocity: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> GridVariableVector:
offsets = self.offsets
grid = self.grid if grid is None else grid
domain_factor = 2 * torch.pi / self.diam
Expand All @@ -292,7 +293,7 @@ def velocity_eval(
rot = self.potential(x, y, scale, k)
u = GridVariable(rot, offsets[0], grid)
v = GridVariable(-rot, (1 / 2, 1), grid)
return tuple((u, v))
return GridVariableVector(tuple((u, v)))

def vorticity_eval(
self,
Expand Down Expand Up @@ -339,7 +340,7 @@ def __init__(
self,
scale=0.1,
diam=1.0,
k=1.0,
wave_number=1,
offsets=((0, 0), (0, 0)),
*args,
**kwargs,
Expand All @@ -348,7 +349,7 @@ def __init__(
*args,
scale=scale,
diam=diam,
k=k,
wave_number=wave_number,
offsets=offsets,
**kwargs,
)
Expand Down
Loading