Skip to content

Commit 0318546

Browse files
authored
Merge pull request #187 from oscardssmith/add-interface-v2
Give Diffractor an AbstractDifferentiation interface
2 parents 8af6de2 + a1e8810 commit 0318546

File tree

8 files changed

+116
-30
lines changed

8 files changed

+116
-30
lines changed

Project.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Keno Fischer <keno@juliacomputing.com> and contributors"]
44
version = "0.2.0"
55

66
[deps]
7+
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
78
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
89
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
910
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
@@ -14,6 +15,13 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1415
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1516
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1617

18+
[extras]
19+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
20+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
21+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
22+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
23+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
24+
1725
[compat]
1826
ChainRules = "1.44.6"
1927
ChainRulesCore = "1.15.3"
@@ -24,3 +32,6 @@ PrecompileTools = "1"
2432
StaticArrays = "1"
2533
StructArrays = "0.6"
2634
julia = "1.10"
35+
36+
[targets]
37+
test = ["ForwardDiff", "LinearAlgebra", "Random", "Symbolics", "Test"]

docs/src/index.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,32 @@
33
Next-generation AD
44

55
[PDF containing the terminology](terminology.pdf)
6+
7+
## Getting Started
8+
9+
⚠️This certainly has bugs and issues. Please open issues on [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl/), or [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues) as appropriate.⚠️
10+
11+
Diffractor's public API is via [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl/).
12+
Please see the AbstractDifferentiation.jl docs for detailed usage.
13+
14+
```@jldoctest
15+
julia> using Diffractor: DiffractorForwardBackend
16+
17+
julia> using AbstractDifferentiation: derivative
18+
19+
julia> derivative(DiffractorForwardBackend(), +, 1.5, 10.0)
20+
(1.0, 1.0)
21+
22+
julia> derivative(DiffractorForwardBackend(), *, 1.5, 10.0)
23+
(10.0, 1.5)
24+
25+
julia> jacobian(DiffractorForwardBackend(), prod, [1.5, 2.5, 10.0]) |> only
26+
1×3 Matrix{Float64}:
27+
25.0 15.0 3.75
28+
29+
julia> jacobian(DiffractorForwardBackend(), identity, [1.5, 2.5, 10.0]) |> only
30+
3×3 Matrix{Float64}:
31+
1.0 0.0 0.0
32+
0.0 1.0 0.0
33+
0.0 0.0 1.0
34+
```

src/AbstractDifferentiation.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import AbstractDifferentiation as AD
2+
struct DiffractorForwardBackend <: AD.AbstractForwardMode
3+
end
4+
5+
bundle(x::Number, dx) = TaylorBundle{1}(x, (dx,))
6+
bundle(x::Tuple, dx) = CompositeBundle{1}(x, dx)
7+
bundle(x::AbstractArray{<:Number}, dx::AbstractArray{<:Number}) = TaylorBundle{1}(x, (dx,)) # TODO check me
8+
# TODO: other types of primal
9+
10+
11+
AD.@primitive function pushforward_function(b::DiffractorForwardBackend, f, args...)
12+
return function pushforward(vs)
13+
z = ∂☆{1}()(ZeroBundle{1}(f), map(bundle, args, vs)...)
14+
z[TaylorTangentIndex(1)]
15+
end
16+
end

src/Diffractor.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ const GENERATORS = Expr[]
4040
include("debugutils.jl")
4141

4242
include("stage1/termination.jl")
43+
include("AbstractDifferentiation.jl")
4344
end
4445

4546
end

src/interface.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ end
118118

119119
function (f::)(args...)
120120
y, f☆ = ∂⃖(getfield(f, :f), args...)
121-
return tail(f☆(dx(y)))
121+
tail(f☆(dx(y)))
122122
end
123123

124124
# N.B: This means the gradient is not available for zero-arg function, but such
@@ -127,8 +127,6 @@ function (::Type{∇})(f, x1, args...)
127127
unthunk.((f)(x1, args...))
128128
end
129129

130-
const gradient =
131-
132130
# Star Trek has their prime directive. We have the...
133131
abstract type AbstractPrimeDerivative{N, T}; end
134132

@@ -151,9 +149,7 @@ PrimeDerivativeBack(f) = PrimeDerivativeBack{1, typeof(f)}(f)
151149
PrimeDerivativeBack(f::PrimeDerivativeBack{N, T}) where {N, T} = raise_pd(f)
152150

153151
function (f::PrimeDerivativeBack)(x)
154-
z = ∂⃖¹(lower_pd(f), x)
155-
y = getfield(z, 1)
156-
f☆ = getfield(z, 2)
152+
y, f☆ = ∂⃖¹(lower_pd(f), x)
157153
return unthunk(getfield(f☆(dx(y)), 2))
158154
end
159155

@@ -181,8 +177,8 @@ struct PrimeDerivative{N, T}
181177
end
182178

183179
function (f::PrimeDerivative{N, T})(x) where {N, T}
184-
# For now, this is backwards mode, since that's more fully implemented
185-
return PrimeDerivativeBack{N, T}(f.f)(x)
180+
# For now, this is forward mode, since that's more fully implemented
181+
return PrimeDerivativeFwd{N, T}(f.f)(x)
186182
end
187183

188184
"""
@@ -227,3 +223,5 @@ will compute the derivative `∂^3 f/∂x^2 ∂y` at `(x,y)`.
227223
macro (expr)
228224
error("Write me")
229225
end
226+
derivative(f, x) = Diffractor.PrimeDerivativeFwd(f)(x)
227+
const gradient =

test/AbstractDifferentiationTests.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
using AbstractDifferentiation, Diffractor, Test, LinearAlgebra
2+
import AbstractDifferentiation as AD
3+
backend = Diffractor.DiffractorForwardBackend()
4+
5+
@testset "basics" begin
6+
@test AD.derivative(backend, +, 1.5, 10.0) == (1.0, 1.0)
7+
@test AD.derivative(backend, *, 1.5, 10.0) == (10.0, 1.5)
8+
@test only(AD.jacobian(backend, prod, [1.5, 2.5, 10.0])) == [25.0 15.0 3.75]
9+
@test only(AD.jacobian(backend, identity, [1.5, 2.5, 10.0])) == Matrix(I, 3, 3)
10+
end
11+
12+
# standard tests from AbstractDifferentiation.test_utils
13+
include(joinpath(pathof(AbstractDifferentiation), "..", "..", "test", "test_utils.jl"))
14+
@testset "ForwardDiffBackend" begin
15+
backends = [
16+
@inferred(Diffractor.DiffractorForwardBackend())
17+
]
18+
@testset for backend in backends
19+
@test backend isa AD.AbstractForwardMode
20+
21+
@testset "Derivative" begin #setfield!(::Core.Box, ::Symbol, ::Float64)
22+
@test_broken test_derivatives(backend)
23+
end
24+
@testset "Gradient" begin #Diffractor.TangentBundle{1, Float64, Diffractor.TaylorTangent{Tuple{Float64}}}(::Float64, ::Tuple{Float64})
25+
@test_broken test_gradients(backend)
26+
end
27+
@testset "Jacobian" begin #setfield!(::Core.Box, ::Symbol, ::Vector{Float64})
28+
@test_broken test_jacobians(backend)
29+
end
30+
@testset "Hessian" begin #setindex!(::ChainRulesCore.ZeroTangent, ::Float64, ::Int64)
31+
@test_broken test_hessians(backend)
32+
end
33+
@testset "jvp" begin #setfield!(::Core.Box, ::Symbol, ::Vector{Float64})
34+
@test_broken test_jvp(backend; vaugmented=true)
35+
end
36+
@testset "j′vp" begin #setfield!(::Core.Box, ::Symbol, ::Vector{Float64})
37+
@test_broken test_j′vp(backend)
38+
end
39+
@testset "Lazy Derivative" begin
40+
test_lazy_derivatives(backend)
41+
end
42+
@testset "Lazy Gradient" begin #Diffractor.TangentBundle{1, Float64, Diffractor.TaylorTangent{Tuple{Float64}}}(::Float64, ::Tuple{Float64})
43+
@test_broken test_lazy_gradients(backend)
44+
end
45+
@testset "Lazy Jacobian" begin #MethodError: no method matching *(::Diffractor.PrimeDerivativeBack{1, Diagonal{Bool, Vector{Bool}}}, ::Vector{Float64})
46+
@test_broken test_lazy_jacobians(backend; vaugmented=true)
47+
end
48+
@testset "Lazy Hessian" begin # everything everywhere all at once is broken
49+
@test_broken test_lazy_hessians(backend)
50+
end
51+
end
52+
end

test/Project.toml

Lines changed: 0 additions & 22 deletions
This file was deleted.

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ const bwd = Diffractor.PrimeDerivativeBack
2020
"forward.jl",
2121
"reverse.jl",
2222
"regression.jl",
23+
"AbstractDifferentiationTests.jl"
2324
#"pinn.jl", # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
2425
)
2526
include(file)

0 commit comments

Comments
 (0)