@@ -3759,25 +3759,37 @@ end
3759
3759
3760
3760
Base. @propagate_inbounds function getidx (
3761
3761
parent_space,
3762
- bc:: StencilBroadcasted ,
3763
- loc:: Interior ,
3764
- idx,
3765
- hidx,
3766
- )
3767
- space = reconstruct_placeholder_space (axes (bc), parent_space)
3768
- stencil_interior (bc. op, loc, space, idx, hidx, bc. args... )
3769
- end
3770
-
3771
- Base. @propagate_inbounds function getidx (
3772
- parent_space,
3773
- bc:: StencilBroadcasted ,
3774
- loc:: LeftBoundaryWindow ,
3762
+ bc:: Union{StencilBroadcasted, Base.Broadcast.Broadcasted} ,
3763
+ loc:: Location ,
3775
3764
idx,
3776
3765
hidx,
3777
3766
)
3767
+ # Use Union-splitting here (x isa X) instead of dispatch
3768
+ # for improved latency.
3778
3769
space = reconstruct_placeholder_space (axes (bc), parent_space)
3770
+ if bc isa Base. Broadcast. Broadcasted
3771
+ # Manually call bc.f for small tuples (improved latency)
3772
+ (; args) = bc
3773
+ N = length (bc. args)
3774
+ if N == 1
3775
+ return bc. f (getidx (space, args[1 ], loc, idx, hidx))
3776
+ elseif N == 2
3777
+ return bc. f (
3778
+ getidx (space, args[1 ], loc, idx, hidx),
3779
+ getidx (space, args[2 ], loc, idx, hidx),
3780
+ )
3781
+ elseif N == 3
3782
+ return bc. f (
3783
+ getidx (space, args[1 ], loc, idx, hidx),
3784
+ getidx (space, args[2 ], loc, idx, hidx),
3785
+ getidx (space, args[3 ], loc, idx, hidx),
3786
+ )
3787
+ end
3788
+ return call_bc_f (bc. f, space, loc, idx, hidx, args... )
3789
+ end
3779
3790
op = bc. op
3780
- if should_call_left_boundary (idx, space, bc, loc)
3791
+ if loc isa LeftBoundaryWindow &&
3792
+ should_call_left_boundary (idx, space, bc, loc)
3781
3793
stencil_left_boundary (
3782
3794
op,
3783
3795
get_boundary (op, loc),
@@ -3787,22 +3799,8 @@ Base.@propagate_inbounds function getidx(
3787
3799
hidx,
3788
3800
bc. args... ,
3789
3801
)
3790
- else
3791
- # fallback to interior stencil
3792
- stencil_interior (op, loc, space, idx, hidx, bc. args... )
3793
- end
3794
- end
3795
-
3796
- Base. @propagate_inbounds function getidx (
3797
- parent_space,
3798
- bc:: StencilBroadcasted ,
3799
- loc:: RightBoundaryWindow ,
3800
- idx,
3801
- hidx,
3802
- )
3803
- op = bc. op
3804
- space = reconstruct_placeholder_space (axes (bc), parent_space)
3805
- if should_call_right_boundary (idx, space, bc, loc)
3802
+ elseif loc isa RightBoundaryWindow &&
3803
+ should_call_right_boundary (idx, space, bc, loc)
3806
3804
stencil_right_boundary (
3807
3805
op,
3808
3806
get_boundary (op, loc),
@@ -3813,11 +3811,11 @@ Base.@propagate_inbounds function getidx(
3813
3811
bc. args... ,
3814
3812
)
3815
3813
else
3816
- # fallback to interior stencil
3817
- stencil_interior (op, loc, space, idx, hidx, bc. args... )
3814
+ stencil_interior (bc. op, loc, space, idx, hidx, bc. args... )
3818
3815
end
3819
3816
end
3820
3817
3818
+
3821
3819
# broadcasting a StencilStyle gives a CompositeStencilStyle
3822
3820
Base. Broadcast. BroadcastStyle (
3823
3821
:: Type{<:StencilBroadcasted{Style}} ,
@@ -3902,48 +3900,25 @@ end
3902
3900
@noinline inferred_getidx_error (idx_type:: Type , space_type:: Type ) =
3903
3901
error (" Invalid index type `$idx_type ` for field on space `$space_type `" )
3904
3902
3905
-
3906
3903
# recursively unwrap getidx broadcast arguments in a way that is statically reducible by the optimizer
3907
- Base. @propagate_inbounds getidx_args (
3908
- space,
3909
- args:: Tuple ,
3910
- loc:: Location ,
3911
- idx,
3912
- hidx,
3913
- ) = (
3914
- getidx (space, args[1 ], loc, idx, hidx),
3915
- getidx_args (space, Base. tail (args), loc, idx, hidx)... ,
3916
- )
3917
- Base. @propagate_inbounds getidx_args (
3904
+ @generated function call_bc_f (
3905
+ f:: F ,
3918
3906
space,
3919
- arg:: Tuple{Any} ,
3920
- loc:: Location ,
3921
- idx,
3922
- hidx,
3923
- ) = (getidx (space, arg[1 ], loc, idx, hidx),)
3924
- Base. @propagate_inbounds getidx_args (
3925
- space,
3926
- :: Tuple{} ,
3927
- loc:: Location ,
3928
- idx,
3929
- hidx,
3930
- ) = ()
3931
-
3932
- Base. @propagate_inbounds function getidx (
3933
- parent_space,
3934
- bc:: Base.Broadcast.Broadcasted ,
3935
3907
loc:: Location ,
3936
3908
idx,
3937
3909
hidx,
3938
- )
3939
- space = reconstruct_placeholder_space (axes (bc), parent_space)
3940
- _args = getidx_args (space, bc. args, loc, idx, hidx)
3941
- bc. f (_args... )
3910
+ args... ,
3911
+ ) where {F}
3912
+ N = length (args)
3913
+ return quote
3914
+ Base. @_propagate_inbounds_meta
3915
+ Base. Cartesian. @ncall $ N f i -> getidx (space, args[i], loc, idx, hidx)
3916
+ end
3942
3917
end
3943
3918
3944
3919
if hasfield (Method, :recursion_relation )
3945
3920
dont_limit = (args... ) -> true
3946
- for m in methods (getidx_args )
3921
+ for m in methods (call_bc_f )
3947
3922
m. recursion_relation = dont_limit
3948
3923
end
3949
3924
for m in methods (getidx)
@@ -4123,7 +4098,6 @@ function window_bounds(space, bc)
4123
4098
return (li, lw, rw, ri)
4124
4099
end
4125
4100
4126
-
4127
4101
Base. @propagate_inbounds function apply_stencil! (
4128
4102
space,
4129
4103
field_out,
@@ -4135,36 +4109,21 @@ Base.@propagate_inbounds function apply_stencil!(
4135
4109
# left window
4136
4110
lbw = LeftBoundaryWindow {Spaces.left_boundary_name(space)} ()
4137
4111
@inbounds for idx in li: (lw - 1 )
4138
- setidx! (
4139
- space,
4140
- field_out,
4141
- idx,
4142
- hidx,
4143
- getidx (space, bc, lbw, idx, hidx),
4144
- )
4112
+ val = getidx (space, bc, lbw, idx, hidx)
4113
+ setidx! (space, field_out, idx, hidx, val)
4145
4114
end
4146
4115
end
4147
4116
# interior
4148
4117
@inbounds for idx in lw: rw
4149
- setidx! (
4150
- space,
4151
- field_out,
4152
- idx,
4153
- hidx,
4154
- getidx (space, bc, Interior (), idx, hidx),
4155
- )
4118
+ val = getidx (space, bc, Interior (), idx, hidx)
4119
+ setidx! (space, field_out, idx, hidx, val)
4156
4120
end
4157
4121
if ! Topologies. isperiodic (Spaces. vertical_topology (space))
4158
4122
# right window
4159
4123
rbw = RightBoundaryWindow {Spaces.right_boundary_name(space)} ()
4160
4124
@inbounds for idx in (rw + 1 ): ri
4161
- setidx! (
4162
- space,
4163
- field_out,
4164
- idx,
4165
- hidx,
4166
- getidx (space, bc, rbw, idx, hidx),
4167
- )
4125
+ val = getidx (space, bc, rbw, idx, hidx)
4126
+ setidx! (space, field_out, idx, hidx, val)
4168
4127
end
4169
4128
end
4170
4129
return field_out
0 commit comments