Skip to content

Commit 93c0659

Browse files
authored
Contexts for FiniteDiff and PolyesterForwardDiff, better context testing (#497)
* More contexts * Stick to same contexts for same point preparation * Doc * Myrandom instead of zero * Randomizable constant
1 parent d9a5cab commit 93c0659

File tree

14 files changed

+758
-268
lines changed

14 files changed

+758
-268
lines changed

DifferentiationInterface/docs/src/explanation/operators.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ For different-point preparation, the output `prep` of `prepare_op(f, b, x, [t])`
130130

131131
For same-point preparation, the output `prep` of `prepare_op_same_point(f, b, x, [t])` can be reused in `op(f, prep, b, x, other_t)`, provided that:
132132

133-
- the input `x` remains the same
133+
- the input `x` remains the same (as well as the [`Context`](@ref) constants)
134134
- the tangents in `t` and `other_t` have similar types and equal shapes
135135

136136
!!! warning

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

+33-11
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,49 @@ struct ChainRulesPullbackPrepSamePoint{Y,PB} <: PullbackPrep
55
pb::PB
66
end
77

8-
function DI.prepare_pullback(f, ::AutoReverseChainRules, x, ty::Tangents)
8+
function DI.prepare_pullback(
9+
f, ::AutoReverseChainRules, x, ty::Tangents, contexts::Vararg{Constant,C}
10+
) where {C}
911
return NoPullbackPrep()
1012
end
1113

1214
function DI.prepare_pullback_same_point(
13-
f, ::NoPullbackPrep, backend::AutoReverseChainRules, x, ty::Tangents
14-
)
15+
f,
16+
::NoPullbackPrep,
17+
backend::AutoReverseChainRules,
18+
x,
19+
ty::Tangents,
20+
contexts::Vararg{Constant,C},
21+
) where {C}
1522
rc = ruleconfig(backend)
16-
y, pb = rrule_via_ad(rc, f, x)
23+
y, pb = rrule_via_ad(rc, f, x, map(unwrap, contexts)...)
1724
return ChainRulesPullbackPrepSamePoint(y, pb)
1825
end
1926

2027
function DI.value_and_pullback(
21-
f, ::NoPullbackPrep, backend::AutoReverseChainRules, x, ty::Tangents
22-
)
28+
f,
29+
::NoPullbackPrep,
30+
backend::AutoReverseChainRules,
31+
x,
32+
ty::Tangents,
33+
contexts::Vararg{Constant,C},
34+
) where {C}
2335
rc = ruleconfig(backend)
24-
y, pb = rrule_via_ad(rc, f, x)
36+
y, pb = rrule_via_ad(rc, f, x, map(unwrap, contexts)...)
2537
tx = map(ty) do dy
2638
pb(dy)[2]
2739
end
2840
return y, tx
2941
end
3042

3143
function DI.value_and_pullback(
32-
f, prep::ChainRulesPullbackPrepSamePoint, ::AutoReverseChainRules, x, ty::Tangents
33-
)
44+
f,
45+
prep::ChainRulesPullbackPrepSamePoint,
46+
::AutoReverseChainRules,
47+
x,
48+
ty::Tangents,
49+
contexts::Vararg{Constant,C},
50+
) where {C}
3451
@compat (; y, pb) = prep
3552
tx = map(ty) do dy
3653
pb(dy)[2]
@@ -39,8 +56,13 @@ function DI.value_and_pullback(
3956
end
4057

4158
function DI.pullback(
42-
f, prep::ChainRulesPullbackPrepSamePoint, ::AutoReverseChainRules, x, ty::Tangents
43-
)
59+
f,
60+
prep::ChainRulesPullbackPrepSamePoint,
61+
::AutoReverseChainRules,
62+
x,
63+
ty::Tangents,
64+
contexts::Vararg{Constant,C},
65+
) where {C}
4466
@compat (; pb) = prep
4567
tx = map(ty) do dy
4668
pb(dy)[2]

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module DifferentiationInterfaceFiniteDiffExt
33
using ADTypes: AutoFiniteDiff
44
import DifferentiationInterface as DI
55
using DifferentiationInterface:
6+
Context,
67
DerivativePrep,
78
GradientPrep,
89
HessianPrep,
@@ -15,7 +16,9 @@ using DifferentiationInterface:
1516
NoJacobianPrep,
1617
NoPullbackPrep,
1718
NoPushforwardPrep,
18-
Tangents
19+
Tangents,
20+
unwrap,
21+
with_contexts
1922
using FiniteDiff:
2023
DerivativeCache,
2124
GradientCache,

0 commit comments

Comments
 (0)