Skip to content

Commit b3f2a51

Browse files
committed
add and test zero_tangent
1 parent 7f99ce4 commit b3f2a51

File tree

4 files changed

+45
-2
lines changed

4 files changed

+45
-2
lines changed

src/ChainRulesCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod
1111
export frule_via_ad, rrule_via_ad
1212
# definition helper macros
1313
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
14-
export ProjectTo, canonicalize, unthunk # tangent operations
14+
export ProjectTo, canonicalize, unthunk, zero_tangent # tangent operations
1515
export add!!, is_inplaceable_destination # gradient accumulation operations
1616
export ignore_derivatives, @ignore_derivatives
1717
# tangents

src/tangent_types/abstract_zero.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,32 @@ arguments.
8787
```
8888
"""
8989
struct NoTangent <: AbstractZero end
90+
91+
"""
92+
zero_tangent(primal)
93+
94+
This returns an appropriate zero tangent suitable for accumulating tangents of the primal.
95+
For mutable composites types this is a structural []`MutableTangent`](@ref)
96+
For `Array`s, it is applied recursively for each element.
97+
For immutable types, this is simply [`ZeroTangent()`](@ref) as accumulation is default out-of-place for contexts where mutation does not apply.
98+
(Where mutation is not to be supported even for mutable types, then [`ZeroTangent()`](@ref) should be used for everything)
99+
100+
!!! warning Exprimental
101+
`zero_tangent`is an experimental feature, and is part of the mutation support featureset.
102+
While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
103+
Exactly how it should be used (e.g. is it forward-mode only?)
104+
"""
105+
function zero_tangent end
106+
zero_tangent(::AbstractString) = ZeroTangent()
107+
# zero_tangent(::Number) = zero(x) # TODO: do we want this?
108+
zero_tangent(primal::Array{<:Number}) = zero(primal) # TODO: do we want this?
109+
zero_tangent(primal::Array) = map(zero_tangent, primal)
110+
@generated function zero_tangent(primal)
111+
has_mutable_tangent(primal) || return ZeroTangent() # note this takes care of tuples
112+
zfield_exprs = map(fieldnames(primal)) do fname
113+
fval = Expr(:call, zero_tangent, Expr(:call, getfield, :primal, QuoteNode(fname)))
114+
Expr(:kw, fname, fval)
115+
end
116+
backing_expr = Expr(:tuple, Expr(:parameters, zfield_exprs...))
117+
return :($MutableTangent{$primal}($backing_expr))
118+
end

src/tangent_types/structural_tangent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function StructuralTangent{P}(nt::NamedTuple) where {P}
2222
end
2323
end
2424

25-
has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(T) > 0)
25+
has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(P) > 0)
2626

2727

2828
StructuralTangent{P}(tup::Tuple) where P = Tangent{P,typeof(tup)}(tup)

test/tangent_types/abstract_zero.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,17 @@
154154
@test isempty(detect_ambiguities(M))
155155
end
156156
end
157+
158+
@testset "zero_tangent" begin
159+
mutable struct MutDemo
160+
x::Float64
161+
end
162+
@test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo}
163+
@test iszero(zero_tangent(MutDemo(1.5)))
164+
165+
@test zero_tangent((;a=1)) isa ZeroTangent
166+
167+
@test zero_tangent([1.0, 2.0]) == [0.0, 0.0]
168+
@test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]]
169+
end
170+

0 commit comments

Comments
 (0)