Skip to content

Commit c485971

Browse files
oscardssmithoxinabox
authored andcommitted
add AbstractDifferentiation support
1 parent 704db4a commit c485971

File tree

6 files changed

+65
-6
lines changed

6 files changed

+65
-6
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Keno Fischer <keno@juliacomputing.com> and contributors"]
44
version = "0.2.0"
55

66
[deps]
7+
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
78
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
89
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
910
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"

src/AbstractDifferentiation.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import AbstractDifferentiation as AD
2+
struct DiffractorForwardBackend <: AD.AbstractForwardMode
3+
end
4+
5+
bundle(x::Number, dx) = TaylorBundle{1}(x, (dx,))
6+
bundle(x::Tuple, dx) = CompositeBundle{1}(x, dx)
7+
bundle(x::AbstractArray{<:Number}, dx::AbstractArray{<:Number}) = TaylorBundle{1}(x, (dx,)) # TODO check me
8+
# TODO: other types of primal
9+
10+
11+
AD.@primitive function pushforward_function(b::DiffractorForwardBackend, f, args...)
12+
return function pushforward(vs)
13+
z = ∂☆{1}()(ZeroBundle{1}(f), map(bundle, args, vs)...)
14+
z[TaylorTangentIndex(1)]
15+
end
16+
end

src/Diffractor.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ const GENERATORS = Expr[]
4040
include("debugutils.jl")
4141

4242
include("stage1/termination.jl")
43+
include("AbstractDifferentiation.jl")
4344
end
4445

4546
end

src/interface.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ end
118118

119119
function (f::)(args...)
120120
y, f☆ = ∂⃖(getfield(f, :f), args...)
121-
return tail(f☆(dx(y)))
121+
tail(f☆(dx(y)))
122122
end
123123

124124
# N.B: This means the gradient is not available for zero-arg function, but such
@@ -149,9 +149,7 @@ PrimeDerivativeBack(f) = PrimeDerivativeBack{1, typeof(f)}(f)
149149
PrimeDerivativeBack(f::PrimeDerivativeBack{N, T}) where {N, T} = raise_pd(f)
150150

151151
function (f::PrimeDerivativeBack)(x)
152-
z = ∂⃖¹(lower_pd(f), x)
153-
y = getfield(z, 1)
154-
f☆ = getfield(z, 2)
152+
y, f☆ = ∂⃖¹(lower_pd(f), x)
155153
return unthunk(getfield(f☆(dx(y)), 2))
156154
end
157155

@@ -227,5 +225,3 @@ macro ∂(expr)
227225
end
228226
derivative(f, x) = Diffractor.PrimeDerivativeFwd(f)(x)
229227
const gradient =
230-
jacobian(f, x::AbstractArray) = reduce(hcat, vec.(gradient(f, x)))
231-
hessian(f, x::AbstractArray) = jacobian(y -> gradient(f, y), float(x))

test/AbstractDifferentiationTests.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using AbstractDifferentiation, Diffractor, Test
2+
include(joinpath(pathof(AbstractDifferentiation), "..", "..", "test", "test_utils.jl"))
3+
import AbstractDifferentiation as AD
4+
5+
backend = Diffractor.DiffractorForwardBackend()
6+
@testset "ForwardDiffBackend" begin
7+
backends = [
8+
@inferred(Diffractor.DiffractorForwardBackend())
9+
]
10+
@testset for backend in backends
11+
@test backend isa AD.AbstractForwardMode
12+
13+
@testset "Derivative" begin #setfield!(::Core.Box, ::Symbol, ::Float64)
14+
@test_broken test_derivatives(backend)
15+
end
16+
@testset "Gradient" begin #Diffractor.TangentBundle{1, Float64, Diffractor.TaylorTangent{Tuple{Float64}}}(::Float64, ::Tuple{Float64})
17+
@test_broken test_gradients(backend)
18+
end
19+
@testset "Jacobian" begin #setfield!(::Core.Box, ::Symbol, ::Vector{Float64})
20+
@test_broken test_jacobians(backend)
21+
end
22+
@testset "Hessian" begin #setindex!(::ChainRulesCore.ZeroTangent, ::Float64, ::Int64)
23+
@test_broken test_hessians(backend)
24+
end
25+
@testset "jvp" begin #setfield!(::Core.Box, ::Symbol, ::Vector{Float64})
26+
@test_broken test_jvp(backend; vaugmented=true)
27+
end
28+
@testset "j′vp" begin #setfield!(::Core.Box, ::Symbol, ::Vector{Float64})
29+
@test_broken test_j′vp(backend)
30+
end
31+
@testset "Lazy Derivative" begin
32+
test_lazy_derivatives(backend)
33+
end
34+
@testset "Lazy Gradient" begin #Diffractor.TangentBundle{1, Float64, Diffractor.TaylorTangent{Tuple{Float64}}}(::Float64, ::Tuple{Float64})
35+
@test_broken test_lazy_gradients(backend)
36+
end
37+
@testset "Lazy Jacobian" begin
38+
test_lazy_jacobians(backend; vaugmented=true)
39+
end
40+
@testset "Lazy Hessian" begin # everything everywhere all at once is broken
41+
@test_broken test_lazy_hessians(backend)
42+
end
43+
end
44+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ const bwd = Diffractor.PrimeDerivativeBack
2020
"forward.jl",
2121
"reverse.jl",
2222
"regression.jl",
23+
"AbstractDifferentiationTests.jl"
2324
#"pinn.jl", # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
2425
)
2526
include(file)

0 commit comments

Comments
 (0)