Skip to content

Commit b18800b

Browse files
Define getidx_return_type (#2280)
1 parent 40f7b0f commit b18800b

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

src/Operators/finitedifference.jl

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ struct RightBoundaryWindow{name} <: BoundaryWindow end
164164
165165
An abstract type for finite difference operators. Instances of this should define:
166166
167+
- [`getidx_return_type`](@ref)
168+
- [`stencil_return_type`](@ref)
167169
- [`return_eltype`](@ref)
168170
- [`return_space`](@ref)
169171
- [`stencil_interior_width`](@ref)
@@ -175,6 +177,18 @@ abstract type FiniteDifferenceOperator <: AbstractOperator end
175177

176178
return_eltype(::FiniteDifferenceOperator, arg) = eltype(arg)
177179

180+
"""
181+
getidx_return_type(::Base.Broadcasted)
182+
getidx_return_type(::StencilBroadcasted)
183+
getidx_return_type(::Field)
184+
getidx_return_type(::Any)
185+
...
186+
187+
The return type of `getidx` on the arguemnt.
188+
Defaults to the type of the argument.
189+
"""
190+
function getidx_return_type end
191+
178192
# boundary width error fallback
179193
@noinline invalid_boundary_condition_error(op_type::Type, bc_type::Type) =
180194
error("Boundary `$bc_type` is not supported for operator `$op_type`")
@@ -327,6 +341,13 @@ Defines the stencil of the operator `Op` in the interior of the domain at `idx`;
327341
"""
328342
function stencil_interior end
329343

344+
"""
345+
stencil_return_type(::Op, args...)
346+
347+
The return type of the given stencil and arguments.
348+
"""
349+
function stencil_return_type end
350+
330351

331352
"""
332353
boundary_width(::Op, ::BC, args...)
@@ -355,6 +376,14 @@ function stencil_right_boundary end
355376

356377
abstract type InterpolationOperator <: FiniteDifferenceOperator end
357378

379+
# single argument interpolation must be the return type of getidx on the
380+
# argument, which should be cheaper / simpler than return_eltype(op, args...)
381+
@inline stencil_return_type(::InterpolationOperator, arg) =
382+
getidx_return_type(arg)
383+
384+
@inline stencil_return_type(op::FiniteDifferenceOperator, args...) =
385+
return_eltype(op, args...)
386+
358387
function assert_no_bcs(op, kwargs)
359388
length(kwargs) == 0 && return nothing
360389
error("InterpolateF2C does not accept boundary conditions.")
@@ -3812,6 +3841,20 @@ Base.@propagate_inbounds function getidx(
38123841
end
38133842
end
38143843

3844+
@inline getidx_return_type(scalar::Tuple{<:Any}) = eltype(scalar)
3845+
@inline getidx_return_type(scalar::Ref) = eltype(scalar)
3846+
@inline getidx_return_type(x::T) where {T} = T
3847+
@inline getidx_return_type(f::Fields.Field) = eltype(f)
3848+
3849+
@inline getidx_return_type(bc::Base.Broadcast.Broadcasted) =
3850+
Base.promote_op(bc.f, map(getidx_return_type, bc.args)...)
3851+
3852+
@inline getidx_return_type(op::AbstractOperator, args...) =
3853+
stencil_return_type(bc.op, bc.args...)
3854+
3855+
@inline getidx_return_type(bc::StencilBroadcasted) =
3856+
stencil_return_type(bc.op, bc.args...)
3857+
38153858
# broadcasting a ColumnStencilStyle gives the StencilBroadcasted's style
38163859
Base.Broadcast.BroadcastStyle(
38173860
::Type{<:StencilBroadcasted{Style}},
@@ -4104,6 +4147,7 @@ Base.@propagate_inbounds function apply_stencil!(
41044147
hidx,
41054148
(li, lw, rw, ri) = window_bounds(space, bc),
41064149
)
4150+
T = getidx_return_type(bc)
41074151
if !Topologies.isperiodic(Spaces.vertical_topology(space))
41084152
# left window
41094153
lbw = LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
@@ -4114,7 +4158,7 @@ Base.@propagate_inbounds function apply_stencil!(
41144158
end
41154159
# interior
41164160
@inbounds for idx in lw:rw
4117-
val = getidx(space, bc, Interior(), idx, hidx)
4161+
val = getidx(space, bc, Interior(), idx, hidx)::T
41184162
setidx!(space, field_out, idx, hidx, val)
41194163
end
41204164
if !Topologies.isperiodic(Spaces.vertical_topology(space))

0 commit comments

Comments
 (0)