Skip to content

0.2.5 dev #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 78 additions & 46 deletions torch_cfd/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@


import math
from typing import Callable, Optional, Tuple
from functools import partial
from typing import Callable, Optional, Tuple

import torch
import torch.nn as nn

Expand All @@ -31,9 +32,11 @@
GridVariableVector = grids.GridVariableVector
FluxInterpFn = Callable[[GridVariable, GridVariableVector, float], GridVariable]


def default(value, d):
return d if value is None else value


def safe_div(x, y, default_numerator=1):
"""Safe division of `Array`'s."""
return x / torch.where(y != 0, y, default_numerator)
Expand All @@ -47,7 +50,7 @@ def van_leer_limiter(r):
class Upwind(nn.Module):
"""Upwind interpolation module for scalar fields.

Upwind interpolation of a scalar field `c` to a
Upwind interpolation of a scalar field `c` to a
target offset based on the velocity field `v`. The interpolation is done axis-wise and uses the upwind scheme where values are taken from upstream cells based on the flow direction.

The module identifies the interpolation axis (must be a single axis) and selects values from the previous cell for positive velocity or the next cell for negative velocity along that axis.
Expand All @@ -73,7 +76,9 @@ def __init__(
):
super().__init__()
self.grid = grid
self.target_offset = target_offset # this is the offset to which we will interpolate c
self.target_offset = (
target_offset # this is the offset to which we will interpolate c
)

def forward(
self,
Expand Down Expand Up @@ -113,7 +118,9 @@ def forward(
ceil = int(math.ceil(offset_delta))
c_floor = c.shift(floor, dim).data
c_ceil = c.shift(ceil, dim).data
return GridVariable(torch.where(u.data > 0, c_floor, c_ceil), self.target_offset, c.grid, c.bc)
return GridVariable(
torch.where(u.data > 0, c_floor, c_ceil), self.target_offset, c.grid, c.bc
)


class LaxWendroff(nn.Module):
Expand All @@ -132,7 +139,7 @@ class LaxWendroff(nn.Module):
Lax-Wendroff method can be used to form monotonic schemes when augmented with
a flux limiter. See https://en.wikipedia.org/wiki/Flux_limiter

Args:
Args:
grid: The computational grid on which interpolation is performed, only used for step.
offset: Target offset to which scalar fields will be interpolated during
forward passes. Target offset have the same length as `c.offset` in forward() and differ in at most one entry.
Expand All @@ -147,8 +154,8 @@ def __init__(
target_offset: Tuple[float, ...],
):
super().__init__()
self.grid = grid
self.target_offset = target_offset
self.grid = grid
self.target_offset = target_offset

def forward(
self,
Expand Down Expand Up @@ -189,8 +196,10 @@ def forward(
c_ceil = c.shift(ceil, dim).data
pos = c_floor + 0.5 * (1 - courant) * (c_ceil - c_floor)
neg = c_ceil - 0.5 * (1 + courant) * (c_ceil - c_floor)
return GridVariable(torch.where(u.data > 0, pos, neg),
self.target_offset, c.grid, c.bc)
return GridVariable(
torch.where(u.data > 0, pos, neg), self.target_offset, c.grid, c.bc
)


class AdvectAligned(nn.Module):
"""
Expand Down Expand Up @@ -277,16 +286,20 @@ def forward(self, cs: GridVariableVector, v: GridVariableVector) -> GridVariable
)

# Compute flux: cu
# if cs and v have different boundary conditions,
# if cs and v have different boundary conditions,
# flux's bc will become None
flux = GridVariableVector(tuple(c * u for c, u in zip(cs, v)))

# wrap flux with boundary conditions to flux if not periodic
# flux = GridVariableVector(
# tuple(bc.impose_bc(f) for f, bc in zip(flux, self.flux_bcs))
# )
flux = GridVariableVector(tuple(GridVariable(f.data, offset, f.grid, bc) for f, offset, bc in zip(flux, self.offsets, self.flux_bcs)))

flux = GridVariableVector(
tuple(
GridVariable(f.data, offset, f.grid, bc)
for f, offset, bc in zip(flux, self.offsets, self.flux_bcs)
)
)

# Return negative divergence of flux
# after taking divergence the bc becomes None
Expand Down Expand Up @@ -335,12 +348,12 @@ class TVDInterpolation(nn.Module):
http://www.ita.uni-heidelberg.de/~dullemond/lectures/num_fluid_2012/Chapter_4.pdf

Args:
target_offset: offset to which we will interpolate `c`.
target_offset: offset to which we will interpolate `c`.
Must have the same length as `c.offset` and differ in at most one entry. This offset should interface as other interpolation methods (take `c`, `v` and `dt` arguments and return value of `c` at offset `offset`).
limiter: flux limiter function that evaluates the portion of the correction (high_accuracy - low_accuracy) to add to low_accuracy solution based on the ratio of the consecutive gradients.
limiter: flux limiter function that evaluates the portion of the correction (high_accuracy - low_accuracy) to add to low_accuracy solution based on the ratio of the consecutive gradients.
Takes array as input and return array of weights. For more details see:
https://en.wikipedia.org/wiki/Flux_limiter

"""

def __init__(
Expand All @@ -353,8 +366,16 @@ def __init__(
):
super().__init__()
self.grid = grid
self.low_interp = Upwind(grid, target_offset=target_offset) if low_interp is None else low_interp
self.high_interp = LaxWendroff(grid, target_offset=target_offset) if high_interp is None else high_interp
self.low_interp = (
Upwind(grid, target_offset=target_offset)
if low_interp is None
else low_interp
)
self.high_interp = (
LaxWendroff(grid, target_offset=target_offset)
if high_interp is None
else high_interp
)
self.limiter = limiter
self.target_offset = target_offset

Expand All @@ -370,8 +391,9 @@ def forward(
v: GridVariableVector representing the velocity field.
dt: Time step size (not used in this interpolation).

Returns:
Interpolated scalar field c to a target offset using Van Leer flux limiting, which uses a combination of high and low order methods to produce monotonic interpolation method."""
Returns:
Interpolated scalar field c to a target offset using Van Leer flux limiting, which uses a combination of high and low order methods to produce monotonic interpolation method.
"""
for axis, axis_offset in enumerate(self.target_offset):
interpolation_offset = tuple(
[
Expand Down Expand Up @@ -420,7 +442,7 @@ class AdvectionBase(nn.Module):
4. Set the boundary condition on flux, which is inhereited from `c`.
5. Return the negative divergence of the flux.

Args:
Args:
grid: Grid.
offset: the current scalar field `c` to be advected.
bc_c: Boundary conditions for the scalar field `c`.
Expand All @@ -429,13 +451,14 @@ class AdvectionBase(nn.Module):

"""

def __init__(self,
grid: Grid,
offset: Tuple[float, ...],
bc_c: boundaries.BoundaryConditions,
bc_v: Tuple[boundaries.BoundaryConditions, ...],
limiter: Optional[Callable] = None,
) -> None:
def __init__(
self,
grid: Grid,
offset: Tuple[float, ...],
bc_c: boundaries.BoundaryConditions,
bc_v: Tuple[boundaries.BoundaryConditions, ...],
limiter: Optional[Callable] = None,
) -> None:
super().__init__()
self.grid = grid
self.offset = offset if offset is not None else (0.5,) * grid.ndim
Expand All @@ -450,19 +473,24 @@ def __init__(self,
),
)
self.advect_aligned = AdvectAligned(
grid=grid, bcs_c=(bc_c, bc_c), bcs_v=bc_v, offsets=self.target_offsets)
self._flux_interp = nn.ModuleList() # placeholder
self._velocity_interp = nn.ModuleList() # placeholder
grid=grid, bcs_c=(bc_c, bc_c), bcs_v=bc_v, offsets=self.target_offsets
)
self._flux_interp = nn.ModuleList() # placeholder
self._velocity_interp = nn.ModuleList() # placeholder

def __post_init__(self):
assert len(self._flux_interp) == len(self.target_offsets)
assert len(self._velocity_interp) == len(self.target_offsets)

for dim, interp in enumerate(self._flux_interp):
assert interp.target_offset == self.target_offsets[dim], f"Expected flux interpolation for dimension {dim} to have target offset {self.target_offsets[dim]}, but got {interp.target_offset}."

assert (
interp.target_offset == self.target_offsets[dim]
), f"Expected flux interpolation for dimension {dim} to have target offset {self.target_offsets[dim]}, but got {interp.target_offset}."

for dim, interp in enumerate(self._velocity_interp):
assert interp.target_offset == self.target_offsets[dim], f"Expected velocity interpolation for dimension {dim} to have target offset {self.target_offsets[dim]}, but got {interp.target_offset}."
assert (
interp.target_offset == self.target_offsets[dim]
), f"Expected velocity interpolation for dimension {dim} to have target offset {self.target_offsets[dim]}, but got {interp.target_offset}."

def flux_interp(
self,
Expand All @@ -478,7 +506,9 @@ def velocity_interp(
self, v: GridVariableVector, *args, **kwargs
) -> GridVariableVector:
"""Interpolate the velocity field `v` to the target offsets."""
return GridVariableVector(tuple(interp(u) for interp, u in zip(self._velocity_interp, v)))
return GridVariableVector(
tuple(interp(u) for interp, u in zip(self._velocity_interp, v))
)

def forward(
self,
Expand All @@ -490,10 +520,10 @@ def forward(
Args:
c: the scalar field to be advected.
v: representing the velocity field.

Returns:
An GridVariable containing the time derivative of `c` due to advection by `v`.

"""
aligned_v = self.velocity_interp(v)

Expand All @@ -507,7 +537,7 @@ class AdvectionLinear(AdvectionBase):
def __init__(
self,
grid: Grid,
offset = (0.5, 0.5),
offset=(0.5, 0.5),
bc_c: boundaries.BoundaryConditions = boundaries.periodic_boundary_conditions(
ndim=2
),
Expand All @@ -520,12 +550,14 @@ def __init__(
super().__init__(grid, offset, bc_c, bc_v)
self._flux_interp = nn.ModuleList(
LinearInterpolation(grid, target_offset=offset)
for offset in self.target_offsets)
for offset in self.target_offsets
)

self._velocity_interp = nn.ModuleList(
LinearInterpolation(grid, target_offset=offset)
for offset in self.target_offsets)

for offset in self.target_offsets
)


class AdvectionUpwind(AdvectionBase):
"""
Expand All @@ -535,9 +567,9 @@ class AdvectionUpwind(AdvectionBase):
- flux_interp: a Upwind interpolation for each component of the velocity field `v`.
- velocity_interp: a LinearInterpolation for each component of the velocity field `v`.

Args:
Args:
- offset: current offset of the scalar field `c` to be advected.

Returns:
Aligned advection of the scalar field `c` by the velocity field `v` using the target offsets on the control volume faces.
"""
Expand All @@ -557,8 +589,7 @@ def __init__(
):
super().__init__(grid, offset, bc_c, bc_v)
self._flux_interp = nn.ModuleList(
Upwind(grid, target_offset=offset)
for offset in self.target_offsets
Upwind(grid, target_offset=offset) for offset in self.target_offsets
)

self._velocity_interp = nn.ModuleList(
Expand All @@ -575,9 +606,9 @@ class AdvectionVanLeer(AdvectionBase):
- flux_interp: a TVDInterpolation with Upwind and LaxWendroff methods
- velocity_interp: a LinearInterpolation for each component of the velocity field `v`.

Args:
Args:
- offset: current offset of the scalar field `c` to be advected.

Returns:
Aligned advection of the scalar field `c` by the velocity field `v` using the target offsets on the control volume faces.
"""
Expand Down Expand Up @@ -611,6 +642,7 @@ def __init__(
for offset, bc in zip(self.target_offsets, bc_v)
)


class ConvectionVector(nn.Module):
"""Computes convection of a vector field `v` by the velocity field `u`.

Expand Down
Loading