Skip to content

Commit 9d1ca76

Browse files
committed
fix!: make strict preparation the default
1 parent 4f0e496 commit 9d1ca76

File tree

13 files changed

+132
-151
lines changed

13 files changed

+132
-151
lines changed

DifferentiationInterface/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Changed
1111

12+
- Preparation is now strict by default ([#799])
1213
- New Arxiv preprint for citation ([#795])
1314

1415
## [0.6.54] - 2025-05-11
@@ -28,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2829
[0.6.54]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.53...DifferentiationInterface-v0.6.54
2930
[0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53
3031

32+
[#799]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/799
3133
[#795]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/795
3234
[#790]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/790
3335
[#788]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/788

DifferentiationInterface/src/docstrings.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ function docstring_prepare(operator; samepoint=false, inplace=false)
2626
Otherwise, preparation becomes invalid and you need to run it again.
2727
In some settings, invalid preparations may still give correct results (e.g. for backends that require no preparation), but this is not a semantic guarantee and should not be relied upon.
2828
29-
When `strict=Val(true)`, type checking is enforced between preparation and execution (but size checking is left to the user).
29+
When `strict=Val(true)` (the default), type checking is enforced between preparation and execution (but size checking is left to the user).
30+
While your code may work for different types by setting `strict=Val(false)`, this is not guaranteed by the API and can break without warning.
3031
"""
3132
end
3233

DifferentiationInterface/src/first_order/derivative.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,19 @@
11
## Docstrings
22

33
"""
4-
prepare_derivative(f, backend, x, [contexts...]; strict=Val(false)) -> prep
5-
prepare_derivative(f!, y, backend, x, [contexts...]; strict=Val(false)) -> prep
4+
prepare_derivative(f, backend, x, [contexts...]; strict=Val(true)) -> prep
5+
prepare_derivative(f!, y, backend, x, [contexts...]; strict=Val(true)) -> prep
66
77
$(docstring_prepare("derivative"; inplace=true))
88
"""
99
function prepare_derivative(
10-
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false)
10+
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true)
1111
) where {F,C}
1212
return prepare_derivative_nokwarg(strict, f, backend, x, contexts...)
1313
end
1414

1515
function prepare_derivative(
16-
f!::F,
17-
y,
18-
backend::AbstractADType,
19-
x,
20-
contexts::Vararg{Context,C};
21-
strict::Val=Val(false),
16+
f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true)
2217
) where {F,C}
2318
return prepare_derivative_nokwarg(strict, f!, y, backend, x, contexts...)
2419
end
@@ -42,8 +37,7 @@ function prepare!_derivative(
4237
old_prep::DerivativePrep,
4338
backend::AbstractADType,
4439
x,
45-
contexts::Vararg{Context,C};
46-
strict::Val=Val(false),
40+
contexts::Vararg{Context,C},
4741
) where {F,C}
4842
check_prep(f!, y, old_prep, backend, x, contexts...)
4943
return prepare_derivative_nokwarg(is_strict(old_prep), f!, y, backend, x, contexts...)

DifferentiationInterface/src/first_order/gradient.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
## Docstrings
22

33
"""
4-
prepare_gradient(f, backend, x, [contexts...]; strict=Val(false)) -> prep
4+
prepare_gradient(f, backend, x, [contexts...]; strict=Val(true)) -> prep
55
66
$(docstring_prepare("gradient"))
77
"""
88
function prepare_gradient(
9-
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false)
9+
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true)
1010
) where {F,C}
1111
return prepare_gradient_nokwarg(strict, f, backend, x, contexts...)
1212
end

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,19 @@
11
## Docstrings
22

33
"""
4-
prepare_jacobian(f, backend, x, [contexts...]; strict=Val(false)) -> prep
5-
prepare_jacobian(f!, y, backend, x, [contexts...]; strict=Val(false)) -> prep
4+
prepare_jacobian(f, backend, x, [contexts...]; strict=Val(true)) -> prep
5+
prepare_jacobian(f!, y, backend, x, [contexts...]; strict=Val(true)) -> prep
66
77
$(docstring_prepare("jacobian"; inplace=true))
88
"""
99
function prepare_jacobian(
10-
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false)
10+
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true)
1111
) where {F,C}
1212
return prepare_jacobian_nokwarg(strict, f, backend, x, contexts...)
1313
end
1414

1515
function prepare_jacobian(
16-
f!::F,
17-
y,
18-
backend::AbstractADType,
19-
x,
20-
contexts::Vararg{Context,C};
21-
strict::Val=Val(false),
16+
f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true)
2217
) where {F,C}
2318
return prepare_jacobian_nokwarg(strict, f!, y, backend, x, contexts...)
2419
end
@@ -43,7 +38,6 @@ function prepare!_jacobian(
4338
backend::AbstractADType,
4439
x,
4540
contexts::Vararg{Context,C};
46-
strict::Val=Val(false),
4741
) where {F,C}
4842
check_prep(f!, y, old_prep, backend, x, contexts...)
4943
return prepare_jacobian_nokwarg(is_strict(old_prep), f!, y, backend, x, contexts...)

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
## Docstrings
22

33
"""
4-
prepare_pullback(f, backend, x, ty, [contexts...]; strict=Val(false)) -> prep
5-
prepare_pullback(f!, y, backend, x, ty, [contexts...]; strict=Val(false)) -> prep
4+
prepare_pullback(f, backend, x, ty, [contexts...]; strict=Val(true)) -> prep
5+
prepare_pullback(f!, y, backend, x, ty, [contexts...]; strict=Val(true)) -> prep
66
77
$(docstring_prepare("pullback"; inplace=true))
88
"""
@@ -12,7 +12,7 @@ function prepare_pullback(
1212
x,
1313
ty::NTuple,
1414
contexts::Vararg{Context,C};
15-
strict::Val=Val(false),
15+
strict::Val=Val(true),
1616
) where {F,C}
1717
return prepare_pullback_nokwarg(strict, f, backend, x, ty, contexts...)
1818
end
@@ -24,7 +24,7 @@ function prepare_pullback(
2424
x,
2525
ty::NTuple,
2626
contexts::Vararg{Context,C};
27-
strict::Val=Val(false),
27+
strict::Val=Val(true),
2828
) where {F,C}
2929
return prepare_pullback_nokwarg(strict, f!, y, backend, x, ty, contexts...)
3030
end
@@ -61,8 +61,8 @@ function prepare!_pullback(
6161
end
6262

6363
"""
64-
prepare_pullback_same_point(f, backend, x, ty, [contexts...]; strict=Val(false)) -> prep_same
65-
prepare_pullback_same_point(f!, y, backend, x, ty, [contexts...]; strict=Val(false)) -> prep_same
64+
prepare_pullback_same_point(f, backend, x, ty, [contexts...]; strict=Val(true)) -> prep_same
65+
prepare_pullback_same_point(f!, y, backend, x, ty, [contexts...]; strict=Val(true)) -> prep_same
6666
6767
$(docstring_prepare("pullback"; samepoint=true, inplace=true))
6868
"""
@@ -72,7 +72,7 @@ function prepare_pullback_same_point(
7272
x,
7373
ty::NTuple,
7474
contexts::Vararg{Context,C};
75-
strict::Val=Val(false),
75+
strict::Val=Val(true),
7676
) where {F,C}
7777
return prepare_pullback_same_point_nokwarg(strict, f, backend, x, ty, contexts...)
7878
end
@@ -84,7 +84,7 @@ function prepare_pullback_same_point(
8484
x,
8585
ty::NTuple,
8686
contexts::Vararg{Context,C};
87-
strict::Val=Val(false),
87+
strict::Val=Val(true),
8888
) where {F,C}
8989
return prepare_pullback_same_point_nokwarg(strict, f!, y, backend, x, ty, contexts...)
9090
end

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
## Docstrings
22

33
"""
4-
prepare_pushforward(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep
5-
prepare_pushforward(f!, y, backend, x, tx, [contexts...]; strict=Val(false)) -> prep
4+
prepare_pushforward(f, backend, x, tx, [contexts...]; strict=Val(true)) -> prep
5+
prepare_pushforward(f!, y, backend, x, tx, [contexts...]; strict=Val(true)) -> prep
66
77
$(docstring_prepare("pushforward"; inplace=true))
88
"""
@@ -12,7 +12,7 @@ function prepare_pushforward(
1212
x,
1313
tx::NTuple,
1414
contexts::Vararg{Context,C};
15-
strict::Val=Val(false),
15+
strict::Val=Val(true),
1616
) where {F,C}
1717
return prepare_pushforward_nokwarg(strict, f, backend, x, tx, contexts...)
1818
end
@@ -24,7 +24,7 @@ function prepare_pushforward(
2424
x,
2525
tx::NTuple,
2626
contexts::Vararg{Context,C};
27-
strict::Val=Val(false),
27+
strict::Val=Val(true),
2828
) where {F,C}
2929
return prepare_pushforward_nokwarg(strict, f!, y, backend, x, tx, contexts...)
3030
end
@@ -63,8 +63,8 @@ function prepare!_pushforward(
6363
end
6464

6565
"""
66-
prepare_pushforward_same_point(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep_same
67-
prepare_pushforward_same_point(f!, y, backend, x, tx, [contexts...]; strict=Val(false)) -> prep_same
66+
prepare_pushforward_same_point(f, backend, x, tx, [contexts...]; strict=Val(true)) -> prep_same
67+
prepare_pushforward_same_point(f!, y, backend, x, tx, [contexts...]; strict=Val(true)) -> prep_same
6868
6969
$(docstring_prepare("pushforward"; samepoint=true, inplace=true))
7070
"""
@@ -74,7 +74,7 @@ function prepare_pushforward_same_point(
7474
x,
7575
tx::NTuple,
7676
contexts::Vararg{Context,C};
77-
strict::Val=Val(false),
77+
strict::Val=Val(true),
7878
) where {F,C}
7979
return prepare_pushforward_same_point_nokwarg(strict, f, backend, x, tx, contexts...)
8080
end
@@ -86,7 +86,7 @@ function prepare_pushforward_same_point(
8686
x,
8787
tx::NTuple,
8888
contexts::Vararg{Context,C};
89-
strict::Val=Val(false),
89+
strict::Val=Val(true),
9090
) where {F,C}
9191
return prepare_pushforward_same_point_nokwarg(
9292
strict, f!, y, backend, x, tx, contexts...

DifferentiationInterface/src/second_order/hessian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
## Docstrings
22

33
"""
4-
prepare_hessian(f, backend, x, [contexts...]; strict=Val(false)) -> prep
4+
prepare_hessian(f, backend, x, [contexts...]; strict=Val(true)) -> prep
55
66
$(docstring_prepare("hessian"))
77
"""
88
function prepare_hessian(
9-
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false)
9+
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true)
1010
) where {F,C}
1111
return prepare_hessian_nokwarg(strict, f, backend, x, contexts...)
1212
end

DifferentiationInterface/src/second_order/hvp.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
## Docstrings
22

33
"""
4-
prepare_hvp(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep
4+
prepare_hvp(f, backend, x, tx, [contexts...]; strict=Val(true)) -> prep
55
66
$(docstring_prepare("hvp"))
77
"""
@@ -11,7 +11,7 @@ function prepare_hvp(
1111
x,
1212
tx::NTuple,
1313
contexts::Vararg{Context,C};
14-
strict::Val=Val(false),
14+
strict::Val=Val(true),
1515
) where {F,C}
1616
return prepare_hvp_nokwarg(strict, f, backend, x, tx, contexts...)
1717
end
@@ -34,7 +34,7 @@ function prepare!_hvp(
3434
end
3535

3636
"""
37-
prepare_hvp_same_point(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep_same
37+
prepare_hvp_same_point(f, backend, x, tx, [contexts...]; strict=Val(true)) -> prep_same
3838
3939
$(docstring_prepare("hvp"; samepoint=true))
4040
"""
@@ -44,7 +44,7 @@ function prepare_hvp_same_point(
4444
x,
4545
tx::NTuple,
4646
contexts::Vararg{Context,C};
47-
strict::Val=Val(false),
47+
strict::Val=Val(true),
4848
) where {F,C}
4949
return prepare_hvp_same_point_nokwarg(strict, f, backend, x, tx, contexts...)
5050
end

DifferentiationInterface/src/second_order/second_derivative.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
## Docstrings
22

33
"""
4-
prepare_second_derivative(f, backend, x, [contexts...]; strict=Val(false)) -> prep
4+
prepare_second_derivative(f, backend, x, [contexts...]; strict=Val(true)) -> prep
55
66
$(docstring_prepare("second_derivative"))
77
"""
88
function prepare_second_derivative(
9-
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false)
9+
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true)
1010
) where {F,C}
1111
return prepare_second_derivative_nokwarg(strict, f, backend, x, contexts...)
1212
end

DifferentiationInterface/src/utils/prep.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ function Base.showerror(
109109
end
110110
println(
111111
io,
112-
"If you are confident that this check is superfluous, you can disable it by running preparation with the keyword argument `strict=Val(false)` inside DifferentiationInterface.",
112+
"If you are confident that this check is superfluous, you can disable it by running preparation with the keyword argument `strict=Val(true)` inside DifferentiationInterface.",
113113
)
114114
return nothing
115115
end

DifferentiationInterface/test/Core/Internals/signature.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ y = zeros(2)
1111
c = 2.0
1212

1313
@testset "Out of place, no tangents" begin
14-
prep = prepare_derivative(f, backend, x, Constant(c); strict=Val(true))
15-
prep_chill = prepare_derivative(f, backend, x, Constant(c); strict=Val(false))
14+
prep = prepare_derivative(f, backend, x, Constant(c))
15+
prep_chill = prepare_derivative(f, backend, x, Constant(c))
1616

1717
@test_throws MethodError derivative(nothing, prep_chill, backend, x, Constant(c))
1818

@@ -68,8 +68,8 @@ c = 2.0
6868
end
6969

7070
@testset "In place, no tangents" begin
71-
prep = prepare_derivative(f!, y, backend, x; strict=Val(true))
72-
prep_chill = prepare_derivative(f!, y, backend, x; strict=Val(false))
71+
prep = prepare_derivative(f!, y, backend, x)
72+
prep_chill = prepare_derivative(f!, y, backend, x)
7373

7474
@test_throws MethodError derivative(nothing, y, prep_chill, backend, x, Constant(c))
7575

@@ -86,8 +86,8 @@ end
8686
end
8787

8888
@testset "Out of place, with tangents" begin
89-
prep = prepare_pushforward(f, backend, x, (x,), Constant(c); strict=Val(true))
90-
prep_chill = prepare_pushforward(f, backend, x, (x,), Constant(c); strict=Val(false))
89+
prep = prepare_pushforward(f, backend, x, (x,), Constant(c))
90+
prep_chill = prepare_pushforward(f, backend, x, (x,), Constant(c))
9191

9292
@test_throws MethodError pushforward(nothing, prep_chill, backend, x, (x,))
9393

@@ -104,10 +104,8 @@ end
104104
end
105105

106106
@testset "In place, with tangents" begin
107-
prep = prepare_pushforward(f!, y, backend, x, (x,); strict=Val(true))
108-
prep_chill = prepare_pushforward(
109-
f!, y, backend, x, (x,), Constant(c); strict=Val(false)
110-
)
107+
prep = prepare_pushforward(f!, y, backend, x, (x,))
108+
prep_chill = prepare_pushforward(f!, y, backend, x, (x,), Constant(c))
111109

112110
@test_throws MethodError pushforward(nothing, y, prep_chill, backend, x, (x,))
113111

0 commit comments

Comments
 (0)