Skip to content

Commit 344f3d5

Browse files
pieveroxinabox
andauthored
Add derivatives_given_output for scalar functions (#453)
* add derivatives_given_output * esc fixes * test derivatives_given_output and correct docs * Update src/rule_definition_tools.jl Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> * add experimental warning * Update src/rule_definition_tools.jl Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> * fix typo Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
1 parent 5f13e0a commit 344f3d5

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

src/rule_definition_tools.jl

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,15 @@ macro scalar_rule(call, maybe_setup, partials...)
8888
)
8989
f = call.args[1]
9090

91-
frule_expr = scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
92-
rrule_expr = scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
91+
# Generate variables to store derivatives named dfi/dxj
92+
derivatives = map(keys(partials)) do i
93+
syms = map(j -> Symbol("∂f", i, "/∂x", j), keys(inputs))
94+
return Expr(:tuple, syms...)
95+
end
96+
97+
derivative_expr = scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials)
98+
frule_expr = scalar_frule_expr(__source__, f, call, [], inputs, derivatives)
99+
rrule_expr = scalar_rrule_expr(__source__, f, call, [], inputs, derivatives)
93100

94101
# Final return: building the expression to insert in the place of this macro
95102
code = quote
@@ -99,6 +106,7 @@ macro scalar_rule(call, maybe_setup, partials...)
99106
))
100107
end
101108

109+
$(derivative_expr)
102110
$(frule_expr)
103111
$(rrule_expr)
104112
end
@@ -135,16 +143,45 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
135143
# For consistency in code that follows we make all partials tuple expressions
136144
partials = map(partials) do partial
137145
if Meta.isexpr(partial, :tuple)
138-
partial
146+
Expr(:tuple, map(esc, partial.args)...)
139147
else
140148
length(inputs) == 1 || error("Invalid use of `@scalar_rule`")
141-
Expr(:tuple, partial)
149+
Expr(:tuple, esc(partial))
142150
end
143151
end
144152

145153
return call, setup_stmts, inputs, partials
146154
end
147155

156+
"""
157+
derivatives_given_output(Ω, f, xs...)
158+
159+
Compute the derivative of scalar function `f` at primal input point `xs...`,
160+
given that it had primal output `Ω`.
161+
Return a tuple of tuples with the partial derivatives of `f` with respect to the `xs...`.
162+
The derivative of the `i`-th component of `f` with respect to the `j`-th input can be
163+
accessed as `Df[i][j]`, where `Df = derivatives_given_output(Ω, f, xs...)`.
164+
165+
!!! warning "Experimental"
166+
This function is experimental and not part of the stable API.
167+
At the moment, it can be considered an implementation detail of the macro
168+
[`@scalar_rule`](@ref), in which it is used.
169+
In the future, the exact semantics of this function will stabilize, and it
170+
will be added to the stable API.
171+
When that happens, this warning will be removed.
172+
173+
"""
174+
function derivatives_given_output end
175+
176+
function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials)
177+
return @strip_linenos quote
178+
function ChainRulesCore.derivatives_given_output($(esc()), ::Core.Typeof($f), $(inputs...))
179+
$(__source__)
180+
$(setup_stmts...)
181+
return $(Expr(:tuple, partials...))
182+
end
183+
end
184+
end
148185

149186
function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
150187
n_outputs = length(partials)
@@ -173,6 +210,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
173210
$(__source__)
174211
$(esc()) = $call
175212
$(setup_stmts...)
213+
$(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc()), $f, $(inputs...))
176214
return $(esc()), $pushforward_returns
177215
end
178216
end
@@ -210,6 +248,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
210248
$(__source__)
211249
$(esc()) = $call
212250
$(setup_stmts...)
251+
$(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc()), $f, $(inputs...))
213252
return $(esc()), $pullback
214253
end
215254
end
@@ -240,9 +279,9 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity)
240279
# This is basically Δs ⋅ ∂s
241280
_∂s = map(∂s) do ∂s_i
242281
if _conj
243-
:(conj($(esc(∂s_i))))
282+
:(conj($∂s_i))
244283
else
245-
esc(∂s_i)
284+
∂s_i
246285
end
247286
end
248287

test/rule_definition_tools.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,9 @@ end
235235
@test== Tangent{typeof(y)}(50f0, 100f0)
236236
# make sure type is exactly as expected:
237237
@testisa Tangent{Tuple{Irrational{}, Float64}, Tuple{Float32, Float32}}
238+
239+
xs, Ω = (3,), (3, 6)
240+
@test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == ((1f0,), (2f0,))
238241
end
239242

240243
@testset "@scalar_rule projection" begin
@@ -298,7 +301,7 @@ module IsolatedModuleForTestingScoping
298301
module IsolatedSubmodule
299302
# check that rules defined in isolated module without imports can be called
300303
# without errors
301-
using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent
304+
using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output
302305
using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id
303306
using Test
304307

@@ -328,6 +331,8 @@ module IsolatedModuleForTestingScoping
328331
y, f_pullback = rrule(my_id, x)
329332
@test y == x
330333
@test f_pullback(Δy) == (NoTangent(), Δy)
334+
335+
@test derivatives_given_output(y, my_id, x) == ((1.0,),)
331336
end
332337
end
333338
end

0 commit comments

Comments
 (0)