Skip to content

ChainRules rrule Integration for Unitful #504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
name = "Unitful"
uuid = "1986cc42-f94f-5a68-af5c-568840ba703d"
version = "1.11.0"
version = "1.12.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
ChainRulesCore = "1"
ConstructionBase = "1"
julia = "1"

Expand Down
3 changes: 3 additions & 0 deletions src/Unitful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import Random

import ConstructionBase: constructorof

import ChainRulesCore: rrule, NoTangent, ProjectTo

export logunit, unit, absoluteunit, dimension, uconvert, ustrip, upreferred
export @dimension, @derived_dimension, @refunit, @unit, @affineunit, @u_str
export Quantity, DimensionlessQuantity, NoUnits, NoDims
Expand Down Expand Up @@ -69,5 +71,6 @@ include("logarithm.jl")
include("complex.jl")
include("pkgdefaults.jl")
include("dates.jl")
include("chainrules.jl")

end
29 changes: 29 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
function rrule(UT::Type{Quantity{T,D,U}}, x::Number) where {T,D,U}
unitful_x = Quantity{T,D,U}(x)
projector_x = ProjectTo(x)
uq_pullback(Δx) = (NoTangent(), projector_x(Δx) * oneunit(UT))
return unitful_x, uq_pullback
end
Comment on lines +1 to +6
Copy link
Contributor

@mcabbott mcabbott Dec 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this called?

If't it's used when attaching units to an initially plain number, x=1 -> unitful_x = 1m, then the thinking is that if the loss is a unitless scalar, the gradient for unitful_x will be d loss / d unitful_x = 100/m, and this will produce a gradient for x with no units (or units equivalent to 1)?

And does that work out in practice? With some Zygote.gradient(loss, 1u"m")... must you ensure by hand that you remove the units within loss, or does Zygote.sensititvity do the right thing? Maybe that's a bigger question than this function... have not thought much about how this all ought to work.

Copy link
Contributor Author

@SBuercklin SBuercklin Dec 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zygote.sensitivity returns the multiplicative identity which is usually 1.0, even for Unitful.Quantity. The example worked out in @oxinabox's comment matches how I've thought about this, so I think Zygote is correct here.


function (projector::ProjectTo{<:Quantity})(x::Number)
new_val = projector.project_val(ustrip(x))
Comment on lines +8 to +9
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convention:

Suggested change
function (projector::ProjectTo{<:Quantity})(x::Number)
new_val = projector.project_val(ustrip(x))
function (project::ProjectTo{<:Quantity})(x::Number)
new_val = project.val(ustrip(x))

return new_val*unit(x)
end

# Project Unitful Quantities onto numerical types by projecting the value and carrying units
ProjectTo(x::Quantity) = ProjectTo(x.val)

(project::ProjectTo{<:Real})(dx::Quantity) = project(ustrip(dx))*unit(dx)
(project::ProjectTo{<:Complex})(dx::Quantity) = project(ustrip(dx))*unit(dx)

function rrule(::typeof(*), x::Quantity, y::Units, z::Units...)
Ω = *(x, y, z...)
function times_pb(Δ)
nots = ntuple(_ -> NoTangent(), 1 + length(z))
return (NoTangent(), *(ProjectTo(x)(Δ), y, z...), nots...)
end
return Ω, times_pb
end

rrule(::typeof(/), x::Number, y::Units) = rrule(*, x, inv(y))
rrule(::typeof(/), x::Units, y::Number) = rrule(*, x, inv(y))
52 changes: 52 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using ChainRulesCore: rrule, ProjectTo, NoTangent

@testset "ProjectTo" begin
real_test(proj, val) = proj(val) == real(val)
complex_test(proj, val) = proj(val) == val
uval = 8.0*u"W"
p_uval = ProjectTo(uval)
cuval = (1.0+im)*u"kg"
p_cuval = ProjectTo(cuval)

p_real = ProjectTo(1.0)
p_complex = ProjectTo(1.0+im)

δval = 6.0*u"m"
δcval = (2.0+3.0im)*u"L"

# Test projection onto real unitful quantities
for δ in (δval, δcval, 1.0, 1.0+im)
@test real_test(p_uval, δ)
end

# Test projection onto complex unitful quantities
for δ in (δval, δcval, 1.0, 1.0+im)
@test complex_test(p_cuval, δ)
end

# Projecting Unitful quantities onto real values
@test p_real(δval) == δval
@test p_real(δcval) == real(δcval)

# Projecting Unitful quantities onto complex values
@test p_complex(δval) == δval
@test p_complex(δcval) == δcval
end

@testset "rrules" begin
@testset "Quantity rrule" begin
UT = typeof(1.0*u"W")
x = 5.0
Ω, pb = rrule(UT, x)
@test Ω == 5.0 * u"W"
@test pb(3.0) == (NoTangent(), 3.0 * u"W")
end
@testset "* rrule" begin
x = 5.0*u"W"
y = u"m"
z = u"L"
Ω, pb = rrule(*, x, y, z)
@test Ω == x*y*z
@test pb(3.0) == (NoTangent(), 3.0*y*z, NoTangent(), NoTangent())
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2047,6 +2047,10 @@ end
"""
end

@testset "ChainRules" begin
include("./chainrules.jl")
end

# Test precompiled Unitful extension modules
load_path = mktempdir()
load_cache_path = mktempdir()
Expand Down