Skip to content

Commit eaef739

Browse files
committed
opt out of CR broadcasting rrule
1 parent ed84d53 commit eaef739

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Zygote"
22
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
3-
version = "0.6.41"
3+
version = "0.6.42"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/compiler/chainrules.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ such that if a suitable rule is defined later, the generated function will recom
1818
function has_chain_rrule(T)
1919
config_T, arg_Ts = Iterators.peel(T.parameters)
2020
configured_rrule_m = meta(Tuple{typeof(rrule), config_T, arg_Ts...})
21+
22+
isnothing(configured_rrule_m) && return false, nothing # too crude, surely
23+
2124
if _is_rrule_redispatcher(configured_rrule_m.method)
2225
# The config is not being used:
2326
# it is being redispatched without config, so we need the method it redispatches to

src/lib/broadcast.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,24 @@ end
152152
# General Fallback
153153
# ================
154154

155-
# The fused reverse mode implementation is the most general but currently has
155+
# The ~~fused~~ reverse mode implementation is the most general but currently has
156156
# poor performance. It works by flattening the broadcast and mapping the call to
157157
# `_pullback` over the input.
158-
159158
# However, the core call
160159
# broadcast(_pullback, (cx,), f, args...)
161160
# is already 10x slower than a simple broadcast (presumably due to inlining
162161
# issues, or something similar) and the other operations needed take it to about
163162
# 100x overhead.
164163

164+
# https://github.com/FluxML/Zygote.jl/pull/1001 tries to use broadcast_forward (using Dual numbers)
165+
# whenever possible, this was previously used only for CuArrays. It is usually much faster.
166+
167+
# https://github.com/JuliaDiff/ChainRules.jl/pull/644 implements broadcasting.
168+
# Its generic rule would be applied before the one defined here, with AbstractArrayStyle
169+
# @adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
170+
# but does not pass all Zygote's tests. So disable it:
171+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), f::F, args::Vararg{Any,N}) where {F,N}
172+
165173
@generated inclen(::NTuple{N,Any}) where N = Val(N+1)
166174

167175
# Avoid hitting special cases for `Adjoint` etc.

0 commit comments

Comments
 (0)