diff --git a/torch_cfd/advection.py b/torch_cfd/advection.py index 7aea9ab..c7010e5 100644 --- a/torch_cfd/advection.py +++ b/torch_cfd/advection.py @@ -17,8 +17,9 @@ import math -from typing import Callable, Optional, Tuple from functools import partial +from typing import Callable, Optional, Tuple + import torch import torch.nn as nn @@ -31,9 +32,11 @@ GridVariableVector = grids.GridVariableVector FluxInterpFn = Callable[[GridVariable, GridVariableVector, float], GridVariable] + def default(value, d): return d if value is None else value + def safe_div(x, y, default_numerator=1): """Safe division of `Array`'s.""" return x / torch.where(y != 0, y, default_numerator) @@ -47,7 +50,7 @@ def van_leer_limiter(r): class Upwind(nn.Module): """Upwind interpolation module for scalar fields. - Upwind interpolation of a scalar field `c` to a + Upwind interpolation of a scalar field `c` to a target offset based on the velocity field `v`. The interpolation is done axis-wise and uses the upwind scheme where values are taken from upstream cells based on the flow direction. The module identifies the interpolation axis (must be a single axis) and selects values from the previous cell for positive velocity or the next cell for negative velocity along that axis. @@ -73,7 +76,9 @@ def __init__( ): super().__init__() self.grid = grid - self.target_offset = target_offset # this is the offset to which we will interpolate c + self.target_offset = ( + target_offset # this is the offset to which we will interpolate c + ) def forward( self, @@ -113,7 +118,9 @@ def forward( ceil = int(math.ceil(offset_delta)) c_floor = c.shift(floor, dim).data c_ceil = c.shift(ceil, dim).data - return GridVariable(torch.where(u.data > 0, c_floor, c_ceil), self.target_offset, c.grid, c.bc) + return GridVariable( + torch.where(u.data > 0, c_floor, c_ceil), self.target_offset, c.grid, c.bc + ) class LaxWendroff(nn.Module): @@ -132,7 +139,7 @@ class LaxWendroff(nn.Module): Lax-Wendroff method can be used to form monotonic schemes when augmented with a flux limiter. See https://en.wikipedia.org/wiki/Flux_limiter - Args: + Args: grid: The computational grid on which interpolation is performed, only used for step. offset: Target offset to which scalar fields will be interpolated during forward passes. Target offset have the same length as `c.offset` in forward() and differ in at most one entry. @@ -147,8 +154,8 @@ def __init__( target_offset: Tuple[float, ...], ): super().__init__() - self.grid = grid - self.target_offset = target_offset + self.grid = grid + self.target_offset = target_offset def forward( self, @@ -189,8 +196,10 @@ def forward( c_ceil = c.shift(ceil, dim).data pos = c_floor + 0.5 * (1 - courant) * (c_ceil - c_floor) neg = c_ceil - 0.5 * (1 + courant) * (c_ceil - c_floor) - return GridVariable(torch.where(u.data > 0, pos, neg), - self.target_offset, c.grid, c.bc) + return GridVariable( + torch.where(u.data > 0, pos, neg), self.target_offset, c.grid, c.bc + ) + class AdvectAligned(nn.Module): """ @@ -277,7 +286,7 @@ def forward(self, cs: GridVariableVector, v: GridVariableVector) -> GridVariable ) # Compute flux: cu - # if cs and v have different boundary conditions, + # if cs and v have different boundary conditions, # flux's bc will become None flux = GridVariableVector(tuple(c * u for c, u in zip(cs, v))) @@ -285,8 +294,12 @@ def forward(self, cs: GridVariableVector, v: GridVariableVector) -> GridVariable # flux = GridVariableVector( # tuple(bc.impose_bc(f) for f, bc in zip(flux, self.flux_bcs)) # ) - flux = GridVariableVector(tuple(GridVariable(f.data, offset, f.grid, bc) for f, offset, bc in zip(flux, self.offsets, self.flux_bcs))) - + flux = GridVariableVector( + tuple( + GridVariable(f.data, offset, f.grid, bc) + for f, offset, bc in zip(flux, self.offsets, self.flux_bcs) + ) + ) # Return negative divergence of flux # after taking divergence the bc becomes None @@ -335,12 +348,12 @@ class TVDInterpolation(nn.Module): http://www.ita.uni-heidelberg.de/~dullemond/lectures/num_fluid_2012/Chapter_4.pdf Args: - target_offset: offset to which we will interpolate `c`. + target_offset: offset to which we will interpolate `c`. Must have the same length as `c.offset` and differ in at most one entry. This offset should interface as other interpolation methods (take `c`, `v` and `dt` arguments and return value of `c` at offset `offset`). - limiter: flux limiter function that evaluates the portion of the correction (high_accuracy - low_accuracy) to add to low_accuracy solution based on the ratio of the consecutive gradients. + limiter: flux limiter function that evaluates the portion of the correction (high_accuracy - low_accuracy) to add to low_accuracy solution based on the ratio of the consecutive gradients. Takes array as input and return array of weights. For more details see: https://en.wikipedia.org/wiki/Flux_limiter - + """ def __init__( @@ -353,8 +366,16 @@ def __init__( ): super().__init__() self.grid = grid - self.low_interp = Upwind(grid, target_offset=target_offset) if low_interp is None else low_interp - self.high_interp = LaxWendroff(grid, target_offset=target_offset) if high_interp is None else high_interp + self.low_interp = ( + Upwind(grid, target_offset=target_offset) + if low_interp is None + else low_interp + ) + self.high_interp = ( + LaxWendroff(grid, target_offset=target_offset) + if high_interp is None + else high_interp + ) self.limiter = limiter self.target_offset = target_offset @@ -370,8 +391,9 @@ def forward( v: GridVariableVector representing the velocity field. dt: Time step size (not used in this interpolation). - Returns: - Interpolated scalar field c to a target offset using Van Leer flux limiting, which uses a combination of high and low order methods to produce monotonic interpolation method.""" + Returns: + Interpolated scalar field c to a target offset using Van Leer flux limiting, which uses a combination of high and low order methods to produce monotonic interpolation method. + """ for axis, axis_offset in enumerate(self.target_offset): interpolation_offset = tuple( [ @@ -420,7 +442,7 @@ class AdvectionBase(nn.Module): 4. Set the boundary condition on flux, which is inhereited from `c`. 5. Return the negative divergence of the flux. - Args: + Args: grid: Grid. offset: the current scalar field `c` to be advected. bc_c: Boundary conditions for the scalar field `c`. @@ -429,13 +451,14 @@ class AdvectionBase(nn.Module): """ - def __init__(self, - grid: Grid, - offset: Tuple[float, ...], - bc_c: boundaries.BoundaryConditions, - bc_v: Tuple[boundaries.BoundaryConditions, ...], - limiter: Optional[Callable] = None, - ) -> None: + def __init__( + self, + grid: Grid, + offset: Tuple[float, ...], + bc_c: boundaries.BoundaryConditions, + bc_v: Tuple[boundaries.BoundaryConditions, ...], + limiter: Optional[Callable] = None, + ) -> None: super().__init__() self.grid = grid self.offset = offset if offset is not None else (0.5,) * grid.ndim @@ -450,19 +473,24 @@ def __init__(self, ), ) self.advect_aligned = AdvectAligned( - grid=grid, bcs_c=(bc_c, bc_c), bcs_v=bc_v, offsets=self.target_offsets) - self._flux_interp = nn.ModuleList() # placeholder - self._velocity_interp = nn.ModuleList() # placeholder + grid=grid, bcs_c=(bc_c, bc_c), bcs_v=bc_v, offsets=self.target_offsets + ) + self._flux_interp = nn.ModuleList() # placeholder + self._velocity_interp = nn.ModuleList() # placeholder def __post_init__(self): assert len(self._flux_interp) == len(self.target_offsets) assert len(self._velocity_interp) == len(self.target_offsets) for dim, interp in enumerate(self._flux_interp): - assert interp.target_offset == self.target_offsets[dim], f"Expected flux interpolation for dimension {dim} to have target offset {self.target_offsets[dim]}, but got {interp.target_offset}." - + assert ( + interp.target_offset == self.target_offsets[dim] + ), f"Expected flux interpolation for dimension {dim} to have target offset {self.target_offsets[dim]}, but got {interp.target_offset}." + for dim, interp in enumerate(self._velocity_interp): - assert interp.target_offset == self.target_offsets[dim], f"Expected velocity interpolation for dimension {dim} to have target offset {self.target_offsets[dim]}, but got {interp.target_offset}." + assert ( + interp.target_offset == self.target_offsets[dim] + ), f"Expected velocity interpolation for dimension {dim} to have target offset {self.target_offsets[dim]}, but got {interp.target_offset}." def flux_interp( self, @@ -478,7 +506,9 @@ def velocity_interp( self, v: GridVariableVector, *args, **kwargs ) -> GridVariableVector: """Interpolate the velocity field `v` to the target offsets.""" - return GridVariableVector(tuple(interp(u) for interp, u in zip(self._velocity_interp, v))) + return GridVariableVector( + tuple(interp(u) for interp, u in zip(self._velocity_interp, v)) + ) def forward( self, @@ -490,10 +520,10 @@ def forward( Args: c: the scalar field to be advected. v: representing the velocity field. - + Returns: An GridVariable containing the time derivative of `c` due to advection by `v`. - + """ aligned_v = self.velocity_interp(v) @@ -507,7 +537,7 @@ class AdvectionLinear(AdvectionBase): def __init__( self, grid: Grid, - offset = (0.5, 0.5), + offset=(0.5, 0.5), bc_c: boundaries.BoundaryConditions = boundaries.periodic_boundary_conditions( ndim=2 ), @@ -520,12 +550,14 @@ def __init__( super().__init__(grid, offset, bc_c, bc_v) self._flux_interp = nn.ModuleList( LinearInterpolation(grid, target_offset=offset) - for offset in self.target_offsets) + for offset in self.target_offsets + ) self._velocity_interp = nn.ModuleList( LinearInterpolation(grid, target_offset=offset) - for offset in self.target_offsets) - + for offset in self.target_offsets + ) + class AdvectionUpwind(AdvectionBase): """ @@ -535,9 +567,9 @@ class AdvectionUpwind(AdvectionBase): - flux_interp: a Upwind interpolation for each component of the velocity field `v`. - velocity_interp: a LinearInterpolation for each component of the velocity field `v`. - Args: + Args: - offset: current offset of the scalar field `c` to be advected. - + Returns: Aligned advection of the scalar field `c` by the velocity field `v` using the target offsets on the control volume faces. """ @@ -557,8 +589,7 @@ def __init__( ): super().__init__(grid, offset, bc_c, bc_v) self._flux_interp = nn.ModuleList( - Upwind(grid, target_offset=offset) - for offset in self.target_offsets + Upwind(grid, target_offset=offset) for offset in self.target_offsets ) self._velocity_interp = nn.ModuleList( @@ -575,9 +606,9 @@ class AdvectionVanLeer(AdvectionBase): - flux_interp: a TVDInterpolation with Upwind and LaxWendroff methods - velocity_interp: a LinearInterpolation for each component of the velocity field `v`. - Args: + Args: - offset: current offset of the scalar field `c` to be advected. - + Returns: Aligned advection of the scalar field `c` by the velocity field `v` using the target offsets on the control volume faces. """ @@ -611,6 +642,7 @@ def __init__( for offset, bc in zip(self.target_offsets, bc_v) ) + class ConvectionVector(nn.Module): """Computes convection of a vector field `v` by the velocity field `u`. diff --git a/torch_cfd/boundaries.py b/torch_cfd/boundaries.py index 708cfc1..6399128 100644 --- a/torch_cfd/boundaries.py +++ b/torch_cfd/boundaries.py @@ -18,7 +18,7 @@ import dataclasses import math -from typing import Optional, Sequence, Tuple, Union +from typing import Callable, List, Optional, Sequence, Tuple, Union import torch @@ -32,6 +32,7 @@ BCType = grids.BCType() Padding = grids.Padding() +BCValues = grids.BCValues @dataclasses.dataclass(init=False, frozen=True, repr=False) @@ -57,6 +58,7 @@ class ConstantBoundaryConditions(BoundaryConditions): _types: Tuple[Tuple[str, str], ...] bc_values: Tuple[Tuple[Optional[float], Optional[float]], ...] + ndim: int def __init__( self, @@ -226,7 +228,7 @@ def pad_all( """Pads along all axes with pad width specified by width tuple. Args: - u: a `GridArray` object. + u: a `GridVariable` object. width: Tuple of padding width for each side for each axis. mode: type of padding to use in non-periodic case. Mirror mirrors the array values across the boundary. @@ -260,8 +262,14 @@ def values( value = self.bc_values[dim][-i] if value is None: bc.append(None) - else: + elif isinstance(value, float): bc.append(torch.full(grid.shape[:dim] + grid.shape[dim + 1 :], value)) + elif isinstance(value, torch.Tensor): + if value.shape != grid.shape[:dim] + grid.shape[dim + 1 :]: + raise ValueError( + f"Boundary value shape {value.shape} does not match expected shape {grid.shape[:dim] + grid.shape[dim + 1 :]}" + ) + bc.append(value) return tuple(bc) @@ -393,7 +401,10 @@ def impose_bc(self, u: GridVariable, mode: str = "") -> GridVariable: u: a `GridVariable` object. Returns: - A GridVariable that has correct boundary values. If ghost_cell == True, then ghost cells are added on the other side of DoFs living at cell center if the bc is Dirichlet or Neumann. + A GridVariable that has correct boundary values. + + Notes: + If one needs ghost_cells, please use a manual function pad_all to add ghost cells are added on the other side of DoFs living at cell center if the bc is Dirichlet or Neumann. """ offset = u.offset u = self.trim_boundary(u) @@ -435,6 +446,7 @@ def is_bc_periodic_boundary_conditions(bc: BoundaryConditions, dim: int) -> bool ) return True + def is_bc_all_periodic_boundary_conditions(bc: BoundaryConditions) -> bool: """Returns true if scalar has periodic bc along all axes.""" for dim in range(bc.ndim): @@ -559,6 +571,487 @@ def periodic_and_neumann_boundary_conditions( ) +@dataclasses.dataclass(init=False, frozen=True, repr=False) +class DiscreteBoundaryConditions(ConstantBoundaryConditions): + """Boundary conditions that can vary spatially along the boundary. + + Array-based values that are evaluated at boundary nodes with proper offsets. The values must match a variable's offset in order that the numerical differentiation is correct. + + Attributes: + types: boundary condition types for each dimension + bc_values: boundary values that can be: + - torch.Tensor: precomputed values along boundary + - None: homogeneous boundary condition + + Example usage: + # Array-based boundary conditions + grid = Grid((10, 20)) + x_boundary = torch.linspace(0, 1, 20) # values along y-axis + y_boundary = torch.sin(torch.linspace(0, 2*np.pi, 10)) # values along x-axis + + bc = VariableBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.NEUMANN, BCType.DIRICHLET)), + values=((y_boundary, y_boundary), # left/right boundaries + (None, x_boundary)) # bottom/top boundaries + ) + """ + + _types: Tuple[Tuple[str, str], ...] + _bc_values: Tuple[Tuple[Union[float, BCValues], Union[float, BCValues]], ...] + ndim: int # default 2d, dataclass init=False + + def __init__( + self, + types: Sequence[Tuple[str, str]], + values: Sequence[ + Tuple[ + Union[float, BCValues], + Union[float, BCValues], + ] + ], + ): + types = tuple(types) + values = tuple(values) + object.__setattr__(self, "_types", types) + object.__setattr__(self, "_bc_values", values) + object.__setattr__(self, "ndim", len(types)) + + @property + def has_callable(self) -> bool: + """Check if any boundary values are callable functions.""" + for dim in range(self.ndim): + for side in range(2): + if callable(self._bc_values[dim][side]): + return True + return False + + def _validate_boundary_arrays_with_grid(self, grid: Grid): + """Validate boundary arrays against grid dimensions.""" + for dim in range(self.ndim): + for side in range(2): + value = self._bc_values[dim][side] + if isinstance(value, torch.Tensor): + # Calculate expected boundary shape + expected_shape = grid.shape[:dim] + grid.shape[dim + 1 :] + if len(expected_shape) == 0: + # 1D case - boundary is a scalar + if value.numel() != 1: + raise ValueError( + f"Boundary array for 1D grid at dim {dim}, side {side} " + f"should be a scalar, got shape {value.shape}" + ) + elif value.ndim == self.ndim-1 and value.shape != expected_shape: + raise ValueError( + f"Boundary array for dim {dim}, side {side} has shape " + f"{value.shape}, expected {expected_shape}" + ) + + @property + def bc_values( + self, + ) -> Sequence[Tuple[Optional[BCValues], Optional[BCValues]]]: + """Returns boundary values as tensors for each boundary. + + For callable boundary conditions, this will raise an error asking the user + to use FunctionBoundaryConditions instead. + For float boundary conditions, returns tensors with the constant value. + For tensor boundary conditions, returns them as-is. + For None, returns None. + """ + if self.has_callable: + raise ValueError( + "Callable boundary conditions detected. Please use " + "FunctionBoundaryConditions class for callable boundary conditions." + ) + + # Process non-callable values + result = [] + for dim in range(self.ndim): + dim_values = [] + for side in range(2): + value = self._bc_values[dim][side] + if value is None: + dim_values.append(None) + elif isinstance(value, torch.Tensor): + dim_values.append(value) + elif isinstance(value, (int, float)): + # Return scalar tensor for float values + dim_values.append(torch.tensor(float(value))) + else: + raise ValueError(f"Unsupported boundary value type: {type(value)}") + result.append(tuple(dim_values)) + return tuple(result) + + def __repr__(self) -> str: + try: + lines = [f"VariableBoundaryConditions({self.ndim}D):"] + + for dim in range(self.ndim): + lower_type, upper_type = self.types[dim] + lower_val, upper_val = self._bc_values[dim] + + # Format values based on type + def format_value(val): + if val is None: + return "None" + elif isinstance(val, torch.Tensor): + return f"Tensor{tuple(val.shape)}" + elif callable(val): + return f"Callable({val.__name__ if hasattr(val, '__name__') else 'lambda'})" + else: + return str(val) + + lower_val_str = format_value(lower_val) + upper_val_str = format_value(upper_val) + + lines.append( + f" dim {dim}: [{lower_type}({lower_val_str}), {upper_type}({upper_val_str})]" + ) + + return "\n".join(lines) + except Exception as e: + return f"VariableBoundaryConditions not initialized: {e}" + + def clone( + self, + types: Optional[Sequence[Tuple[str, str]]] = None, + values: Optional[ + Sequence[ + Tuple[ + BCValues, + BCValues, + ] + ] + ] = None, + ) -> BoundaryConditions: + """Creates a copy with optionally modified parameters.""" + new_types = types if types is not None else self.types + new_values = values if values is not None else self._bc_values + return DiscreteBoundaryConditions(new_types, new_values) + + def _boundary_slices( + self, offset: Tuple[float, ...] + ) -> Tuple[Tuple[Optional[slice], Optional[slice]], ...]: + """Returns slices for boundary values after considering trimming effects. + + When a GridVariable with certain offsets gets trimmed, the boundary coordinates + need to be sliced accordingly to match the trimmed interior data. + Currently, this only works for 2D grids (spatially the variable lives on a 2D grid, i.e., good for 2D+time+channel variables). + + Args: + offset: The offset of the GridVariable + grid: The grid associated with the GridVariable + + Returns: + A tuple of (lower_slice, upper_slice) for each dimension, where each slice + indicates how to index the boundary values for that dimension and side. + (None, None) means no slicing needed (use full boundary array). + """ + if self.ndim > 2: + raise NotImplementedError( + "Multi-dimensional boundary slicing not implemented" + ) + if len(offset) != self.ndim: + raise ValueError( + f"Offset length {len(offset)} doesn't match number of sets of boundary edges {self.ndim}" + ) + + # Initialize with default "no slicing" tuples + slices: List[Tuple[Optional[slice], Optional[slice]]] = [ + (None, None), + (None, None), + ] + + for dim in range(self.ndim): + other_dim = dim ^ 1 # flip the bits to get the other dimension index + trimmed_lower = math.isclose(offset[dim], 0.0) + trimmed_upper = math.isclose(offset[dim], 1.0) + + assert not ( + trimmed_lower and trimmed_upper + ), "MAC grids cannot ahve both lower and upper trimmed for bc." + if trimmed_lower: + slices[other_dim] = (slice(1, None), slice(1, None)) + elif trimmed_upper: + slices[other_dim] = (slice(None, -1), slice(None, -1)) + # else: keep the default (None, None) + + return tuple(slices) + + def _boundary_mesh( + self, + dim: int, + grid: Grid, + offset: Tuple[float, ...], + ) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]: + """Get coordinate arrays for boundary points along dimension dim.""" + # Use the Grid's boundary_mesh method and return coordinates for lower boundary + # (both lower and upper have same coordinate structure for the boundary points) + return grid.boundary_mesh(dim, offset) + + def pad_and_impose_bc( + self, + u: GridVariable, + offset_to_pad_to: Optional[Tuple[float, ...]] = None, + mode: Optional[str] = "", + ) -> GridVariable: + """Pad and impose variable boundary conditions.""" + assert u.bc is None, "u must be trimmed before padding and imposing bc." + if offset_to_pad_to is None: + offset_to_pad_to = u.offset + + bc_values = self.bc_values + boundary_slices = self._boundary_slices(offset_to_pad_to) + x_boundary_slice = boundary_slices[-2] + + if not all(s is None for s in x_boundary_slice): + # Apply slicing to boundary values + new_bc_values = list(list(v) for v in bc_values) + for i in range(2): + if bc_values[-2][i] is not None: + if bc_values[-2][i].ndim > 0: + new_bc_values[-2][i] = bc_values[-2][i][x_boundary_slice[i]] + else: + new_bc_values[-2][i] = bc_values[-2][i] + bc_values = new_bc_values + + for dim in range(-u.grid.ndim, 0): + _ = self._is_aligned(u, dim) + if self.types[dim][0] != BCType.PERIODIC: + # the values passed to grids.pad should consider the offset of the variable + # if the offset is 1, the the trimmed variable will have the upper edge of that dimension trimmed, one only needs n-1 entries. + if mode: + u = grids.pad(u, (1, 1), dim, self, mode=mode, values=bc_values) + elif self.types[dim][0] == BCType.DIRICHLET and not mode: + if math.isclose(offset_to_pad_to[dim], 1.0): + u = grids.pad(u, 1, dim, self, values=bc_values) + elif math.isclose(offset_to_pad_to[dim], 0.0): + u = grids.pad(u, -1, dim, self, values=bc_values) + elif self.types[dim][0] == BCType.NEUMANN and not mode: + if not math.isclose(offset_to_pad_to[dim], 0.5): + raise ValueError("Neumann bc is not defined on edges.") + else: + raise NotImplementedError( + f"Padding for {self.types[dim][0]} boundary conditions is not implemented." + ) + + return GridVariable(u.data, u.offset, u.grid, self) + + +@dataclasses.dataclass(init=False, frozen=True, repr=False) +class FunctionBoundaryConditions(DiscreteBoundaryConditions): + """Boundary conditions defined by callable functions. + + This class handles boundary conditions that are defined as functions of + spatial coordinates (and optionally time). The functions are automatically + evaluated on the boundary mesh during initialization. + + Attributes: + types: boundary condition types for each dimension + _bc_values: evaluated boundary values (tensors/floats after evaluation) + ndim: number of spatial dimensions + + Example usage: + # Function-based boundary conditions with individual functions + def left_bc(x, y): + return torch.sin(y) + + def right_bc(x, y): + return torch.cos(y) + + grid = Grid((10, 20)) + bc = FunctionBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.NEUMANN, BCType.DIRICHLET)), + values=((left_bc, right_bc), # left/right boundaries + (None, lambda x, y: x**2)) # bottom/top boundaries + grid=grid, + offset=(0.5, 0.5) + ) + + # Or with a single function applied to all boundaries + def global_bc(x, y): + return x + y + + bc = FunctionBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.DIRICHLET, BCType.DIRICHLET)), + values=global_bc, # Single function applied everywhere + grid=grid, + offset=(0.5, 0.5) + ) + """ + + _raw_bc_values: Tuple[ + Tuple[ + Union[ + Callable[..., torch.Tensor], + Union[Callable[..., torch.Tensor], BCValues, float], + ], + BCValues, + float, + ], + ..., + ] + + def __init__( + self, + types: Sequence[Tuple[str, str]], + values: Union[ + Callable[..., torch.Tensor], # Single function for all boundaries + Sequence[ + Tuple[ + Union[Callable[..., torch.Tensor], BCValues, float], + Union[Callable[..., torch.Tensor], BCValues, float], + ] + ], + ], + grid: Grid, + offset: Optional[Tuple[float, ...]] = None, + time: Optional[torch.Tensor] = None, + ): + """Initialize function-based boundary conditions. + + Args: + types: boundary condition types for each dimension + values: boundary values that can be: + - Single Callable: function to apply to all boundaries + - Sequence of tuples: individual values per boundary that can be: + - Callable: function to evaluate on boundary mesh + - torch.Tensor: precomputed values along boundary + - float/int: constant value + - None: homogeneous boundary condition + grid: Grid to evaluate boundary conditions on + offset: Grid offset for boundary coordinate calculation + time: Optional time parameter for time-dependent boundary conditions + """ + types = tuple(types) + + # Handle single callable function case + if callable(values): + # Apply the same function to all boundaries + ndim = len(types) + values = tuple((values, values) for _ in range(ndim)) + else: + values = tuple(values) + + # Set basic attributes first + object.__setattr__(self, "_types", types) + object.__setattr__(self, "ndim", len(types)) + object.__setattr__(self, "_raw_bc_values", values) + + if offset is None: + offset = grid.cell_center + + # Evaluate callable boundary conditions + evaluated_values = [] + + for dim in range(len(types)): + dim_values = [] + + # Get boundary coordinates for this dimension if needed + boundary_coords = None + + for side in range(2): + value = values[dim][side] + + if value is None: + dim_values.append(None) + elif isinstance(value, torch.Tensor): + dim_values.append(value) + elif isinstance(value, (int, float)): + dim_values.append(torch.tensor(float(value))) + elif isinstance(value, Callable): + # Get boundary coordinates if not already computed + if boundary_coords is None: + lower_coords, upper_coords = grid.boundary_mesh(dim, offset) + boundary_coords = (lower_coords, upper_coords) + + # Evaluate callable on appropriate boundary + boundary_points = boundary_coords[side] + if time is not None: + evaluated_value = value(*boundary_points, t=time) + else: + evaluated_value = value(*boundary_points) + dim_values.append(evaluated_value) + else: + raise ValueError(f"Unsupported boundary value type: {type(value)}") + + evaluated_values.append(tuple(dim_values)) + + # Set the evaluated values + object.__setattr__(self, "_bc_values", tuple(evaluated_values)) + + # Validate the evaluated arrays + self._validate_boundary_arrays_with_grid(grid) + + @property + def has_callable(self) -> bool: + """Always returns False since all callables are evaluated during init.""" + return False + + @property + def bc_values( + self, + ) -> Sequence[Tuple[Optional[BCValues], Optional[BCValues]]]: + """Returns boundary values as tensors for each boundary. + + Since all callable functions are evaluated during initialization, + this property will never encounter callable values and always returns + the evaluated tensor/float values. + """ + # Process all values (no callables should exist at this point) + result = [] + for dim in range(self.ndim): + dim_values = [] + for side in range(2): + value = self._bc_values[dim][side] + if value is None: + dim_values.append(None) + elif isinstance(value, torch.Tensor): + dim_values.append(value) + elif isinstance(value, (int, float)): + # Return scalar tensor for float values + dim_values.append(torch.tensor(float(value))) + else: + raise ValueError( + f"Unexpected boundary value type after evaluation: {type(value)}" + ) + result.append(tuple(dim_values)) + return tuple(result) + + +def dirichlet_boundary_conditions_nonhomogeneous( + ndim: int, + bc_values: Sequence[Tuple[BCValues, BCValues]], +) -> DiscreteBoundaryConditions: + """Create variable Dirichlet boundary conditions.""" + types = ((BCType.DIRICHLET, BCType.DIRICHLET),) * ndim + return DiscreteBoundaryConditions(types, bc_values) + + +def neumann_boundary_conditions_nonhomogeneous( + ndim: int, + bc_values: Sequence[Tuple[BCValues, BCValues],], +) -> DiscreteBoundaryConditions: + """Create variable Neumann boundary conditions.""" + types = ((BCType.NEUMANN, BCType.NEUMANN),) * ndim + return DiscreteBoundaryConditions(types, bc_values) + +def function_boundary_conditions_nonhomogeneous( + ndim: int, + bc_function: Callable[..., torch.Tensor], + bc_type: str, + grid: Grid, + offset: Optional[Tuple[float, ...]] = None, + time: Optional[torch.Tensor] = None, +) -> FunctionBoundaryConditions: + """Create function boundary conditions with the same function applied to all boundaries. + """ + types = ((bc_type, bc_type),) * ndim + return FunctionBoundaryConditions(types, bc_function, grid, offset, time) + def _count_bc_components(bc: BoundaryConditions) -> int: """Counts the number of components in the boundary conditions. @@ -567,12 +1060,12 @@ def _count_bc_components(bc: BoundaryConditions) -> int: """ count = 0 ndim = len(bc.types) - for axis in range(ndim): # ndim - if len(bc.types[axis]) != 2: + for dim in range(ndim): # ndim + if len(bc.types[dim]) != 2: raise ValueError( - f"Boundary conditions for axis {axis} must have two values got {len(bc.types[axis])}." + f"Boundary conditions for axis {dim} must have two values got {len(bc.types[dim])}." ) - count += len(bc.types[axis]) + count += len(bc.types[dim]) return count @@ -612,8 +1105,8 @@ def consistent_boundary_conditions_gridvariable( they are consistent. """ bc_types = [] - for axis in range(arrays[0].grid.ndim): - bcs = {is_periodic_boundary_conditions(array, axis) for array in arrays} + for dim in range(arrays[0].grid.ndim): + bcs = {is_periodic_boundary_conditions(array, dim) for array in arrays} if len(bcs) != 1: raise Exception(f"arrays do not have consistent bc: {arrays}") elif bcs.pop(): @@ -660,8 +1153,8 @@ def get_pressure_bc_from_velocity( def has_all_periodic_boundary_conditions(*arrays: GridVariable) -> bool: """Returns True if arrays have periodic BC in every dimension, else False.""" for array in arrays: - for axis in range(array.grid.ndim): - if not is_periodic_boundary_conditions(array, axis): + for dim in range(array.grid.ndim): + if not is_periodic_boundary_conditions(array, dim): return False return True @@ -723,11 +1216,11 @@ def get_advection_flux_bc_from_velocity_and_scalar_bc( f"Flux boundary condition is not implemented for scalar with {type(c_bc)}" ) - for axis in range(c_bc.ndim): - if u_bc.types[axis][0] == BCType.PERIODIC: + for dim in range(c_bc.ndim): + if u_bc.types[dim][0] == BCType.PERIODIC: flux_bc_types.append((BCType.PERIODIC, BCType.PERIODIC)) flux_bc_values.append((None, None)) - elif flux_direction != axis: + elif flux_direction != dim: # Flux boundary condition parallel to flux direction # Set to homogeneous Dirichlet as it doesn't affect divergence computation flux_bc_types.append((BCType.DIRICHLET, BCType.DIRICHLET)) @@ -738,10 +1231,10 @@ def get_advection_flux_bc_from_velocity_and_scalar_bc( flux_bc_values_ax = [] for i in range(2): # lower and upper boundary - u_type = u_bc.types[axis][i] - c_type = c_bc.types[axis][i] - u_val = u_values[axis][i] if u_values[axis][i] is not None else 0.0 - c_val = c_values[axis][i] if c_values[axis][i] is not None else 0.0 + u_type = u_bc.types[dim][i] + c_type = c_bc.types[dim][i] + u_val = u_values[dim][i] if u_values[dim][i] is not None else 0.0 + c_val = c_values[dim][i] if c_values[dim][i] is not None else 0.0 # Case 1: Dirichlet velocity with Dirichlet scalar if u_type == BCType.DIRICHLET and c_type == BCType.DIRICHLET: diff --git a/torch_cfd/finite_differences.py b/torch_cfd/finite_differences.py index 5be7751..50be02c 100644 --- a/torch_cfd/finite_differences.py +++ b/torch_cfd/finite_differences.py @@ -36,17 +36,41 @@ import math import typing -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union from functools import reduce import operator import torch from torch_cfd import boundaries, grids -ArrayVector = List[torch.Tensor] +TensorList = Sequence[torch.Tensor] GridVariable = grids.GridVariable GridTensor = grids.GridTensor GridVariableVector = Union[grids.GridVariableVector, Sequence[grids.GridVariable]] +def trim_boundary(u): + # fixed jax-cfd bug that trims all dimension for a batched GridVariable + if isinstance(u, grids.GridVariable): + trimmed_slices = () + for dim in range(-u.grid.ndim, 0): + if u.offset[dim] == 0: + trimmed_slice = slice(1, None) + elif u.offset[dim] == 1: + trimmed_slice = slice(None, -1) + elif u.offset[dim] == 0.5: + trimmed_slice = slice(1, -1) + elif u.offset[dim] < 0: + trimmed = math.floor(u.offset[dim]) + trimmed_slice = slice(-trimmed, None) + elif u.offset[dim] > 1: + trimmed = math.floor(u.offset[dim]) + trimmed_slice = slice(None, -trimmed) + trimmed_slices += (trimmed_slice,) + data = u.data[(..., *trimmed_slices)] + return grids.GridVariable(data, u.offset, u.grid) + else: + u = torch.as_tensor(u) + trimmed_slices = (slice(1, -1),) * u.ndim + return u[(..., *trimmed_slices)] def stencil_sum(*arrays: GridVariable, return_tensor=False) -> Union[GridVariable, torch.Tensor]: """ @@ -160,7 +184,7 @@ def set_laplacian_matrix( bc: boundaries.BoundaryConditions, device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, -) -> ArrayVector: +) -> TensorList: """Initialize the Laplacian operators.""" offset = grid.cell_center @@ -198,7 +222,7 @@ def laplacian_matrix(n: int, step: float, sparse: bool = False, dtype=torch.floa def _laplacian_boundary_dirichlet_cell_centered( - laplacians: ArrayVector, grid: grids.Grid, dim: int, side: str + laplacians: TensorList, grid: grids.Grid, dim: int, side: str ) -> None: """Converts 1d laplacian matrix to satisfy dirichlet homogeneous bc. @@ -251,7 +275,7 @@ def _laplacian_boundary_dirichlet_cell_centered( def _laplacian_boundary_neumann_cell_centered( - laplacians: ArrayVector, grid: grids.Grid, dim: int, side: str + laplacians: TensorList, grid: grids.Grid, dim: int, side: str ) -> None: """Converts 1d laplacian matrix to satisfy neumann homogeneous bc. @@ -287,11 +311,11 @@ def laplacian_matrix_w_boundaries( grid: grids.Grid, offset: Tuple[float, ...], bc: grids.BoundaryConditions, - laplacians: Optional[ArrayVector] = None, + laplacians: Optional[TensorList] = None, device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, sparse: bool = False, -) -> ArrayVector: +) -> TensorList: """Returns 1d laplacians that satisfy boundary conditions bc on grid. Given grid, offset and boundary conditions, returns a list of 1 laplacians diff --git a/torch_cfd/grids.py b/torch_cfd/grids.py index 35d788f..b752918 100644 --- a/torch_cfd/grids.py +++ b/torch_cfd/grids.py @@ -44,6 +44,9 @@ class BCType: NONE = None +BCValues = Union[torch.Tensor, None] + + class Padding: MIRROR = "reflect" EXTEND = "replicate" @@ -65,6 +68,13 @@ class Grid: - `step[i]` is the width of each grid cell. - `(lower, upper) = domain[i]` gives the locations of lower and upper boundaries. The identity `upper - lower = step[i] * shape[i]` is enforced. + + Args: + shape: (nx, ny) + step: (dx, dy) or a single float for isotropic grids. + domain: ((x0, x1), (y0, y1)), by default if only step + is given the domain ((0, 1), (0, 1)). + device: the device of the output grid.mesh(). """ shape: Tuple[int, ...] @@ -159,12 +169,16 @@ def cell_faces(self) -> Tuple[Tuple[float, ...], ...]: def stagger(self, v: Tuple[torch.Tensor, ...]) -> Tuple[GridVariable, ...]: """Places the velocity components of `v` on the `Grid`'s cell faces.""" offsets = self.cell_faces - return tuple(GridVariable(u, o, self) for u, o in zip(v, offsets)) + return GridVariableVector( + tuple(GridVariable(u, o, self) for u, o in zip(v, offsets)) + ) def center(self, v: Tuple[torch.Tensor, ...]) -> Tuple[GridVariable, ...]: """Places all arrays in the pytree `v` at the `Grid`'s cell center.""" offset = self.cell_center - return tuple(GridVariable(tensor, offset, self) for tensor in v) + return GridVariableVector( + tuple(GridVariable(tensor, offset, self) for tensor in v) + ) def axes( self, offset: Optional[Sequence[float]] = None @@ -240,20 +254,94 @@ def eval_on_mesh( self, fn: Callable[..., torch.Tensor], offset: Optional[Tuple[float, ...]] = None, + bc: Optional[BoundaryConditions] = None, ) -> GridVariable: """Evaluates the function on the grid mesh with the specified offset. Args: fn: A function that accepts the mesh arrays and returns an array. - offset: an optional sequence of length `ndim`. If not specified, uses the - offset for the cell center. + offset: an optional sequence of length `ndim`. If not specified, uses the offset for the cell center. + bc: optional boundary conditions to wrap the variable with. Returns: fn(x, y, ...) evaluated on the mesh, as a GridArray with specified offset. + + Example: + >>> f = lambda x, y: x + 2 * y + >>> grid = Grid((4, 4), domain=((0.0, 1.0), (0.0, 1.0))) + >>> offset = (0, 0) + >>> u = grid.eval_on_mesh(f, offset) + """ + if offset is None: + offset = self.cell_center + return GridVariable(fn(*self.mesh(offset)), offset, self, bc) + + def boundary_mesh( + self, + dim: int, + offset: Optional[Tuple[float, ...]] = None, + ) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]: + """Get coordinate arrays for boundary points along dimension dim. + + Args: + dim: The dimension along which to get boundary coordinates + offset: Grid offset for coordinate calculation + + Returns: + A tuple of (lower_boundary_coords, upper_boundary_coords) where each + contains coordinate arrays for the boundary points. + - 1D case: returns ((domain[0],), (domain[1],)) - scalar boundary points + - 2D case: + * dim=0 (x boundaries): ((x_left, y_coords), (x_right, y_coords)) + dim 0 corresponds to u[0, :] and u[-1, :] + * dim=1 (y boundaries): ((x_coords, y_left), (x_coords, y_right)) + dim 1 corresponds to u[:, 0] and u[:, -1] """ if offset is None: offset = self.cell_center - return GridVariable(fn(*self.mesh(offset)), offset, self) + + # Handle 1D case + if self.ndim == 1: + lower_boundary = torch.tensor(self.domain[0][0], device=self.device) + upper_boundary = torch.tensor(self.domain[0][1], device=self.device) + return ((lower_boundary,), (upper_boundary,)) + + # Handle 2D case + elif self.ndim == 2: + if dim < 0: + dim = self.ndim + dim + + if dim not in (0, 1): + raise ValueError(f"dim must be 0 or 1 for 2D grids, got {dim}") + + # Use XOR to determine which dimension varies along the boundary + other_dim = dim ^ 1 # 0 ^ 1 = 1, 1 ^ 1 = 0 + + # Get coordinates for the varying dimension (same for both boundaries) + bd_varying_coords = ( + self.domain[other_dim][0] + + (torch.arange(self.shape[other_dim], device=self.device) + offset[other_dim]) + * self.step[other_dim] + ) + + # Get boundary coordinates for the fixed dimension + lower_coord = self.domain[dim][0] + upper_coord = self.domain[dim][1] + + # Create coordinate arrays for boundaries + lower_fixed_coords = torch.full_like(bd_varying_coords, lower_coord) + upper_fixed_coords = torch.full_like(bd_varying_coords, upper_coord) + + # Arrange coordinates in proper order based on dimension + if dim == 0: # x boundaries + return ((lower_fixed_coords, bd_varying_coords), (upper_fixed_coords, bd_varying_coords)) + else: # dim == 1, y boundaries + return ((bd_varying_coords, lower_fixed_coords), (bd_varying_coords, upper_fixed_coords)) + + else: + raise NotImplementedError( + f"boundary_mesh not implemented for {self.ndim}D grids" + ) @dataclasses.dataclass(init=False, frozen=True) @@ -266,6 +354,8 @@ class BoundaryConditions: """ types: Tuple[Tuple[str, str], ...] + bc_values: Tuple[Tuple[BCValues, BCValues], ...] + ndim: int def shift( self, @@ -511,6 +601,7 @@ class GridVariable(GridTensorOpsMixin): - [x] (0.1.1) In original Google Research's Jax-CFD code, the devs opted to separate GridArray (no bc) and GridVariable (bc). After carefully studied the FVM implementation, I decided to combine GridArray with GridVariable to reduce code duplication. - One can definitely try to use TensorClass from tensordict to implement a more generic GridVariable class, however I found using Tensorclass or @tensorclass actually slows down the code quite a bit. - [x] (0.2.0) Finished refactoring the whole GridVariable class for various routines, getting rid of the GridArray class, adding several helper functions for GridVariableVector, and adding batch dimension checks. + - [x] (0.2.5) Added imposing variable/function-valued nonhomogeneous Dirichlet boundary conditions. """ data: torch.Tensor @@ -533,7 +624,8 @@ def __post_init__(self): def __repr__(self) -> str: lines = [f"GridVariable:"] - lines.append(f"data tensor: \n{self.data.cpu().detach().numpy()}") + display_data = self.disp_data + lines.append(f"data tensor: \n{display_data.numpy()}\n") lines.append(f"data shape: {tuple(s for s in self.data.shape)}") lines.append(f"offset: {self.offset}") lines.append(f"grid shape: {self.grid.shape}") @@ -543,13 +635,13 @@ def __repr__(self) -> str: step = self.grid.step[i] lines.append(f" dim {i}: [{lower:.3f}, {upper:.3f}], step={step:.3f}") - lines.append(f"dtype : {self.data.dtype}") + lines.append(f"\ndtype : {self.data.dtype}") lines.append(f"device: {self.device}") # Add boundary condition info if available if self.bc is not None: bc_repr = repr(self.bc) - lines.append(f"boundary conditions:") + lines.append(f"\nboundary conditions:") bc_lines = bc_repr.split("\n") for bc_line in bc_lines: lines.append(f" {bc_line}") @@ -586,6 +678,15 @@ def array(self, v: GridVariable): self.grid = v.grid self.bc = None # reset boundary conditions + @property + def disp_data(self) -> torch.Tensor: + """Returns the data tensor with the second-to-last dimension flipped. Otherwise return a numpy array for printing.""" + # This is useful for displaying 2D data in a natural way + disp_data = self.data.clone().cpu().detach() + if self.grid.ndim >= 2: + disp_data = torch.flip(disp_data.swapaxes(-2, -1), dims=[-2]) + return disp_data + @property def device(self) -> torch.device: return self.data.device @@ -1008,7 +1109,7 @@ def pad( bc: Optional[BoundaryConditions] = None, mode: Optional[str] = Padding.EXTEND, bc_types: Optional[Tuple[str, str]] = None, - values: Optional[Union[float, torch.Tensor]] = None, + values: Optional[Union[BCValues, Sequence[BCValues]]] = None, ) -> GridVariable: """Pad a GridVariable by `padding`. @@ -1023,7 +1124,7 @@ def pad( bc: boundary conditions to use for padding. If None, uses the boundary conditions of u. bc_types: boundary condition types for the dimension `dim`. If None, uses the boundary conditions of u. mode: padding mode for Ghost cells! The function tries to automatically select the mode based on the boundary conditions. - values: (TODO) allowing tensorial values to be passed for future development, currently used for a fallback implementation of Dirichlet BC. + values: tensorial values can be passed directly, see boundaries.DiscreteBoundaryConditions bc_values and pad_and_impose_bc() for details. Returns: Padded array, elongated along the indicated axis. @@ -1039,10 +1140,11 @@ def pad( Padding.MIRROR, Padding.EXTEND, Padding.SYMMETRIC, - ], f"Padding mode must be one of ['{Padding.MIRROR}', '{Padding.EXTEND}', '{Padding.SYMMETRIC}'], got '{mode}'" + Padding.NONE, + ], f"Padding mode must be one of ['{Padding.MIRROR}', '{Padding.EXTEND}', '{Padding.SYMMETRIC}', None], got '{mode}'" bc = bc if bc is not None else u.bc # use bc in priority bc_types = bc.types[dim] if bc_types is None else bc_types - values = bc.bc_values if values is None else values + values = values if values is not None else bc.bc_values if isinstance(width, int): if width < 0: # pad lower boundary bc_type = bc_types[0] @@ -1106,47 +1208,75 @@ def pad( # Then the mirrored ghost cells need to be appended. # if only one value is needed, no mode is necessary. - if math.isclose(sum(full_padding[dim]), 1) or math.isclose( - sum(full_padding[dim]), 0 - ): + if ( + math.isclose(sum(full_padding[dim]), 1) + or math.isclose(sum(full_padding[dim]), 0) + ) and (0 <= new_offset[dim] <= 1): data = expand_dims_pad( u.data, full_padding, mode="constant", constant_values=values ) return GridVariable(data, tuple(new_offset), u.grid, bc) - elif sum(full_padding[dim]) > 1: - if mode == Padding.EXTEND: - data = expand_dims_pad( - u.data, - full_padding, - mode="constant", - constant_values=values, - ) - return GridVariable(data, tuple(new_offset), u.grid, bc) - elif mode == Padding.MIRROR: - bc_padding = [(0, 0)] * u.grid.ndim - bc_padding[dim] = tuple(1 if pad > 0 else 0 for pad in padding) - # subtract the padded cell - full_padding_past_bc = [(0, 0)] * u.grid.ndim - full_padding_past_bc[dim] = tuple( - pad - 1 if pad > 0 else 0 for pad in padding - ) - # here we are adding 0 boundary cell with 0 value - expanded_data = expand_dims_pad( - u.data, bc_padding, mode="constant", constant_values=(0, 0) - ) - padding_values = list(values) - padding_values[dim] = tuple( - [pad / 2 for pad in padding_values[dim]] - ) + elif ( + sum(full_padding[dim]) > 1 + or (new_offset[dim] < 0) + or (new_offset[dim] > 1) + ): + # either (2, 0) or (1, 1) padding in that dimension + # only triggered when bc.pad_all is called + if new_offset[dim] < 0 or new_offset[dim] > 1: + # if padding beyond the boundary, use the linear extrapolation + # if not specified, of new_offset still >=0, use the user define values data = 2 * expand_dims_pad( - u.data, - full_padding, - mode="constant", - constant_values=padding_values, - ) - expand_dims_pad( - expanded_data, full_padding_past_bc, mode="reflect" - ) - return GridVariable(data, tuple(new_offset), u.grid, bc) + u.data, full_padding, mode="constant", constant_values=values + ) - expand_dims_pad(u.data, full_padding, mode=Padding.MIRROR) + return GridVariable(data, tuple(new_offset), u.grid, bc) + else: + if mode == Padding.EXTEND: + data = expand_dims_pad( + u.data, + full_padding, + mode="constant", + constant_values=values, + ) + return GridVariable(data, tuple(new_offset), u.grid, bc) + elif mode == Padding.MIRROR: + bc_padding = [(0, 0)] * u.grid.ndim + bc_padding[dim] = tuple(1 if pad > 0 else 0 for pad in padding) + # subtract the padded cell + full_padding_past_bc = [(0, 0)] * u.grid.ndim + full_padding_past_bc[dim] = tuple( + pad - 1 if pad > 0 else 0 for pad in padding + ) + # here we are adding 0 boundary cell with 0 value + expanded_data = expand_dims_pad( + u.data, bc_padding, mode="constant", constant_values=(0, 0) + ) + padding_values = list(values) + padding_values[dim] = tuple( + [p / 2 for p in padding_values[dim]] + ) + data = 2 * expand_dims_pad( + u.data, + full_padding, + mode="constant", + constant_values=padding_values, + ) - expand_dims_pad( + expanded_data, full_padding_past_bc, mode=mode + ) + return GridVariable(data, tuple(new_offset), u.grid, bc) + elif mode == Padding.SYMMETRIC: + # symmetric padding, mirrors values at the boundaries + data = 2 * expand_dims_pad( + u.data, + full_padding, + mode="constant", + constant_values=values, + ) - expand_dims_pad(u.data, full_padding, mode=mode) + return GridVariable(data, tuple(new_offset), u.grid, bc) + else: + raise ValueError( + f"Unsupported padding mode '{mode}' for Dirichlet BC with cell edge offset" + ) else: raise ValueError( f"invalid padding width for Dirichlet BC, expected padding[dim={dim}] to have sum >= 0, got {padding[dim]}" @@ -1218,7 +1348,7 @@ def expand_dims_pad( pad: Sequence[Tuple[int, int]], mode: str = "constant", constant_values: Union[ - float, Tuple[float, float], Sequence[Tuple[float, float]] + float, Tuple[BCValues, BCValues], Sequence[Tuple[BCValues, BCValues]] ] = 0, **kwargs, ) -> torch.Tensor: @@ -1228,13 +1358,13 @@ def expand_dims_pad( - jnp's pad pad_width starts from the first dimension to the last dimension while torch's pad pad_width starts from the last dimension to the previous dimension example: - - for torch (1, 1, 2, 2) means padding last dim by (1, 1) and 2nd to last by (2, 2), the pad arg for the expand_dims_pad function should be ((2, 2), (1, 1)) + - for torch.nn.functional.pad: pad = (1, 1, 2, 2) means padding last dim by (1, 1) and 2nd to last by (2, 2), the pad arg for the expand_dims_pad function should be ((2, 2), (1, 1)) which the natural ordering of dimensions Args: inputs: torch.Tensor or a tuple of arrays to pad. pad_width: padding width for each dimension. mode: padding mode, one of 'constant', 'reflect', 'symmetric'. - values: constant value to pad with. + values: values to pad with. Returns: Padded `inputs`. @@ -1274,20 +1404,23 @@ def expand_dims_pad( def _constant_pad_tensor( inputs: torch.Tensor, pad: Sequence[Tuple[int, int]], - constant_values: Sequence[Tuple[float, float]], + constant_values: Sequence[Tuple[BCValues, BCValues]], **kwargs, ) -> torch.Tensor: """ - Corrected padding function that supports different constant values for each side. + Corrected padding function that supports different constant/tensor values for each side. Pads each dimension from first to last as per the user input, bypassing PyTorch's F.pad behavior of last-to-first padding order. + Extended to support tensor values for padding - if the input is already a correctly shaped tensor, + it will be used directly for concatenation instead of creating a new tensor with torch.full. + Args: inputs: torch.Tensor to pad. pad: padding width for each dimension, e.g. ((2, 2), (1, 1)) for 2D tensor. (2, 2) means padding the first dimension by 2 on both sides, and (1, 1) means padding the second dimension by 1 on both sides. - constant_values: constant values to pad with for each dimension, e.g. ((0, 0), (1, 1)). + constant_values: constant values to pad with for each dimension, e.g. ((0, 0), (1, 1)). Can also be tensors with correct boundary shapes. Example: - If ((2, 2), (1, 1)) is given for a 2D (potentially batched) tensor of shape (*, 10, 20), + If pad = ((2, 2), (1, 1)) is given for a 2D (potentially batched) tensor of shape (*, 10, 20), - pad[1] corresponds to the padding of the last dimension (20), - pad[0] corresponds to the padding of the second-to-last dimension (10). - the resulting tensor shape will be (*, 10 + 2 + 2, 20 + 1 + 1) = (*, 14, 22) @@ -1301,7 +1434,7 @@ def _constant_pad_tensor( dims_to_pad = len(pad) # number of dimensions to pad result = inputs - for i in reversed(range(dims_to_pad)): + for i in reversed(range(dims_to_pad)): # iterate from last to first dimension dim = i - dims_to_pad # correct mapping from pad index to tensor dim @@ -1311,7 +1444,7 @@ def _constant_pad_tensor( continue # Get constant values - if len(constant_values) > i: + if len(constant_values) > 0: if ( isinstance(constant_values[i], (tuple, list)) and len(constant_values[i]) == 2 @@ -1322,36 +1455,130 @@ def _constant_pad_tensor( else: left_val = right_val = 0.0 - left_val = ( - float(left_val[0]) - if isinstance(left_val, (list, tuple)) - else float(left_val) - ) - right_val = ( - float(right_val[0]) - if isinstance(right_val, (list, tuple)) - else float(right_val) - ) - - shape = list(result.shape) - + # Handle left padding if left_pad > 0: - shape[dim] = left_pad - left_tensor = torch.full( - shape, left_val, dtype=result.dtype, device=result.device + left_tensor = ( + _create_boundary_tensor(left_val, result.shape, dim, left_pad) + .to(result.dtype) + .to(result.device) ) result = torch.cat([left_tensor, result], dim=dim) + # Handle right padding if right_pad > 0: - shape[dim] = right_pad - right_tensor = torch.full( - shape, right_val, dtype=result.dtype, device=result.device + right_tensor = ( + _create_boundary_tensor(right_val, result.shape, dim, right_pad) + .to(result.dtype) + .to(result.device) ) result = torch.cat([result, right_tensor], dim=dim) return result +def _create_boundary_tensor( + value: Union[BCValues, float], + target_shape: Tuple[int, ...], + pad_dim: int, + pad_width: int, +) -> torch.Tensor: + """ + Create a boundary tensor for padding with proper shape handling. + + Args: + value: The boundary value (tensor, scalar, or None) + target_shape: Shape of the tensor being padded + pad_dim: The dimension being padded (negative indexing) + pad_width: Width of padding for this side + + Returns: + Properly shaped tensor for concatenation + + Notes: + current only handle 1D bc for concatenation with + (b, nx, ny) shaped tensors. + """ + # Calculate expected boundary shape + expected_shape = list(target_shape) + expected_shape[pad_dim] = pad_width + + if isinstance(value, torch.Tensor): + # Handle tensor boundary values + if list(value.shape) == expected_shape: + # Tensor already has correct shape + return value + + elif value.ndim == 1: + # Handle 1D boundary tensor - need to properly reshape/expand + if value.shape[0] > 1: + boundary_size = value.shape[0] + + # Determine target size for the boundary dimension + if pad_dim == -1: # Last dimension + # For padding last dim, boundary values correspond to second-to-last dim + boundary_dim = -2 + target_size = expected_shape[boundary_dim] + elif pad_dim == -2: # Second to last dimension + # For padding second-to-last dim, boundary values correspond to last dim + boundary_dim = -1 + target_size = expected_shape[boundary_dim] + else: + raise NotImplementedError( + f"Padding BC for dimension {pad_dim} is not implemented. " + "Currently only -1 (y dimension) and -2 (x dimension) are supported." + ) + + # Handle size mismatch with interpolation + if boundary_size != target_size: + # Use interpolation to resize the boundary tensor + # Reshape to [1, 1, boundary_size] for F.interpolate + # TODO: interpolate here is a janky monkey patch + # as the location may be mis-aligned + # a better way + # TODO: a better way to handle this is simply use slicing + + reshaped_for_interp = value[None, None, :] + interpolated = F.interpolate( + reshaped_for_interp, + size=target_size, + mode="linear", + align_corners=False, + ) + # Remove the added dimensions: [1, 1, target_size] -> [target_size] + value = interpolated.squeeze() + boundary_size = target_size + + # Create the target tensor by using view and expand appropriately + new_shape = [1] * len(expected_shape) + new_shape[boundary_dim] = boundary_size + + # Reshape and expand + reshaped = value.view(new_shape) + boundary_tensor = reshaped.expand(expected_shape) + + return boundary_tensor + elif value.shape[0] == 1: + # If the tensor has only one element, we can simply expand it + # to the expected shape + return value.expand(expected_shape) + else: + raise ValueError( + f"Boundary tensor with {value.ndim}D shape {value.shape} is not supported. " + f"Only 1D boundary tensors are supported." + ) + + else: + raise ValueError( + f"Boundary tensor with {value.ndim}D shape {value.shape} is not supported. " + f"Only 1D boundary tensors are supported." + ) + + else: + # Handle scalar boundary values + scalar_val = float(value) if value is not None else 0.0 + return torch.full(expected_shape, scalar_val) + + def _symmetric_pad_tensor( inputs: torch.Tensor, pad: Sequence[Tuple[int, int]], @@ -1360,6 +1587,7 @@ def _symmetric_pad_tensor( """ Symmetric padding function that mirrors values at the boundaries. Pads each dimension from first to last as per the user input. + This is a drop-in replacement for np.pad with mode == 'symmetric'. Args: inputs: torch.Tensor to pad. @@ -1377,6 +1605,10 @@ def _symmetric_pad_tensor( >>> _symmetric_pad_tensor(data, ((2, 0),)) tensor([[12., 11., 11., 12., 13., 14.],]) + Note: + - the 'reflect' mode of F.pad would yield for the data above + tensor([[13., 12., 11., 12., 13., 14.],]) + """ dims_to_pad = len(pad) # number of dimensions to pad result = inputs diff --git a/torch_cfd/test_utils.py b/torch_cfd/test_utils.py index 3553926..cbabe7e 100644 --- a/torch_cfd/test_utils.py +++ b/torch_cfd/test_utils.py @@ -62,12 +62,38 @@ def _check_and_remove_alignment_and_grid(self, *arrays): # pylint: disable=unbalanced-tuple-unpacking def assertArrayEqual(self, actual, expected, **kwargs): actual, expected = self._check_and_remove_alignment_and_grid(actual, expected) - atol = torch.finfo(expected.data.dtype).eps - rtol = expected.abs().max() * atol + rtol = torch.finfo(expected.data.dtype).eps + atol = expected.abs().max() * rtol torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol, **kwargs) def assertAllClose(self, actual, expected, **kwargs): actual, expected = self._check_and_remove_alignment_and_grid(actual, expected) torch.testing.assert_close(actual, expected, **kwargs) - # pylint: enable=unbalanced-tuple-unpacking + def assertNestedTuplesEqual(self, tuple1, tuple2, atol=1e-6, rtol=1e-6): + """Assert that two nested tuples containing tensors are equal.""" + def _compare_recursive(t1, t2): + if type(t1) != type(t2): + return False + + if isinstance(t1, tuple): + if len(t1) != len(t2): + return False + return all(_compare_recursive(x, y) for x, y in zip(t1, t2)) + + elif isinstance(t1, torch.Tensor) and isinstance(t2, torch.Tensor): + try: + self.assertAllClose(t1, t2, atol=atol, rtol=rtol) + return True + except AssertionError: + return False + + elif t1 is None and t2 is None: + return True + + else: + return t1 == t2 + + self.assertTrue(_compare_recursive(tuple1, tuple2), + f"Nested tuples are not equal:\n{tuple1}\nvs\n{tuple2}") + diff --git a/torch_cfd/tests/test_boundaries.py b/torch_cfd/tests/test_boundaries.py index 4c294fb..4a68964 100644 --- a/torch_cfd/tests/test_boundaries.py +++ b/torch_cfd/tests/test_boundaries.py @@ -18,6 +18,8 @@ from functools import partial +import math +import numpy as np import torch from absl.testing import absltest, parameterized @@ -320,14 +322,15 @@ def test_trim_padding_1d( @parameterized.parameters( # Dirichlet BC + # test_pad_1d_inhomogeneous0 dict( - bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)), - input_data=tensor([1, 12, 13, 14]), - input_offset=(0,), # cell nodes in 1d (cell edge in 2d) - width=-1, - expected_data=tensor([1, 1, 12, 13, 14]), - expected_offset=(-1,), - ), + bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)), + input_data=tensor([1, 12, 13, 14]), + input_offset=(0,), + width=-3, + expected_data=tensor([-12, -11, -10, 1, 12, 13, 14]), + expected_offset=(-3,), + ), dict( bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)), input_data=tensor([1, 12, 13, 14]), @@ -894,5 +897,393 @@ def test_get_pressure_bc_from_velocity_2d(self): ) +class VariableBoundaryConditionsTest(test_utils.TestCase): + """DiscreteBoundaryConditions and FunctionBoundaryConditions.""" + + def test_initialization_with_arrays(self): + """Test initialization with tensor boundary values.""" + grid = grids.Grid((4, 4)) + + # Create boundary value arrays + left_boundary = tensor([1., 2., 3., 4.]) # left boundary (4 points) + right_boundary = tensor([5., 6., 7., 8.]) # right boundary (4 points) + bottom_boundary = tensor([10., 11., 12., 13.]) # bottom boundary (4 points) + top_boundary = tensor([20., 21., 22., 23.]) # top boundary (4 points) + + bc = boundaries.DiscreteBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.DIRICHLET, BCType.DIRICHLET)), + values=((left_boundary, right_boundary), + (bottom_boundary, top_boundary)) + ) + + self.assertEqual(len(bc.bc_values), 2) + self.assertEqual(len(bc.bc_values[0]), 2) + self.assertEqual(len(bc.bc_values[1]), 2) + + def test_initialization_with_callables(self): + """Test initialization with callable boundary values.""" + def inlet_profile(x, y): + return torch.sin(torch.pi * y) + + def outlet_profile(x, y): + return torch.zeros_like(y) + + grid = grids.Grid((4, 4)) + + bc = boundaries.FunctionBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.PERIODIC, BCType.PERIODIC)), + values=((inlet_profile, outlet_profile), + (None, None)), + grid=grid, + offset=(0.5, 0.5) + ) + + # raw is the callable, bc_values is the evaluated tensor + self.assertTrue(callable(bc._raw_bc_values[0][0])) + self.assertTrue(callable(bc._raw_bc_values[0][1])) + self.assertTrue(isinstance(bc._bc_values[0][0], torch.Tensor)) + self.assertTrue(isinstance(bc._bc_values[0][1], torch.Tensor)) + + def test_initialization_with_mixed_values(self): + """Test initialization with mixed boundary value types.""" + left_boundary = tensor([1., 2., 3., 4.]) + + def right_profile(x, y): + return y * 2.0 + + bc = boundaries.FunctionBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.NEUMANN, BCType.NEUMANN)), + values=((left_boundary, right_profile), + (0.0, 1.0)), # constant values + grid=grids.Grid((4, 4)), + offset=(0.5, 0.5) + ) + + # Use private _bc_values attribute to check types without grid context + self.assertIsInstance(bc._bc_values[0][0], torch.Tensor) + self.assertTrue(callable(bc._raw_bc_values[0][1])) + self.assertTrue(isinstance(bc._bc_values[0][1], torch.Tensor)) + self.assertEqual(bc._bc_values[1][0], 0.0) + self.assertEqual(bc._bc_values[1][1], 1.0) + + def test_evaluate_boundary_value_with_callable(self): + """Test evaluation of callable boundary values.""" + def parabolic_profile(x, y): + return y * (1 - y) + + grid = grids.Grid((4, 4), domain=((0, 1), (0, 1))) + expected = tensor([0., 0.1875, 0.25, 0.1875]) + + bc = boundaries.FunctionBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.NEUMANN, BCType.NEUMANN)), + values=((parabolic_profile, 0.0), + (0.0, 1.0)), # constant values + grid=grid, + offset=(0.0, 0.0) + ) + + result = bc.bc_values[0][0] # Evaluate left boundary + self.assertArrayEqual(result, expected) + + def test_get_boundary_coordinates_2d(self): + """Test boundary coordinate generation for 2D case.""" + grid = grids.Grid((4, 4), domain=((0, 2), (0, 2))) + offset = (0.5, 0.5) + + bc = boundaries.DiscreteBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.DIRICHLET, BCType.DIRICHLET)), + values=((1.0, 2.0), (3.0, 4.0)) + ) + + # Test boundary coordinates for dimension 0 (y-coordinates for left/right boundaries) + coords = bc._boundary_mesh(0, grid, offset) + expected_x = tensor([0, 0, 0, 0]) # x-coordinates for left boundary + expected_y = tensor([0.25, 0.75, 1.25, 1.75]) # offset[1] * step[1] for each grid point + # left edge + self.assertArrayEqual(coords[0][0], expected_x) + self.assertArrayEqual(coords[0][1], expected_y) + # right edge + expected_x = tensor([2, 2, 2, 2]) # + self.assertArrayEqual(coords[1][0], expected_x) + self.assertArrayEqual(coords[1][1], expected_y) + + + # Test boundary coordinates for dimension 1 (x-coordinates for top/bottom boundaries) + coords = bc._boundary_mesh(1, grid, offset) + expected_x = tensor([0.25, 0.75, 1.25, 1.75]) # offset[0] * step[0] for each grid point + expected_y = tensor([0, 0, 0, 0]) # y-coordinates for bottom boundary + # bottom edge + self.assertArrayEqual(coords[0][0], expected_x) + self.assertArrayEqual(coords[0][1], expected_y) + # top edge + expected_y = tensor([2, 2, 2, 2]) # + self.assertArrayEqual(coords[1][0], expected_x) + self.assertArrayEqual(coords[1][1], expected_y) + + def test_values_method_with_callables(self): + """Test values method with callable boundary conditions.""" + grid = grids.Grid((3, 3), domain=((0, 3), (0, 3))) + + def left_profile(x, y): + return y * 2.0 + + def right_profile(x, y): + return y + 1.0 + + bc = boundaries.FunctionBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.PERIODIC, BCType.PERIODIC)), + values=((left_profile, right_profile), + (None, None)), + grid=grid, + offset=(0.0, 0.5) + ) + + bc_lower, bc_upper = bc.bc_values[0] + + # For the default edge center (0.0, 0.5), y coordinates are [0.5, 1.5, 2.5] + expected_left = tensor([1.0, 3.0, 5.0]) # y * 2.0 + expected_right = tensor([1.5, 2.5, 3.5]) # y + 1.0 + + self.assertArrayEqual(bc_lower, expected_left) + self.assertArrayEqual(bc_upper, expected_right) + + + def test_variable_boundary_vs_constant_boundary_consistency(self): + """Test that VariableBoundaryConditions gives same results as ConstantBoundaryConditions for constant values.""" + grid = grids.Grid((8, 4)) + + # Constant boundary conditions + const_bc = boundaries.ConstantBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.DIRICHLET, BCType.DIRICHLET)), + values=((1.0, 2.0), (3.0, 4.0)) + ) + + # Variable boundary conditions with constant values + var_bc = boundaries.DiscreteBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.DIRICHLET, BCType.DIRICHLET)), + values=((torch.ones(4), 2*torch.ones(4)), + (3*torch.ones(8), 4*torch.ones(8)) + )) + + # Test that values method gives same results + const_values_0 = const_bc.values(0, grid) + var_values_0 = var_bc.values(0, grid) + + # Both should give scalar constant values broadcast to boundary shape + expected_left = torch.full((4,), 1.0) + expected_right = torch.full((4,), 2.0) + self.assertArrayEqual(const_values_0[0], expected_left) + self.assertArrayEqual(var_values_0[0], expected_left) + self.assertArrayEqual(const_values_0[1], expected_right) + self.assertArrayEqual(var_values_0[1], expected_right) + + + def test_compatible_bc_with_grid(self): + """Test error handling in VariableBoundaryConditions.""" + # Test with mismatched array dimensions + grid = grids.Grid((4, 4)) + wrong_size_boundary = tensor([1., 2.]) # Should be size 4 + + bc = boundaries.DiscreteBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.DIRICHLET, BCType.DIRICHLET)), + values=((wrong_size_boundary, 2.0), + (3.0, 4.0)) + ) + + # This should raise an error when validating with grid + with self.assertRaises(ValueError): + bc._validate_boundary_arrays_with_grid(grid) + + @parameterized.named_parameters( + # Test different offsets for linear function + dict( + testcase_name="_linear_lower_left_corner", + shape=(4, 4), + offset=(0, 0), + f=lambda x, y: x + 2 * y, + expected_bc_values=( + (tensor([0., 0.5, 1.0, 1.5]), tensor([1., 1.5, 2.0, 2.5])), + (tensor([0., 0.25, 0.5, 0.75]), tensor([2., 2.25, 2.5, 2.75])) + ), + ), + dict( + testcase_name="_linear_cell_center", + shape=(4, 4), + offset=(0.5, 0.5), + f=lambda x, y: x + 2 * y, + expected_bc_values=( + (tensor([0.25, 0.75, 1.25, 1.75]), tensor([1.25, 1.75, 2.25, 2.75])), + (tensor([0.125, 0.375, 0.625, 0.875]), tensor([2.125, 2.375, 2.625, 2.875])) + ), + ), + dict( + testcase_name="_linear_upper_right_corner", + shape=(4, 4), + offset=(1.0, 1.0), + f=lambda x, y: x + 2 * y, + expected_bc_values=( + (tensor([0.5, 1.0, 1.5, 2.0]), tensor([1.5, 2.0, 2.5, 3.0])), + (tensor([0.25, 0.5, 0.75, 1.0]), tensor([2.25, 2.5, 2.75, 3.0])) + ), + ), + # Test constant function + dict( + testcase_name="_constant_cell_center", + shape=(4, 4), + offset=(0.5, 0.5), + f=lambda x, y: torch.ones_like(x), + expected_bc_values=( + (tensor([1., 1., 1., 1.]), tensor([1., 1., 1., 1.])), + (tensor([1., 1., 1., 1.]), tensor([1., 1., 1., 1.])) + ), + ), + # Test polynomial function + dict( + testcase_name="_quadratic_lower_left_corner", + shape=(3, 3), + offset=(0, 0), + f=lambda x, y: x**2 + y**2, + expected_bc_values=( + (tensor([0., 1./9., 4./9.]), tensor([1., 10./9., 13./9.])), + (tensor([0., 1./9., 4./9.]), tensor([1., 10./9., 13./9.])) + ), + ), + ) + def test_function_bc_evaluation(self, shape, offset, f, expected_bc_values): + """Test that FunctionBoundaryConditions evaluates functions correctly.""" + grid = grids.Grid(shape, domain=((0.0, 1.0), (0.0, 1.0))) + + bc = boundaries.FunctionBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.DIRICHLET, BCType.DIRICHLET)), + values=((f, f), (f, f)), + grid=grid, + offset=offset + ) + + actual_bc_values = bc.bc_values + self.assertNestedTuplesEqual(actual_bc_values, expected_bc_values) + + @parameterized.named_parameters( + dict( + testcase_name="_linear", + shape=(16, 16), + f=lambda x, y: x + 2 * y, + ), + dict( + testcase_name="_quadratic", + shape=(16, 16), + f=lambda x, y: x**2 + 2 * y**2, + ), + ) + def test_function_bc_vs_discrete_bc_consistency(self, shape, f): + """Test that FunctionBoundaryConditions gives same results as DiscreteBoundaryConditions for equivalent inputs.""" + # Create a function that matches predefined tensor values + + grid = grids.Grid(shape, domain=((0.0, 1.0), (0.0, 1.0))) + offsets = [(0, 0), (0, 1), (1, 0), (0.5, 1), (1, 0.5), (1, 1)] + + for offset in offsets: + # Pre-calculate what the function should give + lower_coords, upper_coords = grid.boundary_mesh(0, offset) + expected_left = f(*lower_coords) + expected_right = f(*upper_coords) + + lower_coords, upper_coords = grid.boundary_mesh(1, offset) + expected_bottom = f(*lower_coords) + expected_top = f(*upper_coords) + + # Function-based BC + function_bc = boundaries.FunctionBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.DIRICHLET, BCType.DIRICHLET)), + values=f, + grid=grid, + offset=offset + ) + + # Discrete BC with pre-calculated values + discrete_bc = boundaries.DiscreteBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.DIRICHLET, BCType.DIRICHLET)), + values=((expected_left, expected_right), + (expected_bottom, expected_top)) + ) + + # Compare bc_values + self.assertNestedTuplesEqual(function_bc.bc_values, discrete_bc.bc_values, atol=1e-6) + + def test_function_bc_time_dependent(self): + """Test FunctionBoundaryConditions with time-dependent functions.""" + def time_varying_inlet(x, y, t): + return torch.sin(2 * torch.pi * t) * y + + def steady_outlet(x, y, t): + return torch.zeros_like(y) + + grid = grids.Grid((3, 3), domain=((0.0, 1.0), (0.0, 1.0))) + time = torch.tensor([0.25]) # sin(2*pi*0.25) = sin(pi/2) = 1 + + bc = boundaries.FunctionBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.PERIODIC, BCType.PERIODIC)), + values=((time_varying_inlet, steady_outlet), + (None, None)), + grid=grid, + offset=(0.5, 0.5), + time=time + ) + + bc_values = bc.bc_values + + # At t=0.25, sin(2*pi*0.25) = 1, so inlet should be just y + expected_inlet = tensor([1./6., 0.5, 5./6.]) # y values at offset (0.5, 0.5) + expected_outlet = tensor([0., 0., 0.]) + + self.assertAllClose(bc_values[0][0], expected_inlet, atol=1e-6, rtol=1e-8) + self.assertAllClose(bc_values[0][1], expected_outlet, atol=1e-6, rtol=1e-8) + + + def test_function_bc_mixed_values(self): + """Test FunctionBoundaryConditions with mixed value types.""" + def parabolic_profile(x, y): + return y * (1 - y) + + grid = grids.Grid((4, 4), domain=((0.0, 1.0), (0.0, 1.0))) + left_boundary = tensor([0.1, 0.2, 0.3, 0.4]) + + bc = boundaries.FunctionBoundaryConditions( + types=((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.NEUMANN, BCType.NEUMANN)), + values=((left_boundary, parabolic_profile), + (0.5, 1.0)), # constant values + grid=grid, + offset=(0.0, 0.0) + ) + + bc_values = bc.bc_values + + # Check left boundary (tensor input) + self.assertAllClose(bc_values[0][0], left_boundary) + + # Check right boundary (function evaluated) + expected_right = tensor([0., 3./16., 1./4., 3./16.]) # y*(1-y) at y=[0, 0.25, 0.5, 0.75] + self.assertAllClose(bc_values[0][1], expected_right, atol=1e-6, rtol=1e-8) + + # Check constant boundaries + self.assertEqual(bc_values[1][0], tensor(0.5)) + self.assertEqual(bc_values[1][1], tensor(1.0)) + + + if __name__ == "__main__": - absltest.main() + absltest.main() \ No newline at end of file diff --git a/torch_cfd/tests/test_finite_differences.py b/torch_cfd/tests/test_finite_differences.py index 07ffa78..abc3b84 100644 --- a/torch_cfd/tests/test_finite_differences.py +++ b/torch_cfd/tests/test_finite_differences.py @@ -25,19 +25,8 @@ from torch_cfd import boundaries, finite_differences as fdm, grids, test_utils BCType = grids.BCType - - -def _trim_boundary(array): - # fixed jax-cfd bug that trims all dimension for a batched GridVariable - if isinstance(array, grids.GridVariable): - # Convert tuple of slices to individual slice objects - trimmed_slices = (slice(1, -1),) * array.grid.ndim - data = array.data[(..., *trimmed_slices)] - return grids.GridVariable(data, array.offset, array.grid) - else: - tensor = torch.as_tensor(array) - trimmed_slices = (slice(1, -1),) * tensor.ndim - return tensor[(..., *trimmed_slices)] +Padding = grids.Padding +trim_boundary = fdm.trim_boundary def grid_variable_periodic(data, offset, grid): @@ -46,10 +35,56 @@ def grid_variable_periodic(data, offset, grid): ) -def grid_variable_dirichlet(data, offset, grid): +def grid_variable_dirichlet_constant(data, offset, grid, bc_values=None): return grids.GridVariable( - data, offset, grid, bc=boundaries.dirichlet_boundary_conditions(grid.ndim) + data, + offset, + grid, + bc=boundaries.dirichlet_boundary_conditions(grid.ndim, bc_values), + ) + + +def grid_variable_dirichlet_nonhomogeneous(data, offset, grid, bc_values): + bc = boundaries.DiscreteBoundaryConditions( + ((BCType.DIRICHLET, BCType.DIRICHLET),) * grid.ndim, bc_values + ) + return grids.GridVariable(data, offset, grid, bc) + + +def grid_variable_dirichlet_function_nonhomogeneous(data, offset, grid, bc_funcs): + bc_types = ((BCType.DIRICHLET, BCType.DIRICHLET),) * grid.ndim + bc = boundaries.FunctionBoundaryConditions(bc_types, bc_funcs, grid, offset) + return grids.GridVariable(data, offset, grid, bc) + + +def grid_variable_dirichlet_nonhomogeneous_and_periodic( + data, offset, grid, bc_values, periodic_dim=0 +): + bc_dirichlet = (BCType.DIRICHLET, BCType.DIRICHLET) + bc_periodic = (BCType.PERIODIC, BCType.PERIODIC) + bc_types = tuple( + bc_periodic if i == periodic_dim else bc_dirichlet for i in range(grid.ndim) ) + bc = boundaries.DiscreteBoundaryConditions(bc_types, bc_values) + return grids.GridVariable(data, offset, grid, bc) + + +def grid_variable_vector_batch_from_functions( + grid, offsets, vfuncs, bc_u, bc_v, batch_size=1 +): + v = [] + for dim, (offset, bc) in enumerate(zip(offsets, (bc_u, bc_v))): + x, y = grid.mesh(offset) + data = vfuncs(x, y) + data = repeat(data[dim], "h w -> b h w", b=batch_size) + v.append( + grid_variable_dirichlet_nonhomogeneous_and_periodic( + data, offset, grid, bc, periodic_dim=dim + ) + ) + + v = grids.GridVariableVector(tuple(v)) + return v def stack_tensor_matrix(matrix): @@ -193,11 +228,14 @@ def test_finite_difference_analytic( dict( testcase_name="_2D_sine", shape=(32, 32), - f=lambda x, y: torch.sin(math.pi*x) * torch.sin(math.pi*y), - g=lambda x, y: -2 *math.pi**2 * torch.sin(math.pi*x) * torch.sin(math.pi*y), - atol=1/32, + f=lambda x, y: torch.sin(math.pi * x) * torch.sin(math.pi * y), + g=lambda x, y: -2 + * math.pi**2 + * torch.sin(math.pi * x) + * torch.sin(math.pi * y), + atol=1 / 32, rtol=1e-3, - ) + ), ) def test_laplacian_periodic(self, shape, f, g, atol, rtol): step = tuple([1.0 / s for s in shape]) @@ -205,45 +243,39 @@ def test_laplacian_periodic(self, shape, f, g, atol, rtol): offset = (0,) * len(shape) mesh = grid.mesh(offset) u = grid_variable_periodic(f(*mesh), offset, grid) - expected_laplacian = _trim_boundary(grids.GridVariable(g(*mesh), offset, grid)) - actual_laplacian = _trim_boundary(fdm.laplacian(u)) + expected_laplacian = trim_boundary(grids.GridVariable(g(*mesh), offset, grid)) + actual_laplacian = trim_boundary(fdm.laplacian(u)) self.assertAllClose(expected_laplacian, actual_laplacian, atol=atol, rtol=rtol) - @parameterized.named_parameters( dict( - testcase_name="_2D_constant", - shape=(20, 20), - f=lambda x, y: torch.ones_like(x), - g=lambda x, y: torch.zeros_like(x), - atol=1e-3, - rtol=1e-8, - ), - dict( - testcase_name="_2D_quadratic", + testcase_name="_2D_quartic", shape=(21, 21), - f=lambda x, y: x * (x - 1.0) + y * (y - 1.0), - g=lambda x, y: 4 * torch.ones_like(x), - atol=1e-3, - rtol=1e-8, + f=lambda x, y: x * (x - 1.0) * y * (y - 1.0), + g=lambda x, y: 2 * y * (y - 1.0) + 2 * x * (x - 1.0), + atol=1e-2, + rtol=1e-5, ), dict( testcase_name="_2D_sine", - shape=(32, 32), - f=lambda x, y: torch.sin(math.pi*x) * torch.sin(math.pi*y), - g=lambda x, y: -2 *math.pi**2 * torch.sin(math.pi*x) * torch.sin(math.pi*y), - atol=1/32, + shape=(32, 32), + f=lambda x, y: torch.sin(math.pi * x) * torch.sin(math.pi * y), + g=lambda x, y: -2 + * math.pi**2 + * torch.sin(math.pi * x) + * torch.sin(math.pi * y), + atol=1 / 32, rtol=1e-3, - ) + ), ) - def test_laplacian_dirichlet(self, shape, f, g, atol, rtol): + def test_laplacian_dirichlet_homogeneous(self, shape, f, g, atol, rtol): step = tuple([1.0 / s for s in shape]) grid = grids.Grid(shape, step) offset = (0,) * len(shape) mesh = grid.mesh(offset) - u = grid_variable_dirichlet(f(*mesh), offset, grid) - expected_laplacian = _trim_boundary(grids.GridVariable(g(*mesh), offset, grid)) - actual_laplacian = _trim_boundary(fdm.laplacian(u)) + u = grid_variable_dirichlet_constant(f(*mesh), offset, grid) + expected_laplacian = trim_boundary(grids.GridVariable(g(*mesh), offset, grid)) + actual_laplacian = trim_boundary(fdm.laplacian(u)) self.assertAllClose(expected_laplacian, actual_laplacian, atol=atol, rtol=rtol) @parameterized.named_parameters( @@ -267,54 +299,66 @@ def test_laplacian_dirichlet(self, shape, f, g, atol, rtol): ), ) def test_divergence(self, shape, offsets, f, g, atol, rtol): + # note: somehow the bcs are incorrectly set but the divergence is still correct step = tuple([1.0 / s for s in shape]) grid = grids.Grid(shape, step) v = [ - grid_variable_periodic(f(*grid.mesh(offset))[axis], offset, grid) - for axis, offset in enumerate(offsets) + grid_variable_periodic(f(*grid.mesh(offset))[dim], offset, grid) + for dim, offset in enumerate(offsets) ] - expected_divergence = _trim_boundary( + expected_divergence = trim_boundary( grids.GridVariable(g(*grid.mesh()), (0,) * grid.ndim, grid) ) - actual_divergence = _trim_boundary(fdm.divergence(v)) + actual_divergence = trim_boundary(fdm.divergence(v)) self.assertAllClose( expected_divergence, actual_divergence, atol=atol, rtol=rtol ) @parameterized.named_parameters( - # https://en.wikipedia.org/wiki/Curl_(mathematics)#Examples dict( - testcase_name="_solenoidal", - shape=(20, 20), + testcase_name="_solenoidal_8x8", + shape=(8, 8), offsets=((0.5, 0), (0, 0.5)), f=lambda x, y: (y, -x), g=lambda x, y: -2 * torch.ones_like(x), - atol=1e-3, - rtol=1e-10, + bc_u=((None, None), (torch.zeros(8), torch.ones(8))), + bc_v=((torch.zeros(8), -torch.ones(8)), (None, None)), + ), + dict( + testcase_name="_solenoidal_32x32", + shape=(32, 32), + offsets=((0.5, 0), (0, 0.5)), + f=lambda x, y: (y, -x), + g=lambda x, y: -2 * torch.ones_like(x), + bc_u=((None, None), (torch.zeros(32), torch.ones(32))), + bc_v=((torch.zeros(32), -torch.ones(32)), (None, None)), ), dict( - testcase_name="_wikipedia_example_2", + testcase_name="_wikipedia_example_2d_21x21", shape=(21, 21), offsets=((0.5, 0), (0, 0.5)), f=lambda x, y: (torch.ones_like(x), -(x**2)), g=lambda x, y: -2 * x, - atol=1e-3, - rtol=1e-10, + bc_u=((None, None), (torch.ones(21), torch.ones(21))), + bc_v=((torch.zeros(21), -torch.ones(21)), (None, None)), ), ) - def test_curl_2d(self, shape, offsets, f, g, atol, rtol): + def test_curl_2d(self, shape, offsets, f, g, bc_u, bc_v): step = tuple([1.0 / s for s in shape]) grid = grids.Grid(shape, step) + bcvals = [bc_u, bc_v] v = [ - grid_variable_periodic(f(*grid.mesh(offset))[axis], offset, grid) - for axis, offset in enumerate(offsets) + grid_variable_dirichlet_nonhomogeneous_and_periodic( + f(*grid.mesh(offset))[dim], offset, grid, bcval, dim + ) + for dim, (offset, bcval) in enumerate(zip(offsets, bcvals)) ] result_offset = (0.5, 0.5) - expected_curl = _trim_boundary( + expected_curl = trim_boundary( grids.GridVariable(g(*grid.mesh(result_offset)), result_offset, grid) ) - actual_curl = _trim_boundary(fdm.curl_2d(v)) - self.assertAllClose(expected_curl, actual_curl, atol=atol, rtol=rtol) + actual_curl = trim_boundary(fdm.curl_2d(v)) + self.assertAllClose(actual_curl, expected_curl, atol=1e-5, rtol=1e-10) @parameterized.parameters( # Periodic BC @@ -371,6 +415,498 @@ def test_laplacian_matrix_w_boundaries(self, offset, bc_types, expected): self.assertAllClose(actual, expected) +class FiniteDifferenceNonHomogeneousTest(test_utils.TestCase): + """Test finite difference operations with non-homogeneous boundary conditions.""" + + @parameterized.parameters( + dict( + shape=(8,), + offset=(0,), + ), + dict( + shape=(8,), + offset=(1,), + ), + dict( + shape=(16,), + offset=(0,), + ), + dict(shape=(16,), offset=(1,)), + dict(shape=(16,), offset=(0.5,)), + ) + def test_forward_difference_nonhomogeneous_bc_1d(self, shape, offset): + """Test forward difference operator with non-homogeneous boundary conditions.""" + grid = grids.Grid(shape, domain=((0.0, 1.0),)) + mesh = grid.mesh(offset) + + # Linear function: u = 2x + 1 + # Forward difference should give 2 + # checking the boundary behavior of padding + u_data = 2 * mesh[0] + 1 + + # Non-homogeneous boundary conditions + bc_values = ((1.0, 3.0),) # u(0) = 1, u(1) = 3 + + u = grids.GridVariable( + u_data, + offset, + grid, + bc=boundaries.dirichlet_boundary_conditions(grid.ndim, bc_values), + ) + u = u.impose_bc() + + # the forward diff needs another padding beyond the boundary + # by default the padding mode is 'extend' or replicate? + # for the MAC + # u.shift(+1, 0) gets the replicate padding at the end + # check the behavior of pad in this case + forward_diff = trim_boundary(fdm.forward_difference(u, dim=0)) + + expected = 2.0 * torch.ones_like(forward_diff.data) + + self.assertAllClose(forward_diff.data, expected, atol=1e-4, rtol=1e-7) + + @parameterized.parameters( + dict( + shape=(8,), + offset=(0,), + ), + dict( + shape=(8,), + offset=(1,), + ), + dict( + shape=(16,), + offset=(0,), + ), + dict(shape=(16,), offset=(1,)), + dict(shape=(16,), offset=(0.5,)), + ) + def test_backward_difference_nonhomogeneous_bc_1d(self, shape, offset): + """Test backward difference operator with non-homogeneous boundary conditions.""" + grid = grids.Grid(shape, domain=((0.0, 1.0),)) + mesh = grid.mesh(offset) + + u_data = 3 * mesh[0] + 0.5 + bc_values = ((0.5, 3.5),) # u(0) = 0.5, u(1) = 3.5 + + u = grids.GridVariable( + u_data, + offset, + grid, + bc=boundaries.dirichlet_boundary_conditions(grid.ndim, bc_values), + ) + u = u.impose_bc() + backward_diff = trim_boundary(fdm.backward_difference(u, dim=0)) + + expected = 3.0 * torch.ones_like(backward_diff.data) + + self.assertAllClose(backward_diff.data, expected, atol=1e-4, rtol=1e-7) + + @parameterized.parameters( + dict( + shape=(8,), + offset=(0,), + ), + dict( + shape=(8,), + offset=(1,), + ), + dict( + shape=(16,), + offset=(0,), + ), + dict(shape=(16,), offset=(1,)), + dict(shape=(16,), offset=(0.5,)), + ) + def test_central_difference_nonhomogeneous_bc_1d(self, shape, offset): + """Test central difference operator with non-homogeneous boundary conditions.""" + grid = grids.Grid(shape, domain=((0.0, 1.0),)) + mesh = grid.mesh(offset) + + u_data = 4 * mesh[0] + 2 + + bc_values = ((2.0, 6.0),) # u(0) = 2, u(1) = 6 + + u = grids.GridVariable( + u_data, + offset, + grid, + bc=boundaries.dirichlet_boundary_conditions(grid.ndim, bc_values), + ) + u = u.impose_bc() + + central_diff = trim_boundary(fdm.central_difference(u, dim=0)) + + expected = 4.0 * torch.ones_like(central_diff.data) + + self.assertAllClose(central_diff.data, expected, atol=1e-4, rtol=1e-7) + + @parameterized.parameters( + dict( + shape=(16, 16), + offset=(0, 0), + ), + dict( + shape=(16, 16), + offset=(0, 1), + ), + dict( + shape=(16, 16), + offset=(1, 0), + ), + dict(shape=(16, 16), offset=(1, 1)), + dict(shape=(32, 32), offset=(0.5, 1)), + dict(shape=(32, 32), offset=(1, 0.5)), + ) + def test_central_difference_nonhomogeneous_bc_2d(self, shape, offset): + """Test central difference operator with non-homogeneous boundary conditions in 2D.""" + grid = grids.Grid(shape, domain=((0.0, 1.0), (0.0, 1.0))) + x, y = grid.mesh(offset) + h = max(grid.step) + + f = lambda x, y: x**2 + 2 * y**2 + fx = lambda x, y: 2 * x + fy = lambda x, y: 4 * y + u_data = f(x, y) + fx_data = fx(x, y) + fy_data = fy(x, y) + + u = grid_variable_dirichlet_function_nonhomogeneous(u_data, offset, grid, f) + u = u.impose_bc() + + grad_x = grids.GridVariable(fx_data, offset, grid) + grad_y = grids.GridVariable(fy_data, offset, grid) + + # Check that gradients are reasonable in interior + interior_grad_x = trim_boundary(fdm.central_difference(u, dim=0)) + interior_grad_y = trim_boundary(fdm.central_difference(u, dim=1)) + + # Get expected gradients at interior points + expected_grad_x = trim_boundary(grad_x) + expected_grad_y = trim_boundary(grad_y) + + # Use relaxed tolerance for finite difference approximation + self.assertAllClose(interior_grad_x, expected_grad_x, atol=6 * h, rtol=h) + self.assertAllClose(interior_grad_y, expected_grad_y, atol=6 * h, rtol=h) + + @parameterized.named_parameters( + dict( + testcase_name="_x_direction_offset_0", + shape=(8, 4), + offset=(0, 0), + f=lambda x, y: x**2, + g=lambda x, y: 2 * torch.ones_like(x), + bc_values=((torch.zeros(4), torch.ones(4)), (None, None)), + periodic_dim=1, + ), + dict( + testcase_name="_x_direction_offset_1", + shape=(8, 8), + offset=(1, 0), + f=lambda x, y: x**2, + g=lambda x, y: 2 * torch.ones_like(x), + bc_values=((torch.zeros(8), torch.ones(8)), (None, None)), + periodic_dim=1, + ), + dict( + testcase_name="_y_direction_offset_0", + shape=(8, 4), + offset=(0, 0), + f=lambda x, y: y**2, + g=lambda x, y: 2 * torch.ones_like(y), + bc_values=((None, None), (torch.zeros(8), torch.ones(8))), + periodic_dim=0, + ), + dict( + testcase_name="_y_direction_offset_1", + shape=(4, 4), + offset=(0, 1), + f=lambda x, y: y**2, + g=lambda x, y: 2 * torch.ones_like(y), + bc_values=((None, None), (torch.zeros(4), torch.ones(4))), + periodic_dim=0, + ), + ) + def test_laplacian_dirichlet_nonhomogeneous( + self, shape, offset, f, g, bc_values, periodic_dim + ): + grid = grids.Grid(shape, domain=((0.0, 1.0), (0.0, 1.0))) + mesh = grid.mesh(offset) + + # u = x^2, Laplacian of u is 2 + u_data = f(*mesh) + expected_laplacian = trim_boundary(grids.GridVariable(g(*mesh), offset, grid)) + + # Create GridVariable with non-homogeneous Dirichlet BCs + u = grid_variable_dirichlet_nonhomogeneous_and_periodic( + u_data, offset, grid, bc_values, periodic_dim=periodic_dim + ) + # u = u.bc.impose_bc(u, mode=Padding.EXTEND) + u = u.impose_bc() + + # Compute Laplacian using finite differences + actual_laplacian = trim_boundary(fdm.laplacian(u)) + + # Use relaxed tolerance due to boundary effects + self.assertAllClose(actual_laplacian, expected_laplacian, atol=1e-2, rtol=1e-2) + + @parameterized.named_parameters( + dict( + testcase_name="_constant_offset_0_0", + shape=(4, 4), + offset=(0, 0), + f=lambda x, y: torch.ones_like(x), + g=lambda x, y: torch.zeros_like(x), + bc_values=((torch.ones(4), torch.ones(4)), (torch.ones(4), torch.ones(4))), + atol=1e-3, + rtol=1e-10, + ), + dict( + testcase_name="_constant_offset_0_1", + shape=(4, 4), + offset=(0, 1), + f=lambda x, y: torch.ones_like(x), + g=lambda x, y: torch.zeros_like(x), + bc_values=((torch.ones(4), torch.ones(4)), (torch.ones(4), torch.ones(4))), + atol=1e-3, + rtol=1e-10, + ), + dict( + testcase_name="_constant_offset_1_0", + shape=(4, 4), + offset=(1, 0), + f=lambda x, y: torch.ones_like(x), + g=lambda x, y: torch.zeros_like(x), + bc_values=((torch.ones(4), torch.ones(4)), (torch.ones(4), torch.ones(4))), + atol=1e-3, + rtol=1e-10, + ), + dict( + testcase_name="_constant_offset_1_1", + shape=(4, 4), + offset=(1, 1), + f=lambda x, y: torch.ones_like(x), + g=lambda x, y: torch.zeros_like(x), + bc_values=((torch.ones(4), torch.ones(4)), (torch.ones(4), torch.ones(4))), + atol=1e-3, + rtol=1e-10, + ), + dict( + testcase_name="_linear_offset_0_0", + shape=(8, 8), + offset=(0, 0), + f=lambda x, y: x + 2 * y, + g=lambda x, y: torch.zeros_like(x), + bc_values=( + (torch.linspace(0, 2, 9)[:-1], torch.linspace(1, 3, 9)[:-1]), + (torch.linspace(0, 1, 9)[:-1], torch.linspace(2, 3, 9)[:-1]), + ), # ((2y, 1+2y), (x, 2+x)) + atol=1e-3, + rtol=1e-10, + ), + dict( + testcase_name="_linear_offset_0_1", + shape=(8, 8), + offset=(0, 1), + f=lambda x, y: x + 2 * y, + g=lambda x, y: torch.zeros_like(x), + bc_values=( + (torch.linspace(0, 2, 9)[1:], torch.linspace(1, 3, 9)[1:]), + (torch.linspace(0, 1, 9)[:-1], torch.linspace(2, 3, 9)[:-1]), + ), # ((2y, 1+2y), (x, 2+x)) + atol=1e-3, + rtol=1e-10, + ), + dict( + testcase_name="_linear_offset_1_0", + shape=(8, 8), + offset=(1, 0), + f=lambda x, y: x + 2 * y, + g=lambda x, y: torch.zeros_like(x), + bc_values=( + (torch.linspace(0, 2, 9)[:-1], torch.linspace(1, 3, 9)[:-1]), + (torch.linspace(0, 1, 9)[1:], torch.linspace(2, 3, 9)[1:]), + ), # ((2y, 1+2y), (x, 2+x)) + atol=1e-3, + rtol=1e-10, + ), + dict( + testcase_name="_linear_offset_1_1", + shape=(8, 8), + offset=(1, 1), + f=lambda x, y: x + 2 * y, + g=lambda x, y: torch.zeros_like(x), + bc_values=( + (torch.linspace(0, 2, 9)[1:], torch.linspace(1, 3, 9)[1:]), + (torch.linspace(0, 1, 9)[1:], torch.linspace(2, 3, 9)[1:]), + ), # ((2y, 1+2y), (x, 2+x)) + atol=1e-3, + rtol=1e-10, + ), + ) + def test_laplacian_dirichlet_nonhomogeneous_2d( + self, shape, offset, f, g, bc_values, atol, rtol + ): + grid = grids.Grid(shape, domain=((0.0, 1.0), (0.0, 1.0))) + mesh = grid.mesh(offset) + + # Create GridVariable with non-homogeneous Dirichlet BCs + u_data = f(*mesh) + expected_laplacian = trim_boundary(grids.GridVariable(g(*mesh), offset, grid)) + + u = grid_variable_dirichlet_nonhomogeneous(u_data, offset, grid, bc_values) + u = u.impose_bc() + + # Compute Laplacian using finite differences + actual_laplacian = trim_boundary(fdm.laplacian(u)) + + # Use relaxed tolerance due to boundary effects + self.assertAllClose(actual_laplacian, expected_laplacian, atol=atol, rtol=rtol) + + @parameterized.named_parameters( + dict( + testcase_name="_linear_cell_center", + shape=(16, 16), + offset=(0.5, 0.5), + f=lambda x, y: x + 2 * y, + g=lambda x, y: torch.zeros_like(x), + atol=1e-6, + rtol=1e-10, + ), + dict( + testcase_name="_linear_vertical_edge_center", + shape=(16, 16), + offset=(1.0, 0.5), + f=lambda x, y: 2 * x + y, + g=lambda x, y: torch.zeros_like(x), + atol=1e-6, + rtol=1e-10, + ), + dict( + testcase_name="_linear_horizontal_edge_center", + shape=(16, 16), + offset=(0.5, 1.0), + f=lambda x, y: x + 2 * y, + g=lambda x, y: torch.zeros_like(x), + atol=1e-6, + rtol=1e-10, + ), + dict( + testcase_name="_quadratic_lower_left_corner", + shape=(32, 32), + offset=(0, 0), + f=lambda x, y: x**2 + 2 * y**2 + x * y, + g=lambda x, y: (2 + 4) * torch.ones_like(x), + atol=1 / 32, + rtol=1e-2, + ), + dict( + testcase_name="_quadratic_vertical_edge_center", + shape=(32, 32), + offset=(1, 0.5), + f=lambda x, y: 3 * x**2 + y**2 - x * y, + g=lambda x, y: (6 + 2) * torch.ones_like(x), + atol=1 / 32, + rtol=1e-2, + ), + dict( + testcase_name="_quadratic_lower_right_corner", + shape=(32, 32), + offset=(1.0, 0), + f=lambda x, y: 0.5 * x**2 + 1.5 * y**2 + 2 * x * y, + g=lambda x, y: (1 + 3) * torch.ones_like(x), + atol=1 / 32, + rtol=1e-2, + ), + dict( + testcase_name="_quadratic_upper_right_corner", + shape=(32, 32), + offset=(1, 1), + f=lambda x, y: 2 * x**2 + 0.5 * y**2 - 0.5 * x * y + x + y, + g=lambda x, y: (4 + 1) * torch.ones_like(x), + atol=1 / 32, + rtol=1e-2, + ), + dict( + testcase_name="_quadratic_horizontal_edge_center", + shape=(32, 32), + offset=(0.5, 1.0), + f=lambda x, y: x**2 + y**2 + 3 * x * y + 2 * x - y, + g=lambda x, y: (2 + 2) * torch.ones_like(x), + atol=1 / 32, + rtol=1e-2, + ), + dict( + testcase_name="_quadratic_cell_center", + shape=(16, 16), + offset=(0.5, 0.5), + f=lambda x, y: 4 * x**2 + 3 * y**2 + 2 * x * y, + g=lambda x, y: (8.0 + 6.0) * torch.ones_like(x), + atol=1 / 16, + rtol=1e-2, + ), + ) + def test_laplacian_dirichlet_function_nonhomogeneous_2d( + self, shape, offset, f, g, atol, rtol + ): + """Test Laplacian with FunctionBoundaryConditions using quadratic functions.""" + grid = grids.Grid(shape, domain=((-1.0, 1.0), (-1.0, 1.0))) + mesh = grid.mesh(offset) + + # Create GridVariable with function-based non-homogeneous Dirichlet BCs + u_data = f(*mesh) + expected_laplacian = trim_boundary(grids.GridVariable(g(*mesh), offset, grid)) + + # Use FunctionBoundaryConditions instead of discrete values + u = grid_variable_dirichlet_function_nonhomogeneous(u_data, offset, grid, f) + u = u.impose_bc() + + # Compute Laplacian using finite differences + actual_laplacian = trim_boundary(fdm.laplacian(u)) + + self.assertAllClose(actual_laplacian, expected_laplacian, atol=atol, rtol=rtol) + + @parameterized.named_parameters( + dict( + testcase_name="_laplacian_consistency_8x8", + shape=(8, 8), + ), + dict( + testcase_name="_laplacian_consistency_8x16", + shape=(8, 16), + ), + dict( + testcase_name="_laplacian_consistency_16x8", + shape=(16, 8), + ), + dict( + testcase_name="_laplacian_consistency_32x32", + shape=(32, 32), + ), + ) + def test_laplacian_consistency(self, shape): + """Test that Laplacian computation is consistent across different grid resolutions.""" + f = lambda x, y: 0.25 * (x**2 + y**2) + offsets = [(0, 0), (0.5, 1), (1.0, 0.5), (0.5, 0.5), (1.0, 1.0)] + + for offset in offsets: + grid = grids.Grid(shape, domain=((0.0, 1.0), (0.0, 1.0))) + u_data = f(*grid.mesh(offset)) + u = grid_variable_dirichlet_function_nonhomogeneous(u_data, offset, grid, f) + u = u.impose_bc() + + laplacian_result = fdm.laplacian(u) + + # Check interior points only where we expect Laplacian ≈ 4 + interior_laplacian = trim_boundary(laplacian_result).data + expected_interior = torch.ones_like(interior_laplacian) + + self.assertAllClose( + interior_laplacian, expected_interior, atol=1e-3, rtol=1e-2 + ) + + class FiniteDifferenceBatchTest(test_utils.TestCase): """Test finite difference operations with batch dimensions in 2D.""" @@ -490,30 +1026,29 @@ def test_finite_difference_batch_analytic( def test_laplacian_batch(self): """Test Laplacian operator with batch dimensions.""" batch_size = 2 - shape = (20, 20) - step = (0.1, 0.1) - grid = grids.Grid(shape, step) + shape = (32, 64) + grid = grids.Grid(shape, domain=((-1.0, 1.0), (-1.0, 1.0))) offset = (0, 0) - # Test function: f(x,y) = x^2 + y^2, so Laplacian should be 4 + f = lambda x, y: x**2 + y**2 mesh = grid.mesh(offset) - single_data = mesh[0] ** 2 + mesh[1] ** 2 - batched_data = single_data.unsqueeze(0).repeat(batch_size, 1, 1) + single_data = f(*mesh) + batched_data = repeat(single_data, "h w -> b h w", b=batch_size) - u = grid_variable_periodic(batched_data, offset, grid) + u = grid_variable_dirichlet_function_nonhomogeneous(batched_data, offset, grid, f) actual_laplacian = fdm.laplacian(u) # Expected Laplacian is 4 everywhere expected_laplacian = 4 * torch.ones(batch_size, *shape) # Trim boundary for comparison - trimmed_actual = _trim_boundary(actual_laplacian) - trimmed_expected = _trim_boundary( + trimmed_actual = trim_boundary(actual_laplacian) + trimmed_expected = trim_boundary( grids.GridVariable(expected_laplacian, offset, grid) ) self.assertAllClose( - trimmed_expected.data, trimmed_actual.data, atol=1e-2, rtol=1e-8 + trimmed_expected.data, trimmed_actual.data, atol=1e-3, rtol=1e-8 ) def test_laplacian_batch_analytic(self): @@ -533,8 +1068,8 @@ def test_laplacian_batch_analytic(self): batch_laplacian = fdm.laplacian(u_batch) # Trim boundary for comparison - trimmed_single = _trim_boundary(single_laplacian) - trimmed_batch = _trim_boundary(batch_laplacian) + trimmed_single = trim_boundary(single_laplacian) + trimmed_batch = trim_boundary(batch_laplacian) for i in range(batch_size): self.assertAllClose( @@ -571,8 +1106,8 @@ def test_divergence_batch(self): expected_divergence = 2 * torch.ones(batch_size, *shape) # Trim boundary for comparison - trimmed_actual = _trim_boundary(actual_divergence) - trimmed_expected = _trim_boundary( + trimmed_actual = trim_boundary(actual_divergence) + trimmed_expected = trim_boundary( grids.GridVariable(expected_divergence, (0, 0), grid) ) @@ -580,45 +1115,56 @@ def test_divergence_batch(self): trimmed_expected.data, trimmed_actual.data, atol=1e-2, rtol=1e-8 ) - def test_curl_2d_batch(self): + @parameterized.named_parameters( + dict( + testcase_name="_solenoidal_8x8", + shape=(8, 8), + offsets=((0.5, 0), (0, 0.5)), + f=lambda x, y: (y, -x), + g=lambda x, y: -2 * torch.ones_like(x), + bc_u=((None, None), (torch.zeros(8), torch.ones(8))), + bc_v=((torch.zeros(8), -torch.ones(8)), (None, None)), + ), + dict( + testcase_name="_solenoidal_32x32", + shape=(32, 32), + offsets=((0.5, 0), (0, 0.5)), + f=lambda x, y: (y, -x), + g=lambda x, y: -2 * torch.ones_like(x), + bc_u=((None, None), (torch.zeros(32), torch.ones(32))), + bc_v=((torch.zeros(32), -torch.ones(32)), (None, None)), + ), + dict( + testcase_name="_wikipedia_example_2d_21x21", + shape=(21, 21), + offsets=((0.5, 0), (0, 0.5)), + f=lambda x, y: (torch.ones_like(x), -(x**2)), + g=lambda x, y: -2 * x, + bc_u=((None, None), (torch.ones(21), torch.ones(21))), + bc_v=((torch.zeros(21), -torch.ones(21)), (None, None)), + ), + ) + def test_curl_2d_batch(self, shape, offsets, f, g, bc_u, bc_v): """Test 2D curl operator with batch dimensions.""" batch_size = 2 - shape = (20, 20) - step = (0.1, 0.1) - grid = grids.Grid(shape, step) + grid = grids.Grid(shape, domain=((0, 1), (0, 1))) offsets = ((0.5, 0), (0, 0.5)) - # Test vector field: v = (-y, x), so curl should be 2 - mesh_x = grid.mesh(offsets[0]) - mesh_y = grid.mesh(offsets[1]) - - # Create batched vector components - vx_single = -mesh_x[1] # -y component - vy_single = mesh_y[0] # x component - - vx_batched = repeat(vx_single, "h w -> b h w", b=batch_size) - vy_batched = repeat(vy_single, "h w -> b h w", b=batch_size) - - v = [ - grid_variable_periodic(vx_batched, offsets[0], grid), - grid_variable_periodic(vy_batched, offsets[1], grid), - ] + v = grid_variable_vector_batch_from_functions( + grid, offsets, f, bc_u, bc_v, batch_size=batch_size + ) - actual_curl = fdm.curl_2d(v) + actual_curl = trim_boundary(fdm.curl_2d(v)) # Expected curl is 2 everywhere result_offset = (0.5, 0.5) - expected_curl = 2 * torch.ones(batch_size, *shape) - - # Trim boundary for comparison - trimmed_actual = _trim_boundary(actual_curl) - trimmed_expected = _trim_boundary( + expected_curl = g(*grid.mesh(result_offset)) + expected_curl = repeat(expected_curl, "h w -> b h w", b=batch_size) + expected_curl = trim_boundary( grids.GridVariable(expected_curl, result_offset, grid) ) - self.assertAllClose( - trimmed_expected.data, trimmed_actual.data, atol=1e-2, rtol=1e-8 - ) + self.assertAllClose(actual_curl.data, expected_curl.data, atol=1e-2, rtol=1e-8) def test_batch_consistency_across_operations(self): """Test that batch operations are consistent across different batch sizes.""" @@ -644,6 +1190,240 @@ def test_batch_consistency_across_operations(self): self.assertAllClose(grad_x.data[0], grad_x.data[i]) self.assertAllClose(grad_y.data[0], grad_y.data[i]) + @parameterized.named_parameters( + dict( + testcase_name="_quadratic_1_batch", + batch_size=3, + shape=(16, 16), + offset=(1, 0), + f=lambda x, y: x**2 + 2 * y**2 + x * y, + g=lambda x, y: (2 + 4) * torch.ones_like(x), + atol=1 / 16, + rtol=1e-2, + ), + dict( + testcase_name="_quadratic_2_batch", + batch_size=4, + shape=(32, 32), + offset=(0, 1), + f=lambda x, y: 3 * x**2 + y**2 - x * y, + g=lambda x, y: (6 + 2) * torch.ones_like(x), + atol=1 / 32, + rtol=1e-2, + ), + dict( + testcase_name="_quadratic_and_trig_batch", + batch_size=2, + shape=(32, 32), + offset=(1, 0), + f=lambda x, y: 4 * x**2 + + 3 * y**2 + + 2 * x * y + + torch.sin(torch.pi * x) * torch.sin(torch.pi * y)/2, + g=lambda x, y: 14 * torch.ones_like(x) + - torch.pi**2 * torch.sin(torch.pi * x) * torch.sin(torch.pi * y), + atol=1 / 32, + rtol=1e-2, + ), + ) + def test_laplacian_dirichlet_function_nonhomogeneous_batch( + self, batch_size, shape, offset, f, g, atol, rtol + ): + """Test Laplacian with FunctionBoundaryConditions using quadratic functions with batch dimensions.""" + grid = grids.Grid(shape, domain=((-1.0, 1.0), (-1.0, 1.0))) + mesh = grid.mesh(offset) + + # Create batched data by repeating the same function + single_u_data = f(*mesh) + batched_u_data = repeat(single_u_data, "h w -> b h w", b=batch_size) + + single_expected = g(*mesh) + batched_expected = repeat(single_expected, "h w -> b h w", b=batch_size) + expected_laplacian = trim_boundary( + grids.GridVariable(batched_expected, offset, grid) + ) + + # Use FunctionBoundaryConditions with batched data + u = grid_variable_dirichlet_function_nonhomogeneous( + batched_u_data, offset, grid, f + ) + u = u.impose_bc() + + # Compute Laplacian using finite differences + actual_laplacian = trim_boundary(fdm.laplacian(u)) + + # Check that batch dimension is preserved + self.assertEqual(actual_laplacian.data.shape[0], batch_size) + + # Check accuracy for each batch element + self.assertAllClose( + actual_laplacian.data, expected_laplacian.data, atol=atol, rtol=rtol + ) + + @parameterized.named_parameters( + dict( + testcase_name="_linear_discrete_bc_batch", + batch_size=2, + shape=(8, 8), + offset=(0, 0), + f=lambda x, y: x + 2 * y, + g=lambda x, y: torch.zeros_like(x), + bc_values_func=lambda: ( + (torch.linspace(0, 2, 9)[:-1], torch.linspace(1, 3, 9)[:-1]), + (torch.linspace(0, 1, 9)[:-1], torch.linspace(2, 3, 9)[:-1]), + ), + atol=1e-3, + rtol=1e-10, + ), + dict( + testcase_name="_constant_discrete_bc_batch", + batch_size=4, + shape=(12, 12), + offset=(1, 1), + f=lambda x, y: torch.ones_like(x), + g=lambda x, y: torch.zeros_like(x), + bc_values_func=lambda: ( + (torch.ones(12), torch.ones(12)), + (torch.ones(12), torch.ones(12)), + ), + atol=1e-3, + rtol=1e-10, + ), + dict( + testcase_name="_quadratic_discrete_bc_batch", + batch_size=3, + shape=(16, 16), + offset=(0.5, 0.5), + f=lambda x, y: x**2 + y**2, + g=lambda x, y: 4 * torch.ones_like(x), + bc_values_func=lambda: ( + (torch.linspace(0, 2, 17)[1:-1], torch.linspace(1, 3, 17)[1:-1]), + (torch.linspace(0, 1, 17)[1:-1], torch.linspace(1, 2, 17)[1:-1]), + ), + atol=1e-2, + rtol=1e-2, + ), + ) + def test_laplacian_dirichlet_discrete_nonhomogeneous_batch( + self, batch_size, shape, offset, f, g, bc_values_func, atol, rtol + ): + """Test Laplacian with discrete non-homogeneous Dirichlet BCs with batch dimensions.""" + grid = grids.Grid(shape, domain=((0.0, 1.0), (0.0, 1.0))) + mesh = grid.mesh(offset) + + # Create batched data + single_u_data = f(*mesh) + batched_u_data = repeat(single_u_data, "h w -> b h w", b=batch_size) + + single_expected = g(*mesh) + batched_expected = repeat(single_expected, "h w -> b h w", b=batch_size) + expected_laplacian = trim_boundary( + grids.GridVariable(batched_expected, offset, grid) + ) + + # Get boundary condition values + bc_values = bc_values_func() + + # Create GridVariable with batched non-homogeneous Dirichlet BCs + u = grid_variable_dirichlet_nonhomogeneous( + batched_u_data, offset, grid, bc_values + ) + u = u.impose_bc() + + # Compute Laplacian using finite differences + actual_laplacian = trim_boundary(fdm.laplacian(u)) + + # Check that batch dimension is preserved + self.assertEqual(actual_laplacian.data.shape[0], batch_size) + + # Check accuracy + self.assertAllClose( + actual_laplacian.data, expected_laplacian.data, atol=atol, rtol=rtol + ) + + @parameterized.named_parameters( + dict( + testcase_name="_quadratic_gradient_batch", + batch_size=2, + shape=(16, 16), + offset=(0, 0), + f=lambda x, y: x**2 + 2 * y**2, + fx=lambda x, y: 2 * x, + fy=lambda x, y: 4 * y, + ), + dict( + testcase_name="_mixed_quadratic_gradient_batch", + batch_size=3, + shape=(32, 32), + offset=(0.5, 1), + f=lambda x, y: 3 * x**2 + y**2 + x * y, + fx=lambda x, y: 6 * x + y, + fy=lambda x, y: 2 * y + x, + ), + dict( + testcase_name="_cubic_gradient_batch", + batch_size=4, + shape=(24, 24), + offset=(1, 0.5), + f=lambda x, y: x**3 + y**3 + x * y**2, + fx=lambda x, y: 3 * x**2 + y**2, + fy=lambda x, y: 3 * y**2 + 2 * x * y, + ), + ) + def test_central_difference_function_nonhomogeneous_batch( + self, batch_size, shape, offset, f, fx, fy + ): + """Test central difference with FunctionBoundaryConditions with batch dimensions.""" + grid = grids.Grid(shape, domain=((0.0, 1.0), (0.0, 1.0))) + x, y = grid.mesh(offset) + h = max(grid.step) + + # Create batched data + single_u_data = f(x, y) + batched_u_data = repeat(single_u_data, "h w -> b h w", b=batch_size) + + single_fx_data = fx(x, y) + single_fy_data = fy(x, y) + batched_fx_data = repeat(single_fx_data, "h w -> b h w", b=batch_size) + batched_fy_data = repeat(single_fy_data, "h w -> b h w", b=batch_size) + + u = grid_variable_dirichlet_function_nonhomogeneous( + batched_u_data, offset, grid, f + ) + u = u.impose_bc() + + expected_grad_x = grids.GridVariable(batched_fx_data, offset, grid) + expected_grad_y = grids.GridVariable(batched_fy_data, offset, grid) + + # Check that gradients are reasonable in interior + interior_grad_x = trim_boundary(fdm.central_difference(u, dim=0)) + interior_grad_y = trim_boundary(fdm.central_difference(u, dim=1)) + + # Get expected gradients at interior points + expected_grad_x_interior = trim_boundary(expected_grad_x) + expected_grad_y_interior = trim_boundary(expected_grad_y) + + # Check that batch dimension is preserved + self.assertEqual(interior_grad_x.data.shape[0], batch_size) + self.assertEqual(interior_grad_y.data.shape[0], batch_size) + + # Use relaxed tolerance for finite difference approximation + self.assertAllClose( + interior_grad_x.data, expected_grad_x_interior.data, atol=6 * h, rtol=h + ) + self.assertAllClose( + interior_grad_y.data, expected_grad_y_interior.data, atol=6 * h, rtol=h + ) + + # Test batch consistency: each batch element should be identical + for i in range(1, batch_size): + self.assertAllClose( + interior_grad_x.data[0], interior_grad_x.data[i], atol=1e-12, rtol=1e-15 + ) + self.assertAllClose( + interior_grad_y.data[0], interior_grad_y.data[i], atol=1e-12, rtol=1e-15 + ) + if __name__ == "__main__": absltest.main() diff --git a/torch_cfd/tests/test_grids.py b/torch_cfd/tests/test_grids.py index 275539d..9cc828b 100644 --- a/torch_cfd/tests/test_grids.py +++ b/torch_cfd/tests/test_grids.py @@ -397,15 +397,42 @@ def test_interior_dirichlet(self): expected_data=tensor([2, 3, 4, 5, -5]), expected_offset=(1.5,), ), + dict( + shape=(5,), + data=tensor([1, 2, 3, 4, 5]), + offset=(0.5,), + shift_offset=-1, + bc_values=((0.0, 0.0),), + expected_data=tensor([-1, 1, 2, 3, 4]), + expected_offset=(-0.5,), + ), dict( shape=(5,), data=tensor([1, 2, 3, 4, 5]), offset=(0.0,), shift_offset=-1, bc_values=((0.0, 0.0),), - expected_data=tensor([0, 0, 2, 3, 4]), + expected_data=tensor([-2, 0, 2, 3, 4]), expected_offset=(-1.0,), ), + dict( + shape=(5,), + data=tensor([1, 2, 3, 4, 5]), + offset=(0.0,), + shift_offset=1, + bc_values=((0.0, 0.0),), + expected_data=tensor([2, 3, 4, 5, 0]), + expected_offset=(1.0,), + ), + dict( + shape=(5,), + data=tensor([1, 2, 3, 4, 5]), + offset=(0.0,), + shift_offset=1, + bc_values=((0.0, 20.0),), + expected_data=tensor([2, 3, 4, 5, 20]), + expected_offset=(1.0,), + ), dict( shape=(5,), data=tensor([1, 2, 3, 4, 5]), @@ -426,12 +453,10 @@ def test_shift_1d_dirichlet( expected_data, expected_offset, ): - """Test grids.shift with 1D arrays - """ + """Test grids.shift with 1D arrays""" grid = grids.Grid(shape) - bc = boundaries.dirichlet_boundary_conditions( - ndim=1, bc_values=bc_values) + bc = boundaries.dirichlet_boundary_conditions(ndim=1, bc_values=bc_values) u = grids.GridVariable(data, offset, grid, bc) u = u.impose_bc() @@ -475,8 +500,7 @@ def test_shift_1d_periodic( """Test grids.shift with 1D arrays and various boundary conditions.""" grid = grids.Grid(shape) - bc = boundaries.periodic_boundary_conditions( - ndim=1) + bc = boundaries.periodic_boundary_conditions(ndim=1) u = grids.GridVariable(data, offset, grid, bc) u_shifted = u.shift(offset=shift_offset, dim=0) @@ -485,7 +509,6 @@ def test_shift_1d_periodic( self.assertEqual(u_shifted.grid, grid) self.assertIsNone(u_shifted.bc) - @parameterized.parameters( dict( shape=(10,), @@ -1371,6 +1394,177 @@ def test_domain_interior_masks(self): ) self.assertAllClose(expected, grids.domain_interior_masks(grid)) + @parameterized.named_parameters( + # Test dim=0 (x boundaries) with different offsets + dict( + testcase_name="_dim_0_lower_left_corner", + shape=(4, 6), + domain=((0.0, 2.0), (0.0, 3.0)), + dim=0, + offset=(0.0, 0.0), + expected_lower_x=0.0, + expected_upper_x=2.0, + expected_y=tensor([0.0, 0.5, 1.0, 1.5, 2.0, 2.5]), + ), + dict( + testcase_name="_dim_0_edge_center", + shape=(4, 6), + domain=((0.0, 2.0), (0.0, 3.0)), + dim=0, + offset=(0.0, 0.5), + expected_lower_x=0.0, + expected_upper_x=2.0, + expected_y=tensor([0.25, 0.75, 1.25, 1.75, 2.25, 2.75]), + ), + dict( + testcase_name="_dim_0_cell_center", + # for cell center, the boundary mesh is not technically at the edge center + # but if a function-valued bc is given, the value at these edge center will be used to impose the boundary condition + shape=(4, 6), + domain=((0.0, 2.0), (0.0, 3.0)), + dim=0, + offset=(0.5, 0.5), + expected_lower_x=0.0, + expected_upper_x=2.0, + expected_y=tensor([0.25, 0.75, 1.25, 1.75, 2.25, 2.75]), + ), + dict( + testcase_name="_dim_0_upper_right_corner", + shape=(4, 6), + domain=((0.0, 2.0), (0.0, 3.0)), + dim=0, + offset=(1.0, 1.0), + expected_lower_x=0.0, + expected_upper_x=2.0, + expected_y=tensor([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]), + ), + # Test dim=1 (y boundaries) with different offsets + dict( + testcase_name="_dim_1_offset_lower_left_corner", + shape=(4, 6), + domain=((0.0, 2.0), (0.0, 3.0)), + dim=1, + offset=(0.0, 0.0), + expected_lower_y=0.0, + expected_upper_y=3.0, + expected_x=tensor([0.0, 0.5, 1.0, 1.5]), + ), + dict( + testcase_name="_dim_1_offset_edge_center", + shape=(4, 6), + domain=((0.0, 2.0), (0.0, 3.0)), + dim=1, + offset=(0.5, 0.0), + expected_lower_y=0.0, + expected_upper_y=3.0, + expected_x=tensor([0.25, 0.75, 1.25, 1.75]), + ), + dict( + testcase_name="_dim_1_offset_upper_right_corner", + shape=(4, 6), + domain=((0.0, 2.0), (0.0, 3.0)), + dim=1, + offset=(1.0, 1.0), + expected_lower_y=0.0, + expected_upper_y=3.0, + expected_x=tensor([0.5, 1.0, 1.5, 2.0]), + ), + # Test with different domains + dict( + testcase_name="_dim_0_custom_domain", + shape=(3, 4), + domain=((-1.0, 1.0), (-2.0, 2.0)), + dim=0, + offset=(0.5, 0.5), + expected_lower_x=-1.0, + expected_upper_x=1.0, + expected_y=tensor([-1.5, -0.5, 0.5, 1.5]), + ), + dict( + testcase_name="_dim_1_custom_domain", + shape=(3, 4), + domain=((-1.0, 1.0), (-2.0, 2.0)), + dim=1, + offset=(0.5, 0.5), + expected_lower_y=-2.0, + expected_upper_y=2.0, + expected_x=tensor([-2/3, 0.0, 2/3]), + ), + ) + def test_boundary_mesh_2d( + self, + shape, + domain, + dim, + offset, + expected_lower_x=None, + expected_upper_x=None, + expected_lower_y=None, + expected_upper_y=None, + expected_x=None, + expected_y=None, + ): + """Test Grid.boundary_mesh for 2D grids with different offsets and dimensions.""" + grid = grids.Grid(shape, domain=domain) + + lower_coords, upper_coords = grid.boundary_mesh(dim, offset) + + if dim == 0: # x boundaries + # Check structure: ((x_left, y_coords), (x_right, y_coords)) + self.assertEqual(len(lower_coords), 2) + self.assertEqual(len(upper_coords), 2) + + x_left, y_coords_left = lower_coords + x_right, y_coords_right = upper_coords + + # Check y coordinates (should be the same for both boundaries) + self.assertAllClose(y_coords_left, expected_y, atol=1e-6, rtol=1e-10) + self.assertAllClose(y_coords_right, expected_y, atol=1e-6, rtol=1e-10) + + # Check x coordinates (should be constant arrays with boundary values) + self.assertTrue( + torch.allclose(x_left, torch.full_like(y_coords_left, expected_lower_x)) + ) + self.assertTrue( + torch.allclose( + x_right, torch.full_like(y_coords_right, expected_upper_x) + ) + ) + + # Check shapes + self.assertEqual(x_left.shape, (shape[1],)) + self.assertEqual(x_right.shape, (shape[1],)) + self.assertEqual(y_coords_left.shape, (shape[1],)) + self.assertEqual(y_coords_right.shape, (shape[1],)) + + elif dim == 1: # y boundaries + # Check structure: ((x_coords, y_bottom), (x_coords, y_top)) + self.assertEqual(len(lower_coords), 2) + self.assertEqual(len(upper_coords), 2) + + x_coords_bottom, y_bottom = lower_coords + x_coords_top, y_top = upper_coords + + # Check x coordinates (should be the same for both boundaries) + self.assertAllClose(x_coords_bottom, expected_x, atol=1e-6, rtol=1e-10) + self.assertAllClose(x_coords_top, expected_x, atol=1e-6, rtol=1e-10) + + # Check y coordinates (should be constant arrays with boundary values) + self.assertTrue( + torch.allclose( + y_bottom, torch.full_like(x_coords_bottom, expected_lower_y) + ) + ) + self.assertTrue( + torch.allclose(y_top, torch.full_like(x_coords_top, expected_upper_y)) + ) + + # Check shapes + self.assertEqual(x_coords_bottom.shape, (shape[0],)) + self.assertEqual(x_coords_top.shape, (shape[0],)) + self.assertEqual(y_bottom.shape, (shape[0],)) + self.assertEqual(y_top.shape, (shape[0],)) + if __name__ == "__main__": absltest.main()