Skip to content

Commit d2e12ae

Browse files
authored
feat!: specify preparation arguments in DIT Scenario (#786)
* feat!: specify preparation arguments in DIT `Scenario` * Fix * Fixes * Fixes * Fixes * Fix static arrays * Fix * Fix sparse and complex * All works except HVP * Fix tangents for prep same point * Fixes * Update DifferentiationInterfaceTest/src/scenarios/scenario.jl
1 parent c2bd64f commit d2e12ae

File tree

27 files changed

+872
-839
lines changed

27 files changed

+872
-839
lines changed

DifferentiationInterface/src/utils/prep.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ function check_prep(
198198
if SIG != EXEC_SIG
199199
throw(
200200
PreparationMismatchError(
201-
SIG, EXEC_SIG; format=[:f, :backend, :x, :tang, :contexts]
201+
SIG, EXEC_SIG; format=[:f, :backend, :x, :t, :contexts]
202202
),
203203
)
204204
end
@@ -213,7 +213,7 @@ function check_prep(
213213
if SIG != EXEC_SIG
214214
throw(
215215
PreparationMismatchError(
216-
SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :tang, :contexts]
216+
SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :t, :contexts]
217217
),
218218
)
219219
end

DifferentiationInterface/test/Back/DifferentiateWith/test.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ function differentiatewith_scenarios()
1616
DIT.function_place(scen) == :out
1717
end
1818
good_scens = map(bad_scens) do scen
19-
DIT.change_function(
20-
scen, DifferentiateWith(scen.f, AutoFiniteDiff()); keep_smaller=false
21-
)
19+
DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff()))
2220
end
2321
return good_scens
2422
end

DifferentiationInterface/test/Back/FiniteDiff/test.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ end
2727
include_cachified=true,
2828
include_constantorcachified=true,
2929
use_tuples=true,
30+
include_smaller=true,
3031
);
3132
excluded=[:second_derivative, :hvp],
3233
logging=LOGGING,

DifferentiationInterface/test/Back/ForwardDiff/test.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ end
4141
include_cachified=true,
4242
include_constantorcachified=true,
4343
use_tuples=true,
44+
include_smaller=true,
4445
);
4546
logging=LOGGING,
4647
)

DifferentiationInterface/test/Core/Internals/signature.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ end
9898
- exec: Nothing
9999
- backend: ✅
100100
- x: ✅
101-
- tang: ✅
101+
- t: ✅
102102
- contexts: ✅
103103
""" pushforward(nothing, prep, backend, x, (x,), Constant(c))
104104
end
@@ -119,7 +119,7 @@ end
119119
- y: ✅
120120
- backend: ✅
121121
- x: ✅
122-
- tang: ✅
122+
- t: ✅
123123
- contexts: ✅
124124
""" pushforward(nothing, y, prep, backend, x, (x,))
125125
end

DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ end
6363
@testset "Dense" begin
6464
test_differentiation(
6565
vcat(backends, second_order_backends),
66-
default_scenarios(; include_constantified=true);
66+
default_scenarios(; include_constantified=true, include_smaller=true);
6767
logging=LOGGING,
6868
)
6969

DifferentiationInterfaceTest/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Changed
11+
12+
- Specify preparation arguments in DIT Scenario ([#786])
13+
1014
## [0.9.6] - 2025-03-28
1115

1216
### Added
@@ -18,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1822
[unreleased]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.9.6...main
1923
[0.9.6]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.9.5...DifferentiationInterfaceTest-v0.9.6
2024

25+
[#786]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/786
2126
[#749]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/749
2227
[#748]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/748
2328
[#745]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/745

DifferentiationInterfaceTest/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterfaceTest"
22
uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.9.6"
4+
version = "0.10.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -44,7 +44,7 @@ AllocCheck = "0.2"
4444
Chairmarks = "1.2.1"
4545
ComponentArrays = "0.15"
4646
DataFrames = "1.6.1"
47-
DifferentiationInterface = "0.6.0"
47+
DifferentiationInterface = "0.6.53"
4848
DocStringExtensions = "0.8,0.9"
4949
ExplicitImports = "1.10.1"
5050
FiniteDiff = "2.27.0"

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestComponentArraysExt/DifferentiationInterfaceTestComponentArraysExt.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,13 @@ function comp_to_num_scenarios_onearg(x::ComponentVector; dx::AbstractVector, dy
3333
append!(
3434
scens,
3535
[
36-
DIT.Scenario{:pullback,pl_op}(f, x; tang=(dy,), res1=(dx_from_dy,)),
36+
DIT.Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,)),
3737
DIT.Scenario{:gradient,pl_op}(f, x; res1=grad),
3838
],
3939
)
4040
end
4141
for pl_op in (:out,)
42-
append!(
43-
scens, [DIT.Scenario{:pushforward,pl_op}(f, x; tang=(dx,), res1=(dy_from_dx,))]
44-
)
42+
append!(scens, [DIT.Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,))])
4543
end
4644
return scens
4745
end

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,11 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
9292
g = gradient_finite_differences(square_loss, model, x)
9393

9494
scen = DIT.Scenario{:gradient,:out}(
95-
square_loss, model; contexts=(DI.Constant(x),), res1=g
95+
square_loss,
96+
model,
97+
DI.Constant(x);
98+
prep_args=(x=model, contexts=(DI.Constant(x),)),
99+
res1=g,
96100
)
97101
push!(scens, scen)
98102

@@ -163,7 +167,11 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
163167
Flux.trainmode!(model)
164168
g = gradient_finite_differences(square_loss, model, x)
165169
scen = DIT.Scenario{:gradient,:out}(
166-
square_loss, model; contexts=(DI.Constant(x),), res1=g
170+
square_loss,
171+
model,
172+
DI.Constant(x);
173+
prep_args=(; x=model, contexts=(DI.Constant(x),)),
174+
res1=g,
167175
)
168176
push!(scens, scen)
169177
end
@@ -191,7 +199,11 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
191199
Flux.trainmode!(model)
192200
g = gradient_finite_differences(square_loss_iterated, model, x)
193201
scen = DIT.Scenario{:gradient,:out}(
194-
square_loss_iterated, model; contexts=(DI.Constant(x),), res1=g
202+
square_loss_iterated,
203+
model,
204+
DI.Constant(x);
205+
prep_args=(; x=model, contexts=(DI.Constant(x),)),
206+
res1=g,
195207
)
196208
push!(scens, scen)
197209
end

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,17 @@ myjl(x::DI.Cache{<:Union{Tuple,NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap
2323
myjl(::Nothing) = nothing
2424

2525
function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
26-
(; f, x, y, tang, contexts, res1, res2) = scen
27-
return DIT.Scenario{op,pl_op,pl_fun}(
28-
myjl(f);
26+
(; f, x, y, t, contexts, prep_args, res1, res2, name) = scen
27+
return DIT.Scenario{op,pl_op,pl_fun}(;
28+
f=myjl(f),
2929
x=myjl(x),
3030
y=myjl(y),
31-
tang=myjl(tang),
31+
t=myjl(t),
3232
contexts=myjl(contexts),
33+
prep_args=map(myjl, prep_args),
3334
res1=myjl(res1),
3435
res2=myjl(res2),
36+
name,
3537
)
3638
end
3739

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,14 @@ function DIT.lux_scenarios(rng::AbstractRNG=default_rng())
199199
)
200200
scen = DIT.Scenario{:gradient,:out}(
201201
square_loss,
202-
ComponentArray(ps);
203-
contexts=(DI.Constant(model), DI.Constant(x), DI.Constant(st)),
202+
ComponentArray(ps),
203+
DI.Constant(model),
204+
DI.Constant(x),
205+
DI.Constant(st);
206+
prep_args=(
207+
x=ComponentArray(ps),
208+
contexts=(DI.Constant(model), DI.Constant(x), DI.Constant(st)),
209+
),
204210
res1=g,
205211
)
206212
push!(scens, scen)

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module DifferentiationInterfaceTestStaticArraysExt
33
import DifferentiationInterface as DI
44
import DifferentiationInterfaceTest as DIT
55
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm
6-
using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector
6+
using StaticArrays: StaticArray, MArray, MMatrix, MVector, SArray, SMatrix, SVector
77

88
static_num_to_vec(x::Number) = sin.(SVector(1, 2) .* x)
99
static_num_to_mat(x::Number) = hcat(static_num_to_vec(x), static_num_to_vec(3x))
@@ -36,15 +36,23 @@ end
3636
mystatic(::Nothing) = nothing
3737

3838
function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
39-
(; f, x, y, tang, contexts, res1, res2) = scen
40-
return DIT.Scenario{op,pl_op,pl_fun}(
41-
mystatic(f);
39+
(; f, x, y, t, contexts, prep_args, res1, res2, name) = scen
40+
new_prep_args = (;
41+
x=mystatic(prep_args.x), contexts=map(mystatic, prep_args.contexts), t=mystatic(t)
42+
)
43+
if pl_fun == :in
44+
new_prep_args = (; new_prep_args..., y=mymutablestatic(prep_args.y))
45+
end
46+
return DIT.Scenario{op,pl_op,pl_fun}(;
47+
f=mystatic(f),
4248
x=mystatic(x),
4349
y=pl_fun == :in ? mymutablestatic(y) : mystatic(y),
44-
tang=mystatic(tang),
50+
t=mystatic(t),
4551
contexts=mystatic(contexts),
52+
prep_args=new_prep_args,
4653
res1=mystatic(res1),
4754
res2=mystatic(res2),
55+
name=name,
4856
)
4957
end
5058

DifferentiationInterfaceTest/src/scenarios/allocfree.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ function identity_scenarios(x::Number; dx::Number, dy::Number)
55
der = one(x)
66

77
return [
8-
Scenario{:pushforward,:out}(f, x; tang=(dx,), res1=(dy_from_dx,)),
9-
Scenario{:pullback,:out}(f, x; tang=(dy,), res1=(dx_from_dy,)),
8+
Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)),
9+
Scenario{:pullback,:out}(f, x, (dy,); res1=(dx_from_dy,)),
1010
Scenario{:derivative,:out}(f, x; res1=der),
1111
]
1212
end
@@ -19,8 +19,8 @@ function sum_scenarios(x::AbstractArray; dx::AbstractArray, dy::Number)
1919
grad .= one(eltype(x))
2020

2121
return [
22-
Scenario{:pushforward,:out}(f, x; tang=(dx,), res1=(dy_from_dx,)),
23-
Scenario{:pullback,:in}(f, x; tang=(dy,), res1=(dx_from_dy,)),
22+
Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)),
23+
Scenario{:pullback,:in}(f, x, (dy,); res1=(dx_from_dy,)),
2424
Scenario{:gradient,:in}(f, x; res1=grad),
2525
]
2626
end
@@ -34,8 +34,8 @@ function copyto!_scenarios(x::AbstractArray; dx::AbstractArray, dy::AbstractArra
3434
jac = Matrix(Diagonal(ones(eltype(x), length(x))))
3535

3636
return [
37-
Scenario{:pushforward,:in}(f!, y, x; tang=(dx,), res1=(dy_from_dx,)),
38-
Scenario{:pullback,:in}(f!, y, x; tang=(dy,), res1=(dx_from_dy,)),
37+
Scenario{:pushforward,:in}(f!, y, x, (dx,); res1=(dy_from_dx,)),
38+
Scenario{:pullback,:in}(f!, y, x, (dy,); res1=(dx_from_dy,)),
3939
Scenario{:jacobian,:in}(f!, y, x; res1=jac),
4040
]
4141
end

DifferentiationInterfaceTest/src/scenarios/complex.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ function complex_holomorphic_gradient_scenarios()
99
scens = Scenario[
1010
Scenario{:gradient,:out}(square_only, x; res1=grad),
1111
Scenario{:gradient,:in}(square_only, x; res1=grad),
12-
Scenario{:pullback,:out}(square_only, x; tang=(dy,), res1=(grad,)),
13-
Scenario{:pullback,:in}(square_only, x; tang=(dy,), res1=(grad,)),
12+
Scenario{:pullback,:out}(square_only, x, (dy,); res1=(grad,)),
13+
Scenario{:pullback,:in}(square_only, x, (dy,); res1=(grad,)),
1414
]
1515
return scens
1616
end
@@ -22,8 +22,8 @@ function complex_gradient_scenarios()
2222
scens = Scenario[
2323
Scenario{:gradient,:out}(abs2_only, x; res1=grad),
2424
Scenario{:gradient,:in}(abs2_only, x; res1=grad),
25-
Scenario{:pullback,:out}(abs2_only, x; tang=(dy,), res1=(grad,)),
26-
Scenario{:pullback,:in}(abs2_only, x; tang=(dy,), res1=(grad,)),
25+
Scenario{:pullback,:out}(abs2_only, x, (dy,); res1=(grad,)),
26+
Scenario{:pullback,:in}(abs2_only, x, (dy,); res1=(grad,)),
2727
]
2828
return scens
2929
end

0 commit comments

Comments
 (0)