Skip to content

WIP: Wirtinger support #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1548cbc
remove type constraints for Wirtinger
simeonschaub Sep 22, 2019
88bb756
introduce `AbstractWirtinger` and `ComplexGradient`
simeonschaub Sep 23, 2019
e3ce538
add `chain` function
simeonschaub Sep 23, 2019
186009a
make `swap_order` in `chain` a positional arg
simeonschaub Sep 23, 2019
2618146
introduce a function `unwrap_wirtinger`
simeonschaub Sep 23, 2019
aa7a84a
stop using types before they are defined
simeonschaub Sep 23, 2019
20e7134
fix tests
simeonschaub Sep 25, 2019
c42a953
rename `unwrap_wirtinger` -> `unthunk`
simeonschaub Sep 27, 2019
2ac8801
fix `chain` function
simeonschaub Sep 27, 2019
61a60ec
use the new `chain` function in `@scalar_rule`
simeonschaub Sep 27, 2019
41f1071
fix tests accordingly
simeonschaub Sep 27, 2019
93a7b14
Update src/differentials.jl
simeonschaub Oct 3, 2019
c904407
add chain(::Real, ::ComplexGradient)
simeonschaub Oct 5, 2019
a83da2f
Update src/differentials.jl
simeonschaub Oct 18, 2019
4d21e1a
special case `AbstractWirtinger` in `at_thunk`
simeonschaub Oct 19, 2019
71e0c7b
add `refine_differential` for `ComplexGradient`
simeonschaub Oct 19, 2019
491f781
overload `Base.real` for some differentials
simeonschaub Oct 19, 2019
06ccb14
add some docstrings
simeonschaub Oct 19, 2019
6ce400c
move `at_thunk`-magic into separate macro
simeonschaub Oct 19, 2019
b960a09
make at_scalar_rule detect wrong Wirtinger rules
simeonschaub Oct 19, 2019
3bc7e22
use `_thunk` as a function, not macro
simeonschaub Oct 19, 2019
e3bf56d
implement some of @oxinabox's suggestions
simeonschaub Oct 19, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions src/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,16 +296,9 @@ end
"""
@thunk body

Returns `Thunk(() -> body)`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref).
In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
Returns `Thunk(() -> body)`
"""
macro thunk(body)
if body isa Expr && body.head == :call
fname = body.args[1]
if fname in (:Wirtinger, :ComplexGradient)
return :($fname($((:(@thunk $i) for i in body.args[2:end])...)))
end
end
return :(Thunk(() -> $(esc(body))))
end

Expand Down
20 changes: 18 additions & 2 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,16 +224,32 @@ end
"""
function frule_propagation_expr(𝒟, Δs, ∂s)
∂s = map(esc, ∂s)
∂_mul_Δs = [:(chain(@thunk($(∂s[i])), $(Δs[i]))) for i in 1:length(∂s)]
∂_mul_Δs = [:(chain(@_thunk($(∂s[i])), $(Δs[i]))) for i in 1:length(∂s)]
return :(refine_differential($𝒟, +($(∂_mul_Δs...))))
end

function rrule_propagation_expr(𝒟, Δs, ∂s)
∂s = map(esc, ∂s)
∂_mul_Δs = [:(chain($(Δs[i]), @thunk($(∂s[i])))) for i in 1:length(∂s)]
∂_mul_Δs = [:(chain($(Δs[i]), @_thunk($(∂s[i])))) for i in 1:length(∂s)]
return :(refine_differential($𝒟, +($(∂_mul_Δs...))))
end

"""
@_thunk body

Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref).
In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
"""
macro _thunk(body)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need to be a macro anymore as it is only called from with in a function of ASTs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you take a look, whether everything is escaped correctly now? I'm not 100% sure, I know how escaping works.

if body isa Expr && body.head == :call
fname = body.args[1]
if fname in (:Wirtinger, :ComplexGradient)
return :($fname($((:(@thunk $(esc(i))) for i in body.args[2:end])...)))
end
end
return :(@thunk $(esc(body)))
end

"""
propagator_name(f, propname)

Expand Down