Skip to content

Commit 886808b

Browse files
committed
Fix unrolling over field broadcast arguments
1 parent 2bc09fb commit 886808b

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

src/Fields/Fields.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import ..Utilities: PlusHalf, half
3030
using ..RecursiveApply
3131
using ClimaComms
3232
import Adapt
33-
import UnrolledUtilities: unrolled_map, unrolled_findfirst
33+
import UnrolledUtilities: unrolled_map, unrolled_reduce, unrolled_findfirst
3434

3535
import StaticArrays, LinearAlgebra, Statistics, InteractiveUtils
3636

src/Fields/broadcast.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ Base.Broadcast.BroadcastStyle(
4040
::FieldStyle{DS2},
4141
) where {DS1, DS2} = FieldStyle(Base.Broadcast.BroadcastStyle(DS1(), DS2()))
4242

43+
# Override the recursive unrolling used in combine_styles (which can lead to
44+
# inference failures in broadcast expressions with more than 10 arguments) with
45+
# manual unrolling (which sacrifices latency but is always inferrable).
46+
Base.Broadcast.combine_styles(style1::AbstractFieldStyle, style2, styles...) =
47+
unrolled_reduce(Base.Broadcast.combine_styles, (style1, style2, styles...))
48+
4349
Base.Broadcast.broadcastable(field::Field) = field
4450

4551
Base.eltype(bc::Base.Broadcast.Broadcasted{<:AbstractFieldStyle}) =

0 commit comments

Comments
 (0)