-
Notifications
You must be signed in to change notification settings - Fork 120
ChainRules rrule
Integration for Unitful
#504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
30a98b8
16281fe
3c48749
6f47bdf
c98ebdc
3a5e0a3
5e24deb
5f11a68
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,29 @@ | ||||||||||
function rrule(UT::Type{Quantity{T,D,U}}, x::Number) where {T,D,U} | ||||||||||
unitful_x = Quantity{T,D,U}(x) | ||||||||||
projector_x = ProjectTo(x) | ||||||||||
uq_pullback(Δx) = (NoTangent(), projector_x(Δx) * oneunit(UT)) | ||||||||||
return unitful_x, uq_pullback | ||||||||||
end | ||||||||||
Comment on lines
+1
to
+6
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When is this called? If't it's used when attaching units to an initially plain number, And does that work out in practice? With some There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||
|
||||||||||
function (projector::ProjectTo{<:Quantity})(x::Number) | ||||||||||
new_val = projector.project_val(ustrip(x)) | ||||||||||
Comment on lines
+8
to
+9
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. convention:
Suggested change
|
||||||||||
return new_val*unit(x) | ||||||||||
end | ||||||||||
|
||||||||||
# Project Unitful Quantities onto numerical types by projecting the value and carrying units | ||||||||||
ProjectTo(x::Quantity) = ProjectTo(x.val) | ||||||||||
|
||||||||||
(project::ProjectTo{<:Real})(dx::Quantity) = project(ustrip(dx))*unit(dx) | ||||||||||
(project::ProjectTo{<:Complex})(dx::Quantity) = project(ustrip(dx))*unit(dx) | ||||||||||
|
||||||||||
function rrule(::typeof(*), x::Quantity, y::Units, z::Units...) | ||||||||||
Ω = *(x, y, z...) | ||||||||||
function times_pb(Δ) | ||||||||||
nots = ntuple(_ -> NoTangent(), 1 + length(z)) | ||||||||||
return (NoTangent(), *(ProjectTo(x)(Δ), y, z...), nots...) | ||||||||||
end | ||||||||||
return Ω, times_pb | ||||||||||
end | ||||||||||
|
||||||||||
rrule(::typeof(/), x::Number, y::Units) = rrule(*, x, inv(y)) | ||||||||||
rrule(::typeof(/), x::Units, y::Number) = rrule(*, x, inv(y)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
using ChainRulesCore: rrule, ProjectTo, NoTangent | ||
|
||
@testset "ProjectTo" begin | ||
real_test(proj, val) = proj(val) == real(val) | ||
complex_test(proj, val) = proj(val) == val | ||
uval = 8.0*u"W" | ||
p_uval = ProjectTo(uval) | ||
cuval = (1.0+im)*u"kg" | ||
p_cuval = ProjectTo(cuval) | ||
|
||
p_real = ProjectTo(1.0) | ||
p_complex = ProjectTo(1.0+im) | ||
|
||
δval = 6.0*u"m" | ||
δcval = (2.0+3.0im)*u"L" | ||
|
||
# Test projection onto real unitful quantities | ||
for δ in (δval, δcval, 1.0, 1.0+im) | ||
@test real_test(p_uval, δ) | ||
end | ||
|
||
# Test projection onto complex unitful quantities | ||
for δ in (δval, δcval, 1.0, 1.0+im) | ||
@test complex_test(p_cuval, δ) | ||
end | ||
|
||
# Projecting Unitful quantities onto real values | ||
@test p_real(δval) == δval | ||
@test p_real(δcval) == real(δcval) | ||
|
||
# Projecting Unitful quantities onto complex values | ||
@test p_complex(δval) == δval | ||
@test p_complex(δcval) == δcval | ||
end | ||
|
||
@testset "rrules" begin | ||
@testset "Quantity rrule" begin | ||
UT = typeof(1.0*u"W") | ||
x = 5.0 | ||
Ω, pb = rrule(UT, x) | ||
@test Ω == 5.0 * u"W" | ||
@test pb(3.0) == (NoTangent(), 3.0 * u"W") | ||
end | ||
@testset "* rrule" begin | ||
x = 5.0*u"W" | ||
y = u"m" | ||
z = u"L" | ||
Ω, pb = rrule(*, x, y, z) | ||
@test Ω == x*y*z | ||
@test pb(3.0) == (NoTangent(), 3.0*y*z, NoTangent(), NoTangent()) | ||
end | ||
end |
Uh oh!
There was an error while loading. Please reload this page.