Skip to content

Commit 20e7f32

Browse files
authored
Merge pull request #205 from JuliaDiff/ox/bundle_helper
add more types supported with bundle helper
2 parents 8dd45c0 + b60acd4 commit 20e7f32

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

src/AbstractDifferentiation.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,27 @@ import AbstractDifferentiation as AD
22
struct DiffractorForwardBackend <: AD.AbstractForwardMode
33
end
44

5-
bundle(x::Number, dx) = TaylorBundle{1}(x, (dx,))
6-
bundle(x::Tuple, dx) = CompositeBundle{1}(x, dx)
7-
bundle(x::AbstractArray{<:Number}, dx::AbstractArray{<:Number}) = TaylorBundle{1}(x, (dx,)) # TODO check me
8-
# TODO: other types of primal
5+
"""
6+
bundle(primal, tangent)
7+
8+
Wraps a primal up with a tangent into the appropriate kind of `AbstractBundle{1}`.
9+
This is more or less the Diffractor equivelent of ForwardDiff.jl's `Dual` type.
10+
"""
11+
function bundle end
12+
bundle(x, dx::ChainRulesCore.AbstractZero) = UniformBundle{1, typeof(x), typeof(dx)}(x, dx)
13+
bundle(x::Number, dx::Number) = TaylorBundle{1}(x, (dx,))
14+
bundle(x::AbstractArray{<:Number}, dx) = TaylorBundle{1}(x, (dx,))
15+
bundle(x::AbstractArray, dx) = error("Nonnumeric arrays not implemented, that type is a mess")
16+
bundle(x::P, dx::Tangent{P}) where P = _bundle(x, ChainRulesCore.canonicalize(dx))
17+
18+
"helper that assumes tangent is in canonical form"
19+
function _bundle(x::P, dx::Tangent{P}) where P
20+
# SoA to AoS flip (hate this, hate it even more cos we just undo it later when we hit chainrules)
21+
the_bundle = ntuple(Val{fieldcount(P)}()) do ii
22+
bundle(getfield(x, ii), getproperty(dx, ii))
23+
end
24+
return CompositeBundle{1, P}(the_bundle)
25+
end
926

1027

1128
AD.@primitive function pushforward_function(b::DiffractorForwardBackend, f, args...)

test/AbstractDifferentiationTests.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,25 @@
1-
using AbstractDifferentiation, Diffractor, Test, LinearAlgebra
1+
using AbstractDifferentiation, Diffractor, Test, LinearAlgebra, ChainRulesCore
22
import AbstractDifferentiation as AD
33
backend = Diffractor.DiffractorForwardBackend()
44

5+
@testset "bundle" begin
6+
bundle = Diffractor.bundle
7+
8+
@test bundle(1.0, 2.0) isa Diffractor.TaylorBundle{1}
9+
@test bundle([1.0, 2.0], [2.0, 3.0]) isa Diffractor.TaylorBundle{1}
10+
@test bundle(1.5=>2.5, Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0)) isa Diffractor.CompositeBundle{1}
11+
@test bundle(1.1, ChainRulesCore.ZeroTangent()) isa Diffractor.ZeroBundle{1}
12+
@test bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(first=1.0, second=Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0))) isa Diffractor.CompositeBundle{1}
13+
14+
# noncanonical structural tangent
15+
b = bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(second=Tangent{Pair{Float64, Float64}}(second=2.0, first=1.0)))
16+
t = Diffractor.first_partial(b)
17+
@test b isa Diffractor.CompositeBundle{1}
18+
@test iszero(t.first)
19+
@test t.second.first == 1.0
20+
@test t.second.second == 2.0
21+
end
22+
523
@testset "basics" begin
624
@test AD.derivative(backend, +, 1.5, 10.0) == (1.0, 1.0)
725
@test AD.derivative(backend, *, 1.5, 10.0) == (10.0, 1.5)
@@ -50,3 +68,4 @@ include(joinpath(pathof(AbstractDifferentiation), "..", "..", "test", "test_util
5068
end
5169
end
5270
end
71+

0 commit comments

Comments
 (0)