Skip to content

Commit 50b1dd1

Browse files
Reduce latency in stencil operators (#2225)
1 parent 869cc54 commit 50b1dd1

File tree

2 files changed

+53
-97
lines changed

2 files changed

+53
-97
lines changed

src/DataLayouts/struct.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,13 @@ Base.@propagate_inbounds @generated function get_struct(
179179
::Val{D},
180180
start_index::CartesianIndex,
181181
) where {T, S, D}
182+
# recursion base case: hit array type is the same as the struct leaf type
183+
if T === S # Use Union-splitting for better latency
184+
return quote
185+
Base.@_propagate_inbounds_meta
186+
@inbounds return array[start_index]
187+
end
188+
end
182189
tup = :(())
183190
for i in 1:fieldcount(S)
184191
push!(
@@ -201,16 +208,6 @@ Base.@propagate_inbounds @generated function get_struct(
201208
end
202209
end
203210

204-
# recursion base case: hit array type is the same as the struct leaf type
205-
Base.@propagate_inbounds function get_struct(
206-
array::AbstractArray{S},
207-
::Type{S},
208-
::Val{D},
209-
start_index::CartesianIndex,
210-
) where {S, D}
211-
@inbounds return array[start_index]
212-
end
213-
214211
"""
215212
set_struct!(array, val::S, Val(D), start_index)
216213

src/Operators/finitedifference.jl

Lines changed: 46 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -3759,25 +3759,37 @@ end
37593759

37603760
Base.@propagate_inbounds function getidx(
37613761
parent_space,
3762-
bc::StencilBroadcasted,
3763-
loc::Interior,
3764-
idx,
3765-
hidx,
3766-
)
3767-
space = reconstruct_placeholder_space(axes(bc), parent_space)
3768-
stencil_interior(bc.op, loc, space, idx, hidx, bc.args...)
3769-
end
3770-
3771-
Base.@propagate_inbounds function getidx(
3772-
parent_space,
3773-
bc::StencilBroadcasted,
3774-
loc::LeftBoundaryWindow,
3762+
bc::Union{StencilBroadcasted, Base.Broadcast.Broadcasted},
3763+
loc::Location,
37753764
idx,
37763765
hidx,
37773766
)
3767+
# Use Union-splitting here (x isa X) instead of dispatch
3768+
# for improved latency.
37783769
space = reconstruct_placeholder_space(axes(bc), parent_space)
3770+
if bc isa Base.Broadcast.Broadcasted
3771+
# Manually call bc.f for small tuples (improved latency)
3772+
(; args) = bc
3773+
N = length(bc.args)
3774+
if N == 1
3775+
return bc.f(getidx(space, args[1], loc, idx, hidx))
3776+
elseif N == 2
3777+
return bc.f(
3778+
getidx(space, args[1], loc, idx, hidx),
3779+
getidx(space, args[2], loc, idx, hidx),
3780+
)
3781+
elseif N == 3
3782+
return bc.f(
3783+
getidx(space, args[1], loc, idx, hidx),
3784+
getidx(space, args[2], loc, idx, hidx),
3785+
getidx(space, args[3], loc, idx, hidx),
3786+
)
3787+
end
3788+
return call_bc_f(bc.f, space, loc, idx, hidx, args...)
3789+
end
37793790
op = bc.op
3780-
if should_call_left_boundary(idx, space, bc, loc)
3791+
if loc isa LeftBoundaryWindow &&
3792+
should_call_left_boundary(idx, space, bc, loc)
37813793
stencil_left_boundary(
37823794
op,
37833795
get_boundary(op, loc),
@@ -3787,22 +3799,8 @@ Base.@propagate_inbounds function getidx(
37873799
hidx,
37883800
bc.args...,
37893801
)
3790-
else
3791-
# fallback to interior stencil
3792-
stencil_interior(op, loc, space, idx, hidx, bc.args...)
3793-
end
3794-
end
3795-
3796-
Base.@propagate_inbounds function getidx(
3797-
parent_space,
3798-
bc::StencilBroadcasted,
3799-
loc::RightBoundaryWindow,
3800-
idx,
3801-
hidx,
3802-
)
3803-
op = bc.op
3804-
space = reconstruct_placeholder_space(axes(bc), parent_space)
3805-
if should_call_right_boundary(idx, space, bc, loc)
3802+
elseif loc isa RightBoundaryWindow &&
3803+
should_call_right_boundary(idx, space, bc, loc)
38063804
stencil_right_boundary(
38073805
op,
38083806
get_boundary(op, loc),
@@ -3813,11 +3811,11 @@ Base.@propagate_inbounds function getidx(
38133811
bc.args...,
38143812
)
38153813
else
3816-
# fallback to interior stencil
3817-
stencil_interior(op, loc, space, idx, hidx, bc.args...)
3814+
stencil_interior(bc.op, loc, space, idx, hidx, bc.args...)
38183815
end
38193816
end
38203817

3818+
38213819
# broadcasting a StencilStyle gives a CompositeStencilStyle
38223820
Base.Broadcast.BroadcastStyle(
38233821
::Type{<:StencilBroadcasted{Style}},
@@ -3902,48 +3900,25 @@ end
39023900
@noinline inferred_getidx_error(idx_type::Type, space_type::Type) =
39033901
error("Invalid index type `$idx_type` for field on space `$space_type`")
39043902

3905-
39063903
# recursively unwrap getidx broadcast arguments in a way that is statically reducible by the optimizer
3907-
Base.@propagate_inbounds getidx_args(
3908-
space,
3909-
args::Tuple,
3910-
loc::Location,
3911-
idx,
3912-
hidx,
3913-
) = (
3914-
getidx(space, args[1], loc, idx, hidx),
3915-
getidx_args(space, Base.tail(args), loc, idx, hidx)...,
3916-
)
3917-
Base.@propagate_inbounds getidx_args(
3904+
@generated function call_bc_f(
3905+
f::F,
39183906
space,
3919-
arg::Tuple{Any},
3920-
loc::Location,
3921-
idx,
3922-
hidx,
3923-
) = (getidx(space, arg[1], loc, idx, hidx),)
3924-
Base.@propagate_inbounds getidx_args(
3925-
space,
3926-
::Tuple{},
3927-
loc::Location,
3928-
idx,
3929-
hidx,
3930-
) = ()
3931-
3932-
Base.@propagate_inbounds function getidx(
3933-
parent_space,
3934-
bc::Base.Broadcast.Broadcasted,
39353907
loc::Location,
39363908
idx,
39373909
hidx,
3938-
)
3939-
space = reconstruct_placeholder_space(axes(bc), parent_space)
3940-
_args = getidx_args(space, bc.args, loc, idx, hidx)
3941-
bc.f(_args...)
3910+
args...,
3911+
) where {F}
3912+
N = length(args)
3913+
return quote
3914+
Base.@_propagate_inbounds_meta
3915+
Base.Cartesian.@ncall $N f i -> getidx(space, args[i], loc, idx, hidx)
3916+
end
39423917
end
39433918

39443919
if hasfield(Method, :recursion_relation)
39453920
dont_limit = (args...) -> true
3946-
for m in methods(getidx_args)
3921+
for m in methods(call_bc_f)
39473922
m.recursion_relation = dont_limit
39483923
end
39493924
for m in methods(getidx)
@@ -4123,7 +4098,6 @@ function window_bounds(space, bc)
41234098
return (li, lw, rw, ri)
41244099
end
41254100

4126-
41274101
Base.@propagate_inbounds function apply_stencil!(
41284102
space,
41294103
field_out,
@@ -4135,36 +4109,21 @@ Base.@propagate_inbounds function apply_stencil!(
41354109
# left window
41364110
lbw = LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
41374111
@inbounds for idx in li:(lw - 1)
4138-
setidx!(
4139-
space,
4140-
field_out,
4141-
idx,
4142-
hidx,
4143-
getidx(space, bc, lbw, idx, hidx),
4144-
)
4112+
val = getidx(space, bc, lbw, idx, hidx)
4113+
setidx!(space, field_out, idx, hidx, val)
41454114
end
41464115
end
41474116
# interior
41484117
@inbounds for idx in lw:rw
4149-
setidx!(
4150-
space,
4151-
field_out,
4152-
idx,
4153-
hidx,
4154-
getidx(space, bc, Interior(), idx, hidx),
4155-
)
4118+
val = getidx(space, bc, Interior(), idx, hidx)
4119+
setidx!(space, field_out, idx, hidx, val)
41564120
end
41574121
if !Topologies.isperiodic(Spaces.vertical_topology(space))
41584122
# right window
41594123
rbw = RightBoundaryWindow{Spaces.right_boundary_name(space)}()
41604124
@inbounds for idx in (rw + 1):ri
4161-
setidx!(
4162-
space,
4163-
field_out,
4164-
idx,
4165-
hidx,
4166-
getidx(space, bc, rbw, idx, hidx),
4167-
)
4125+
val = getidx(space, bc, rbw, idx, hidx)
4126+
setidx!(space, field_out, idx, hidx, val)
41684127
end
41694128
end
41704129
return field_out

0 commit comments

Comments
 (0)