Skip to content

Commit ca8820e

Browse files
committed
add and test zero_tangent
1 parent d2142fa commit ca8820e

File tree

4 files changed

+45
-2
lines changed

4 files changed

+45
-2
lines changed

src/ChainRulesCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod
1010
export frule_via_ad, rrule_via_ad
1111
# definition helper macros
1212
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
13-
export ProjectTo, canonicalize, unthunk # tangent operations
13+
export ProjectTo, canonicalize, unthunk, zero_tangent # tangent operations
1414
export add!!, is_inplaceable_destination # gradient accumulation operations
1515
export ignore_derivatives, @ignore_derivatives
1616
# tangents

src/tangent_types/abstract_zero.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,32 @@ arguments.
9191
```
9292
"""
9393
struct NoTangent <: AbstractZero end
94+
95+
"""
96+
zero_tangent(primal)
97+
98+
This returns an appropriate zero tangent suitable for accumulating tangents of the primal.
99+
For mutable composites types this is a structural []`MutableTangent`](@ref)
100+
For `Array`s, it is applied recursively for each element.
101+
For immutable types, this is simply [`ZeroTangent()`](@ref) as accumulation is default out-of-place for contexts where mutation does not apply.
102+
(Where mutation is not to be supported even for mutable types, then [`ZeroTangent()`](@ref) should be used for everything)
103+
104+
!!! warning Exprimental
105+
`zero_tangent`is an experimental feature, and is part of the mutation support featureset.
106+
While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
107+
Exactly how it should be used (e.g. is it forward-mode only?)
108+
"""
109+
function zero_tangent end
110+
zero_tangent(::AbstractString) = ZeroTangent()
111+
# zero_tangent(::Number) = zero(x) # TODO: do we want this?
112+
zero_tangent(primal::Array{<:Number}) = zero(primal) # TODO: do we want this?
113+
zero_tangent(primal::Array) = map(zero_tangent, primal)
114+
@generated function zero_tangent(primal)
115+
has_mutable_tangent(primal) || return ZeroTangent() # note this takes care of tuples
116+
zfield_exprs = map(fieldnames(primal)) do fname
117+
fval = Expr(:call, zero_tangent, Expr(:call, getfield, :primal, QuoteNode(fname)))
118+
Expr(:kw, fname, fval)
119+
end
120+
backing_expr = Expr(:tuple, Expr(:parameters, zfield_exprs...))
121+
return :($MutableTangent{$primal}($backing_expr))
122+
end

src/tangent_types/structural_tangent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function StructuralTangent{P}(nt::NamedTuple) where {P}
2222
end
2323
end
2424

25-
has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(T) > 0)
25+
has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(P) > 0)
2626

2727

2828
StructuralTangent{P}(tup::Tuple) where P = Tangent{P,typeof(tup)}(tup)

test/tangent_types/abstract_zero.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,17 @@
160160
@test isempty(detect_ambiguities(M))
161161
end
162162
end
163+
164+
@testset "zero_tangent" begin
165+
mutable struct MutDemo
166+
x::Float64
167+
end
168+
@test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo}
169+
@test iszero(zero_tangent(MutDemo(1.5)))
170+
171+
@test zero_tangent((;a=1)) isa ZeroTangent
172+
173+
@test zero_tangent([1.0, 2.0]) == [0.0, 0.0]
174+
@test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]]
175+
end
176+

0 commit comments

Comments
 (0)