Skip to content

Commit bfc36b8

Browse files
Merge pull request #2350 from CliMA/dy/broadcast_argument_unrolling
Fix unrolling over field broadcast arguments
2 parents 476c5a9 + 66c0a84 commit bfc36b8

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

src/Fields/Fields.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import ..Utilities: PlusHalf, half
3030
using ..RecursiveApply
3131
using ClimaComms
3232
import Adapt
33-
import UnrolledUtilities: unrolled_map, unrolled_findfirst
33+
import UnrolledUtilities: unrolled_map, unrolled_mapreduce, unrolled_findfirst
3434

3535
import StaticArrays, LinearAlgebra, Statistics, InteractiveUtils
3636

src/Fields/broadcast.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,20 @@ Base.Broadcast.BroadcastStyle(
4040
::FieldStyle{DS2},
4141
) where {DS1, DS2} = FieldStyle(Base.Broadcast.BroadcastStyle(DS1(), DS2()))
4242

43+
# Override the recursive unrolling used in combine_styles (which can lead to
44+
# inference failures in broadcast expressions with more than 10 arguments) with
45+
# manual unrolling (which can have higher latency but is always inferrable).
46+
Base.Broadcast.combine_styles(
47+
arg1::Union{Field, Base.Broadcast.Broadcasted{<:AbstractFieldStyle}},
48+
arg2,
49+
arg3,
50+
args...,
51+
) = unrolled_mapreduce(
52+
Base.Broadcast.combine_styles,
53+
Base.Broadcast.result_style,
54+
(arg1, arg2, arg3, args...),
55+
)
56+
4357
Base.Broadcast.broadcastable(field::Field) = field
4458

4559
Base.eltype(bc::Base.Broadcast.Broadcasted{<:AbstractFieldStyle}) =

0 commit comments

Comments
 (0)