diff --git a/src/Fields/Fields.jl b/src/Fields/Fields.jl index 55dedb89c3..a30a0e000d 100644 --- a/src/Fields/Fields.jl +++ b/src/Fields/Fields.jl @@ -30,7 +30,7 @@ import ..Utilities: PlusHalf, half using ..RecursiveApply using ClimaComms import Adapt -import UnrolledUtilities: unrolled_map, unrolled_findfirst +import UnrolledUtilities: unrolled_map, unrolled_mapreduce, unrolled_findfirst import StaticArrays, LinearAlgebra, Statistics, InteractiveUtils diff --git a/src/Fields/broadcast.jl b/src/Fields/broadcast.jl index 03a80edbac..ce431e7f22 100644 --- a/src/Fields/broadcast.jl +++ b/src/Fields/broadcast.jl @@ -40,6 +40,20 @@ Base.Broadcast.BroadcastStyle( ::FieldStyle{DS2}, ) where {DS1, DS2} = FieldStyle(Base.Broadcast.BroadcastStyle(DS1(), DS2())) +# Override the recursive unrolling used in combine_styles (which can lead to +# inference failures in broadcast expressions with more than 10 arguments) with +# manual unrolling (which can have higher latency but is always inferrable). +Base.Broadcast.combine_styles( + arg1::Union{Field, Base.Broadcast.Broadcasted{<:AbstractFieldStyle}}, + arg2, + arg3, + args..., +) = unrolled_mapreduce( + Base.Broadcast.combine_styles, + Base.Broadcast.result_style, + (arg1, arg2, arg3, args...), +) + Base.Broadcast.broadcastable(field::Field) = field Base.eltype(bc::Base.Broadcast.Broadcasted{<:AbstractFieldStyle}) =