From c35f043edad1d76414bd04d99d165607700d7c61 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 13 Oct 2021 22:22:27 +0200 Subject: [PATCH 01/23] Remove CommutativeMulNumber restrictions --- src/rulesets/Base/arraymath.jl | 77 ++++++++++++++++------------------ 1 file changed, 35 insertions(+), 42 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index f70cb878c..a5aaf715f 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -19,12 +19,7 @@ end ##### `*` ##### - -function rrule( - ::typeof(*), - A::AbstractVecOrMat{<:CommutativeMulNumber}, - B::AbstractVecOrMat{<:CommutativeMulNumber}, -) +function rrule(::typeof(*), A::AbstractVecOrMat{<:Number}, B::AbstractVecOrMat{<:Number}) project_A = ProjectTo(A) project_B = ProjectTo(B) function times_pullback(ȳ) @@ -60,13 +55,7 @@ function rrule( return A * B, times_pullback end - - -function rrule( - ::typeof(*), - A::AbstractVector{<:CommutativeMulNumber}, - B::AbstractMatrix{<:CommutativeMulNumber}, -) +function rrule(::typeof(*), A::AbstractVector{<:Number}, B::AbstractMatrix{<:Number}) project_A = ProjectTo(A) project_B = ProjectTo(B) function times_pullback(ȳ) @@ -87,19 +76,14 @@ function rrule( return A * B, times_pullback end - - - -function rrule( - ::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber} -) +function rrule(::typeof(*), A::Number, B::AbstractArray{<:Number}) project_A = ProjectTo(A) project_B = ProjectTo(B) function times_pullback(ȳ) Ȳ = unthunk(ȳ) return ( NoTangent(), - @thunk(project_A(dot(Ȳ, B)')), + @thunk(project_A(eltype(B) isa CommutativeMulNumber ? dot(Ȳ, B)' : dot(Ȳ', B'))), InplaceableThunk( X̄ -> mul!(X̄, conj(A), Ȳ, true, true), @thunk(project_B(A' * Ȳ)), @@ -110,7 +94,7 @@ function rrule( end function rrule( - ::typeof(*), B::AbstractArray{<:CommutativeMulNumber}, A::CommutativeMulNumber + ::typeof(*), B::AbstractArray{<:Number}, A::Number ) project_A = ProjectTo(A) project_B = ProjectTo(B) @@ -122,7 +106,7 @@ function rrule( X̄ -> mul!(X̄, conj(A), Ȳ, true, true), @thunk(project_B(A' * Ȳ)), ), - @thunk(project_A(dot(Ȳ, B)')), + @thunk(project_A(eltype(A) isa CommutativeMulNumber ? dot(Ȳ, B)' : dot(Ȳ', B'))), ) end return A * B, times_pullback @@ -135,9 +119,9 @@ end function rrule( ::typeof(muladd), - A::AbstractMatrix{<:CommutativeMulNumber}, - B::AbstractVecOrMat{<:CommutativeMulNumber}, - z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}}, + A::AbstractMatrix{<:Number}, + B::AbstractVecOrMat{<:Number}, + z::Union{Number, AbstractVecOrMat{<:Number}}, ) project_A = ProjectTo(A) project_B = ProjectTo(B) @@ -173,9 +157,9 @@ end function rrule( ::typeof(muladd), - ut::LinearAlgebra.AdjOrTransAbsVec{<:CommutativeMulNumber}, - v::AbstractVector{<:CommutativeMulNumber}, - z::CommutativeMulNumber, + ut::LinearAlgebra.AdjOrTransAbsVec{<:Number}, + v::AbstractVector{<:Number}, + z::Number, ) project_ut = ProjectTo(ut) project_v = ProjectTo(v) @@ -186,11 +170,11 @@ function rrule( dy = unthunk(ȳ) ut_thunk = InplaceableThunk( dut -> dut .+= v' .* dy, - @thunk(project_ut(v' .* dy)), + @thunk(project_ut((v * dy')')), ) v_thunk = InplaceableThunk( dv -> dv .+= ut' .* dy, - @thunk(project_v(ut' .* dy)), + @thunk(project_v(ut' * dy)), ) (NoTangent(), ut_thunk, v_thunk, z isa Bool ? NoTangent() : project_z(dy)) end @@ -199,9 +183,9 @@ end function rrule( ::typeof(muladd), - u::AbstractVector{<:CommutativeMulNumber}, - vt::LinearAlgebra.AdjOrTransAbsVec{<:CommutativeMulNumber}, - z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}}, + u::AbstractVector{<:Number}, + vt::LinearAlgebra.AdjOrTransAbsVec{<:Number}, + z::Union{Number, AbstractVecOrMat{<:Number}}, ) project_u = ProjectTo(u) project_vt = ProjectTo(vt) @@ -211,8 +195,8 @@ function rrule( function muladd_pullback_3(ȳ) Ȳ = unthunk(ȳ) proj = ( - @thunk(project_u(vec(sum(Ȳ .* conj.(vt), dims=2)))), - @thunk(project_vt(vec(sum(u .* conj.(Ȳ), dims=1))')), + @thunk(project_u(Ȳ * vec(vt'))), + @thunk(project_vt((Ȳ' * u)')), ) addon = if z isa Bool NoTangent() @@ -280,7 +264,7 @@ end ##### `\`, `/` matrix-scalar_rule ##### -function rrule(::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber) +function rrule(::typeof(/), A::AbstractArray{<:Number}, b::Number) Y = A/b function slash_pullback_scalar(ȳ) Ȳ = unthunk(ȳ) @@ -288,17 +272,26 @@ function rrule(::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::Commuta dA -> dA .+= Ȳ ./ conj(b), @thunk(Ȳ / conj(b)), ) - bthunk = @thunk(-dot(A,Ȳ) / conj(b^2)) + bthunk = @thunk(-dot(Y,Ȳ) / conj(b)) return (NoTangent(), Athunk, bthunk) end return Y, slash_pullback_scalar end -function rrule(::typeof(\), b::CommutativeMulNumber, A::AbstractArray{<:CommutativeMulNumber}) - Y, back = rrule(/, A, b) - function backslash_pullback(dY) # just reverses the arguments! - d0, dA, db = back(dY) - return (d0, db, dA) +function rrule(::typeof(\), b::Number, A::AbstractArray{<:Number}) + Y = b \ A + function backslash_pullback(ȳ) + Ȳ = unthunk(ȳ) + Athunk = InplaceableThunk( + @thunk(conj(b) \ Ȳ), + dA -> dA .+= conj(b) .\ Ȳ, + ) + bthunk = if eltype(Y) isa CommutativeMulNumber + @thunk(-conj(b) \ dot(Y, Ȳ)) + else + @thunk(-conj(b) \ dot(Ȳ',Y')) + end + return (NoTangent(), Athunk, bthunk) end return Y, backslash_pullback end From 5e036ca292bef2ea9a8b19c98178813d292c9bd9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 13 Oct 2021 22:23:32 +0200 Subject: [PATCH 02/23] Simplify rrule for ldiv --- src/rulesets/Base/arraymath.jl | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index a5aaf715f..9dfdce637 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -217,20 +217,17 @@ end ##### `/` ##### -function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) - Aᵀ, dA_pb = rrule(adjoint, A) - Bᵀ, dB_pb = rrule(adjoint, B) - Cᵀ, dS_pb = rrule(\, Bᵀ, Aᵀ) - C, dC_pb = rrule(adjoint, Cᵀ) - function slash_pullback(Ȳ) - # Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want - _, dC = dC_pb(Ȳ) - _, dBᵀ, dAᵀ = dS_pb(unthunk(dC)) - - ∂A = last(dA_pb(unthunk(dAᵀ))) - ∂B = last(dB_pb(unthunk(dBᵀ))) - - (NoTangent(), ∂A, ∂B) +function rrule(::typeof(/), A::AbstractVecOrMat{<:Number}, B::AbstractVecOrMat{<:Number}) + C = A / B + project_A = ProjectTo(A) + project_B = ProjectTo(B) + function slash_pullback(ΔC) + ∂A = unthunk(ΔC) / B' + ∂B = InplaceableThunk( + @thunk(C' * -∂A), + B̄ -> mul!(B̄, C', ∂A, -1, true) + ) + return NoTangent(), project_A(∂A), project_B(∂B) end return C, slash_pullback end From 331c65f55dead94a6cee05965ad96db928e85eb0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 13 Oct 2021 22:25:47 +0200 Subject: [PATCH 03/23] Release type constraints --- src/rulesets/Base/base.jl | 4 ++-- src/rulesets/Base/fastmath_able.jl | 16 ++++++++-------- src/rulesets/LinearAlgebra/structured.jl | 6 +++--- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 5b9237bcf..8b9648ecd 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -76,13 +76,13 @@ end @scalar_rule hypot(x::Real) sign(x) -function frule((_, Δz), ::typeof(hypot), z::Complex) +function frule((_, Δz), ::typeof(hypot), z::Number) Ω = hypot(z) ∂Ω = _realconjtimes(z, Δz) / ifelse(iszero(Ω), one(Ω), Ω) return Ω, ∂Ω end -function rrule(::typeof(hypot), z::Complex) +function rrule(::typeof(hypot), z::Number) Ω = hypot(z) function hypot_pullback(ΔΩ) return (NoTangent(), (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index d1839518a..143e120c7 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -64,14 +64,14 @@ let # Unary complex functions ## abs - function frule((_, Δx), ::typeof(abs), x::Union{Real, Complex}) + function frule((_, Δx), ::typeof(abs), x::Number) Ω = abs(x) # `ifelse` is applied only to denominator to ensure type-stability. signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) return Ω, _realconjtimes(signx, Δx) end - function rrule(::typeof(abs), x::Union{Real, Complex}) + function rrule(::typeof(abs), x::Number) Ω = abs(x) function abs_pullback(ΔΩ) signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) @@ -81,11 +81,11 @@ let end ## abs2 - function frule((_, Δz), ::typeof(abs2), z::Union{Real, Complex}) + function frule((_, Δz), ::typeof(abs2), z::Number) return abs2(z), 2 * _realconjtimes(z, Δz) end - function rrule(::typeof(abs2), z::Union{Real, Complex}) + function rrule(::typeof(abs2), z::Number) function abs2_pullback(ΔΩ) Δu = real(ΔΩ) return (NoTangent(), 2Δu*z) @@ -94,10 +94,10 @@ let end ## conj - function frule((_, Δz), ::typeof(conj), z::Union{Real, Complex}) + function frule((_, Δz), ::typeof(conj), z::Number) return conj(z), conj(Δz) end - function rrule(::typeof(conj), z::Union{Real, Complex}) + function rrule(::typeof(conj), z::Number) function conj_pullback(ΔΩ) return (NoTangent(), conj(ΔΩ)) end @@ -143,14 +143,14 @@ let ::typeof(hypot), x::T, y::T, - ) where {T<:Union{Real,Complex}} + ) where {T<:Number} Ω = hypot(x, y) n = ifelse(iszero(Ω), one(Ω), Ω) ∂Ω = (_realconjtimes(x, Δx) + _realconjtimes(y, Δy)) / n return Ω, ∂Ω end - function rrule(::typeof(hypot), x::T, y::T) where {T<:Union{Real,Complex}} + function rrule(::typeof(hypot), x::T, y::T) where {T<:Number} Ω = hypot(x, y) function hypot_pullback(ΔΩ) c = real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 513f8e3fb..5284c9db3 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -99,13 +99,13 @@ function _diagm_back(p, ȳ) return Tangent{typeof(p)}(second = d) end -function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real}) +function rrule(::typeof(*), D::Diagonal{<:Number}, V::AbstractVector{<:Number}) project_D = ProjectTo(D) project_V = ProjectTo(V) function times_pullback(ȳ) Ȳ = unthunk(ȳ) - dD = @thunk(project_D(Diagonal(Ȳ .* V))) - dV = @thunk(project_V(D * Ȳ)) + dD = @thunk(project_D(Diagonal(Ȳ .* conj.(V)))) + dV = @thunk(project_V(conj.(D.diag) .* Ȳ)) return (NoTangent(), dD, dV) end return D * V, times_pullback From 405427f5c304b8a7e78134b67422223ff8403725 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 13 Oct 2021 22:27:14 +0200 Subject: [PATCH 04/23] Constrain types where necessary --- src/rulesets/Base/base.jl | 103 +++++++++++++++-------- src/rulesets/Base/fastmath_able.jl | 88 ++++++++++++------- src/rulesets/LinearAlgebra/structured.jl | 6 +- 3 files changed, 127 insertions(+), 70 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 8b9648ecd..1c6b7d036 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -90,8 +90,24 @@ function rrule(::typeof(hypot), z::Number) return (Ω, hypot_pullback) end -@scalar_rule fma(x, y, z) (y, x, true) -@scalar_rule muladd(x, y, z) (y, x, true) +@scalar_rule fma(x, y::CommutativeMulNumber, z) (y, x, true) +function frule((_, Δx, Δy, Δz), ::typeof(fma), x::Number, y::Number, z::Number) + return fma(x, y, z), muladd(Δx, y, muladd(x, Δy, Δz)) +end +function rrule(::typeof(fma), x::Number, y::Number, z::Number) + projectx, projecty, projectz = ProjectTo(x), ProjectTo(y), ProjectTo(z) + fma_pullback(ΔΩ) = NoTangent(), projectx(ΔΩ * y'), projecty(x' * ΔΩ), projectz(ΔΩ) + fma(x, y, z), fma_pullback +end +@scalar_rule muladd(x, y::CommutativeMulNumber, z) (y, x, true) +function frule((_, Δx, Δy, Δz), ::typeof(muladd), x::Number, y::Number, z::Number) + return muladd(x, y, z), muladd(Δx, y, muladd(x, Δy, Δz)) +end +function rrule(::typeof(muladd), x::Number, y::Number, z::Number) + projectx, projecty, projectz = ProjectTo(x), ProjectTo(y), ProjectTo(z) + muladd_pullback(ΔΩ) = NoTangent(), projectx(ΔΩ * y'), projecty(x' * ΔΩ), projectz(ΔΩ) + muladd(x, y, z), muladd_pullback +end @scalar_rule rem2pi(x, r::RoundingMode) (true, NoTangent()) @scalar_rule( mod(x, y), @@ -105,51 +121,51 @@ end @scalar_rule(ldexp(x, y), (2^y, NoTangent())) # Can't multiply though sqrt in acosh because of negative complex case for x -@scalar_rule acosh(x) inv(sqrt(x - 1) * sqrt(x + 1)) -@scalar_rule acoth(x) inv(1 - x ^ 2) -@scalar_rule acsch(x) -(inv(x ^ 2 * sqrt(1 + x ^ -2))) +@scalar_rule acosh(x::CommutativeMulNumber) inv(sqrt(x - 1) * sqrt(x + 1)) +@scalar_rule acoth(x::CommutativeMulNumber) inv(1 - x ^ 2) +@scalar_rule acsch(x::CommutativeMulNumber) -(inv(x ^ 2 * sqrt(1 + x ^ -2))) @scalar_rule acsch(x::Real) -(inv(abs(x) * sqrt(1 + x ^ 2))) -@scalar_rule asech(x) -(inv(x * sqrt(1 - x ^ 2))) -@scalar_rule asinh(x) inv(sqrt(x ^ 2 + 1)) -@scalar_rule atanh(x) inv(1 - x ^ 2) +@scalar_rule asech(x::CommutativeMulNumber) -(inv(x * sqrt(1 - x ^ 2))) +@scalar_rule asinh(x::CommutativeMulNumber) inv(sqrt(x ^ 2 + 1)) +@scalar_rule atanh(x::CommutativeMulNumber) inv(1 - x ^ 2) -@scalar_rule acosd(x) (-(oftype(x, 180)) / π) / sqrt(1 - x ^ 2) -@scalar_rule acotd(x) (-(oftype(x, 180)) / π) / (1 + x ^ 2) -@scalar_rule acscd(x) ((-(oftype(x, 180)) / π) / x ^ 2) / sqrt(1 - x ^ -2) +@scalar_rule acosd(x::CommutativeMulNumber) (-(oftype(x, 180)) / π) / sqrt(1 - x ^ 2) +@scalar_rule acotd(x::CommutativeMulNumber) (-(oftype(x, 180)) / π) / (1 + x ^ 2) +@scalar_rule acscd(x::CommutativeMulNumber) ((-(oftype(x, 180)) / π) / x ^ 2) / sqrt(1 - x ^ -2) @scalar_rule acscd(x::Real) ((-(oftype(x, 180)) / π) / abs(x)) / sqrt(x ^ 2 - 1) -@scalar_rule asecd(x) ((oftype(x, 180) / π) / x ^ 2) / sqrt(1 - x ^ -2) +@scalar_rule asecd(x::CommutativeMulNumber) ((oftype(x, 180) / π) / x ^ 2) / sqrt(1 - x ^ -2) @scalar_rule asecd(x::Real) ((oftype(x, 180) / π) / abs(x)) / sqrt(x ^ 2 - 1) -@scalar_rule asind(x) (oftype(x, 180) / π) / sqrt(1 - x ^ 2) -@scalar_rule atand(x) (oftype(x, 180) / π) / (1 + x ^ 2) - -@scalar_rule cot(x) -((1 + Ω ^ 2)) -@scalar_rule coth(x) -(csch(x) ^ 2) -@scalar_rule cotd(x) -(π / oftype(x, 180)) * (1 + Ω ^ 2) -@scalar_rule csc(x) -Ω * cot(x) -@scalar_rule cscd(x) -(π / oftype(x, 180)) * Ω * cotd(x) -@scalar_rule csch(x) -(coth(x)) * Ω -@scalar_rule sec(x) Ω * tan(x) -@scalar_rule secd(x) (π / oftype(x, 180)) * Ω * tand(x) -@scalar_rule sech(x) -(tanh(x)) * Ω - -@scalar_rule acot(x) -(inv(1 + x ^ 2)) -@scalar_rule acsc(x) -(inv(x ^ 2 * sqrt(1 - x ^ -2))) +@scalar_rule asind(x::CommutativeMulNumber) (oftype(x, 180) / π) / sqrt(1 - x ^ 2) +@scalar_rule atand(x::CommutativeMulNumber) (oftype(x, 180) / π) / (1 + x ^ 2) + +@scalar_rule cot(x::CommutativeMulNumber) -((1 + Ω ^ 2)) +@scalar_rule coth(x::CommutativeMulNumber) -(csch(x) ^ 2) +@scalar_rule cotd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * (1 + Ω ^ 2) +@scalar_rule csc(x::CommutativeMulNumber) -Ω * cot(x) +@scalar_rule cscd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * Ω * cotd(x) +@scalar_rule csch(x::CommutativeMulNumber) -(coth(x)) * Ω +@scalar_rule sec(x::CommutativeMulNumber) Ω * tan(x) +@scalar_rule secd(x::CommutativeMulNumber) (π / oftype(x, 180)) * Ω * tand(x) +@scalar_rule sech(x::CommutativeMulNumber) -(tanh(x)) * Ω + +@scalar_rule acot(x::CommutativeMulNumber) -(inv(1 + x ^ 2)) +@scalar_rule acsc(x::CommutativeMulNumber) -(inv(x ^ 2 * sqrt(1 - x ^ -2))) @scalar_rule acsc(x::Real) -(inv(abs(x) * sqrt(x ^ 2 - 1))) -@scalar_rule asec(x) inv(x ^ 2 * sqrt(1 - x ^ -2)) +@scalar_rule asec(x::CommutativeMulNumber) inv(x ^ 2 * sqrt(1 - x ^ -2)) @scalar_rule asec(x::Real) inv(abs(x) * sqrt(x ^ 2 - 1)) -@scalar_rule cosd(x) -(π / oftype(x, 180)) * sind(x) -@scalar_rule cospi(x) -π * sinpi(x) -@scalar_rule sind(x) (π / oftype(x, 180)) * cosd(x) -@scalar_rule sinpi(x) π * cospi(x) -@scalar_rule tand(x) (π / oftype(x, 180)) * (1 + Ω ^ 2) +@scalar_rule cosd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * sind(x) +@scalar_rule cospi(x::CommutativeMulNumber) -π * sinpi(x) +@scalar_rule sind(x::CommutativeMulNumber) (π / oftype(x, 180)) * cosd(x) +@scalar_rule sinpi(x::CommutativeMulNumber) π * cospi(x) +@scalar_rule tand(x::CommutativeMulNumber) (π / oftype(x, 180)) * (1 + Ω ^ 2) -@scalar_rule sinc(x) cosc(x) +@scalar_rule sinc(x::CommutativeMulNumber) cosc(x) # the position of the minus sign below warrants the correct type for π if VERSION ≥ v"1.6" - @scalar_rule sincospi(x) @setup((sinpix, cospix) = Ω) (π * cospix) (π * (-sinpix)) + @scalar_rule sincospi(x::CommutativeMulNumber) @setup((sinpix, cospix) = Ω) (π * cospix) (π * (-sinpix)) end @scalar_rule( @@ -160,7 +176,22 @@ end ), (!(islow | ishigh), islow, ishigh), ) -@scalar_rule x \ y (-(Ω / x), one(y) / x) + +@scalar_rule x::CommutativeMulNumber \ y::CommutativeMulNumber (-(x \ Ω), x \ one(y)) +function frule((_, Δx, Δy), ::typeof(\), x::Number, y::Number) + Ω = x \ y + return Ω, x \ muladd(Δy, -Δx, Ω) +end +function rrule(::typeof(\), x::Number, y::Number) + Ω = x \ y + project_x = ProjectTo(x) + project_y = ProjectTo(y) + function backslash_pullback(ΔΩ) + ∂x = x' \ ΔΩ + return NoTangent(), project_x(∂x), project_y(-∂x * Ω') + end + return Ω, backslash_pullback +end function frule((_, ẏ), ::typeof(identity), x) return (x, ẏ) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 143e120c7..1efc4f3f4 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -9,58 +9,71 @@ let ## for the rules for `sin` and `cos` ## See issue: https://github.com/JuliaDiff/ChainRules.jl/issues/291 ## sin - function rrule(::typeof(sin), x::Number) + function rrule(::typeof(sin), x::CommutativeMulNumber) sinx, cosx = sincos(x) sin_pullback(Δy) = (NoTangent(), cosx' * Δy) return (sinx, sin_pullback) end - function frule((_, Δx), ::typeof(sin), x::Number) + function frule((_, Δx), ::typeof(sin), x::CommutativeMulNumber) sinx, cosx = sincos(x) return (sinx, cosx * Δx) end ## cos - function rrule(::typeof(cos), x::Number) + function rrule(::typeof(cos), x::CommutativeMulNumber) sinx, cosx = sincos(x) cos_pullback(Δy) = (NoTangent(), -sinx' * Δy) return (cosx, cos_pullback) end - function frule((_, Δx), ::typeof(cos), x::Number) + function frule((_, Δx), ::typeof(cos), x::CommutativeMulNumber) sinx, cosx = sincos(x) return (cosx, -sinx * Δx) end - @scalar_rule tan(x) 1 + Ω ^ 2 + @scalar_rule tan(x::CommutativeMulNumber) 1 + Ω ^ 2 # Trig-Hyperbolic - @scalar_rule cosh(x) sinh(x) - @scalar_rule sinh(x) cosh(x) - @scalar_rule tanh(x) 1 - Ω ^ 2 + @scalar_rule cosh(x::CommutativeMulNumber) sinh(x) + @scalar_rule sinh(x::CommutativeMulNumber) cosh(x) + @scalar_rule tanh(x::CommutativeMulNumber) 1 - Ω ^ 2 # Trig- Inverses - @scalar_rule acos(x) -(inv(sqrt(1 - x ^ 2))) - @scalar_rule asin(x) inv(sqrt(1 - x ^ 2)) - @scalar_rule atan(x) inv(1 + x ^ 2) + @scalar_rule acos(x::CommutativeMulNumber) -(inv(sqrt(1 - x ^ 2))) + @scalar_rule asin(x::CommutativeMulNumber) inv(sqrt(1 - x ^ 2)) + @scalar_rule atan(x::CommutativeMulNumber) inv(1 + x ^ 2) # Trig-Multivariate - @scalar_rule atan(y, x) @setup(u = x ^ 2 + y ^ 2) (x / u, -y / u) - @scalar_rule sincos(x) @setup((sinx, cosx) = Ω) cosx -sinx + @scalar_rule atan(y::Real, x::Real) @setup(u = x ^ 2 + y ^ 2) (x / u, -y / u) + @scalar_rule sincos(x::CommutativeMulNumber) @setup((sinx, cosx) = Ω) cosx -sinx # exponents - @scalar_rule cbrt(x) inv(3 * Ω ^ 2) - @scalar_rule inv(x) -(Ω ^ 2) - @scalar_rule sqrt(x) inv(2Ω) # gradient +Inf at x==0 - @scalar_rule exp(x) Ω - @scalar_rule exp10(x) Ω * log(oftype(x, 10)) - @scalar_rule exp2(x) Ω * log(oftype(x, 2)) - @scalar_rule expm1(x) exp(x) - @scalar_rule log(x) inv(x) - @scalar_rule log10(x) inv(x) / log(oftype(x, 10)) - @scalar_rule log1p(x) inv(x + 1) - @scalar_rule log2(x) inv(x) / log(oftype(x, 2)) + @scalar_rule cbrt(x::CommutativeMulNumber) inv(3 * Ω ^ 2) + @scalar_rule inv(x::CommutativeMulNumber) -(Ω ^ 2) + function frule((_, Δx), ::typeof(inv), x::Number) + Ω = inv(x) + return Ω, Ω * -Δx * Ω + end + function rrule(::typeof(inv), x::Number) + Ω = inv(x) + project_x = ProjectTo(x) + function inv_pullback(ΔΩ) + Ω′ = conj(Ω) + return NoTangent(), project_x(Ω′ * -ΔΩ * Ω′) + end + return Ω, inv_pullback + end + @scalar_rule sqrt(x::CommutativeMulNumber) inv(2Ω) # gradient +Inf at x==0 + @scalar_rule exp(x::CommutativeMulNumber) Ω + @scalar_rule exp10(x::CommutativeMulNumber) Ω * log(oftype(x, 10)) + @scalar_rule exp2(x::CommutativeMulNumber) Ω * log(oftype(x, 2)) + @scalar_rule expm1(x::CommutativeMulNumber) exp(x) + @scalar_rule log(x::CommutativeMulNumber) inv(x) + @scalar_rule log10(x::CommutativeMulNumber) inv(x) / log(oftype(x, 10)) + @scalar_rule log1p(x::CommutativeMulNumber) inv(x + 1) + @scalar_rule log2(x::CommutativeMulNumber) inv(x) / log(oftype(x, 2)) # Unary complex functions ## abs @@ -105,7 +118,7 @@ let end ## angle - function frule((_, Δx), ::typeof(angle), x) + function frule((_, Δx), ::typeof(angle), x::Union{Real,Complex}) Ω = angle(x) # `ifelse` is applied only to denominator to ensure type-stability. n = ifelse(iszero(x), one(real(x)), abs2(x)) @@ -161,11 +174,24 @@ let @scalar_rule x + y (true, true) @scalar_rule x - y (true, -1) - @scalar_rule x / y (one(x) / y, -(Ω / y)) - + @scalar_rule x / y::CommutativeMulNumber (one(x) / y, -(Ω / y)) + function frule((_, Δx, Δy), ::typeof(/), x::Number, y::Number) + Ω = x / y + return Ω, muladd(Δx, Ω, -Δy) / y + end + function rrule(::typeof(/), x::Number, y::Number) + Ω = x / y + project_x = ProjectTo(x) + project_y = ProjectTo(y) + function slash_pullback(ΔΩ) + ∂x = ΔΩ / y' + return NoTangent(), project_x(∂x), project_y(Ω' * -∂x) + end + return Ω, slash_pullback + end ## power # literal_pow is in base.jl - function frule((_, Δx, Δp), ::typeof(^), x::Number, p::Number) + function frule((_, Δx, Δp), ::typeof(^), x::CommutativeMulNumber, p::CommutativeMulNumber) y = x ^ p _dx = _pow_grad_x(x, p, float(y)) if iszero(Δp) @@ -178,7 +204,7 @@ let end end - function rrule(::typeof(^), x::Number, p::Number) + function rrule(::typeof(^), x::CommutativeMulNumber, p::CommutativeMulNumber) y = x^p project_x = ProjectTo(x) project_p = ProjectTo(p) @@ -209,14 +235,14 @@ let @scalar_rule -x -1 ## `sign` - function frule((_, Δx), ::typeof(sign), x) + function frule((_, Δx), ::typeof(sign), x::Number) n = ifelse(iszero(x), one(real(x)), abs(x)) Ω = x isa Real ? sign(x) : x / n ∂Ω = Ω * (_imagconjtimes(Ω, Δx) / n) * im return Ω, ∂Ω end - function rrule(::typeof(sign), x) + function rrule(::typeof(sign), x::Number) n = ifelse(iszero(x), one(real(x)), abs(x)) Ω = x isa Real ? sign(x) : x / n function sign_pullback(ΔΩ) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 5284c9db3..387b71f24 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -38,7 +38,7 @@ end _diagview(x::Diagonal) = x.diag _diagview(x::AbstractMatrix) = view(x, diagind(x)) _diagview(x::Tangent{<:Diagonal}) = x.diag -function ChainRulesCore.rrule(::typeof(sqrt), d::Diagonal) +function ChainRulesCore.rrule(::typeof(sqrt), d::Diagonal{<:CommutativeMulNumber}) y = sqrt(d) @assert y isa Diagonal function sqrt_pullback(Δ) @@ -250,7 +250,7 @@ end _diag_view(X) = view(X, diagind(X)) _diag_view(X::Diagonal) = parent(X) #Diagonal wraps a Vector of just Diagonal elements -function rrule(::typeof(det), X::Union{Diagonal, AbstractTriangular}) +function rrule(::typeof(det), X::Union{Diagonal{T}, AbstractTriangular{T}}) where {T<:CommutativeMulNumber} y = det(X) s = conj!(y ./ _diag_view(X)) function det_pullback(ȳ) @@ -259,7 +259,7 @@ function rrule(::typeof(det), X::Union{Diagonal, AbstractTriangular}) return y, det_pullback end -function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular}) +function rrule(::typeof(logdet), X::Union{Diagonal{T}, AbstractTriangular{T}}) where {T<:CommutativeMulNumber} y = logdet(X) s = conj!(one(eltype(X)) ./ _diag_view(X)) function logdet_pullback(ȳ) From b4e90dc6e83d9f0c6558f3ccb776159f34c6884e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 13 Oct 2021 22:27:36 +0200 Subject: [PATCH 05/23] Add missing projectors --- src/rulesets/Base/fastmath_able.jl | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 1efc4f3f4..1154311bd 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -86,9 +86,10 @@ let function rrule(::typeof(abs), x::Number) Ω = abs(x) + project_x = ProjectTo(x) function abs_pullback(ΔΩ) signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) - return (NoTangent(), signx * real(ΔΩ)) + return (NoTangent(), project_x(signx * real(ΔΩ))) end return Ω, abs_pullback end @@ -99,9 +100,10 @@ let end function rrule(::typeof(abs2), z::Number) + project_z = ProjectTo(z) function abs2_pullback(ΔΩ) Δu = real(ΔΩ) - return (NoTangent(), 2Δu*z) + return (NoTangent(), project_z(2Δu * z)) end return abs2(z), abs2_pullback end @@ -111,8 +113,9 @@ let return conj(z), conj(Δz) end function rrule(::typeof(conj), z::Number) + project_z = ProjectTo(z) function conj_pullback(ΔΩ) - return (NoTangent(), conj(ΔΩ)) + return (NoTangent(), project_z(conj(ΔΩ))) end return conj(z), conj_pullback end @@ -165,9 +168,11 @@ let function rrule(::typeof(hypot), x::T, y::T) where {T<:Number} Ω = hypot(x, y) + project_x = ProjectTo(x) + project_y = ProjectTo(y) function hypot_pullback(ΔΩ) c = real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω) - return (NoTangent(), c * x, c * y) + return (NoTangent(), project_x(c * x), project_y(c * y)) end return (Ω, hypot_pullback) end @@ -245,9 +250,10 @@ let function rrule(::typeof(sign), x::Number) n = ifelse(iszero(x), one(real(x)), abs(x)) Ω = x isa Real ? sign(x) : x / n + project_x = ProjectTo(x) function sign_pullback(ΔΩ) ∂x = Ω * (_imagconjtimes(Ω, ΔΩ) / n) * im - return (NoTangent(), ∂x) + return (NoTangent(), project_x(∂x)) end return Ω, sign_pullback end From 723ae92501306d2e06dab6f6ce7575eed5dd5e5c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 13 Oct 2021 22:27:50 +0200 Subject: [PATCH 06/23] Remove unneeded broadcasting and comments --- src/rulesets/Base/fastmath_able.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 1154311bd..3b31a9746 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -258,12 +258,11 @@ let return Ω, sign_pullback end - # product rule requires special care for arguments where `mul` is non-commutative function frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number) - # Optimized version of `Δx .* y .+ x .* Δy`. Also, it is potentially more + # Optimized version of `Δx * y + x * Δy`. Also, it is potentially more # accurate on machines with FMA instructions, since there are only two # rounding operations, one in `muladd/fma` and the other in `*`. - ∂xy = muladd.(Δx, y, x .* Δy) + ∂xy = muladd(Δx, y, x * Δy) return x * y, ∂xy end From f7293d2cb64340869f588058de1d99b86c1b71b4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 13 Oct 2021 22:28:16 +0200 Subject: [PATCH 07/23] Small speedup when cotangent is zero --- src/rulesets/Base/arraymath.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 9dfdce637..a50be8e63 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -10,7 +10,7 @@ end function rrule(::typeof(inv), x::AbstractArray) Ω = inv(x) function inv_pullback(ΔΩ) - return NoTangent(), -Ω' * ΔΩ * Ω' + return NoTangent(), Ω' * -ΔΩ * Ω' end return Ω, inv_pullback end From 41a7d27003bc95318680a92cadcc4fe510af70e8 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 13 Oct 2021 22:28:26 +0200 Subject: [PATCH 08/23] Generalize cross --- src/rulesets/LinearAlgebra/dense.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 3cb362900..cece8d7e7 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -67,15 +67,14 @@ function frule((_, Δa, Δb), ::typeof(cross), a::AbstractVector, b::AbstractVec return cross(a, b), cross(Δa, b) .+ cross(a, Δb) end -# TODO: support complex vectors -function rrule(::typeof(cross), a::AbstractVector{<:Real}, b::AbstractVector{<:Real}) +function rrule(::typeof(cross), a::AbstractVector, b::AbstractVector) project_a = ProjectTo(a) project_b = ProjectTo(b) Ω = cross(a, b) function cross_pullback(Ω̄) ΔΩ = unthunk(Ω̄) - da = @thunk(project_a(cross(b, ΔΩ))) - db = @thunk(project_b(cross(ΔΩ, a))) + da = @thunk(project_a(eltype(b) <: Real ? cross(b, ΔΩ) : -cross(ΔΩ, vec(b')))) + db = @thunk(project_b(eltype(a) <: Real ? cross(ΔΩ, a) : -cross(vec(a'), ΔΩ))) return (NoTangent(), da, db) end return Ω, cross_pullback From 0cb446fd5aaa0f507907a56abfb2abe441d4cdd8 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 12 Feb 2022 11:42:12 +0100 Subject: [PATCH 09/23] Add some projectors --- src/rulesets/Base/fastmath_able.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 8d3a0c9d6..8630dbbfe 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -11,7 +11,8 @@ let ## sin function rrule(::typeof(sin), x::CommutativeMulNumber) sinx, cosx = sincos(x) - sin_pullback(Δy) = (NoTangent(), cosx' * Δy) + project_x = ProjectTo(x) + sin_pullback(Δy) = (NoTangent(), project_x(cosx' * Δy)) return (sinx, sin_pullback) end @@ -23,7 +24,8 @@ let ## cos function rrule(::typeof(cos), x::CommutativeMulNumber) sinx, cosx = sincos(x) - cos_pullback(Δy) = (NoTangent(), -sinx' * Δy) + project_x = ProjectTo(x) + cos_pullback(Δy) = (NoTangent(), -project_x(sinx' * Δy)) return (cosx, cos_pullback) end @@ -61,7 +63,7 @@ let project_x = ProjectTo(x) function inv_pullback(ΔΩ) Ω′ = conj(Ω) - return NoTangent(), project_x(Ω′ * -ΔΩ * Ω′) + return NoTangent(), -project_x(Ω′ * ΔΩ * Ω′) end return Ω, inv_pullback end From 964e42438d1f5f0f885d15aaae2ea5b6922d59fe Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 12 Feb 2022 22:16:12 +0100 Subject: [PATCH 10/23] More generalizations --- src/rulesets/Base/arraymath.jl | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 2c4bd6bfe..8054898ac 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -353,14 +353,16 @@ end ##### -function frule((_, ΔA, Δb), ::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber) - return A/b, ΔA/b - A*(Δb/b^2) +function frule((_, ΔA, Δb), ::typeof(/), A::AbstractArray{<:Number}, b::Number) + Y = A / b + return Y, muladd(Y, -Δb, ΔA) / b end -function frule((_, Δa, ΔB), ::typeof(\), a::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber}) - return B/a, ΔB/a - B*(Δa/a^2) +function frule((_, Δa, ΔB), ::typeof(\), a::Number, B::AbstractArray{<:Number}) + Y = a \ B + return Y, a \ muladd(-Δa, Y, ΔB) end -function rrule(::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber) +function rrule(::typeof(/), A::AbstractArray{<:Number}, b::Number) Y = A/b function slash_pullback_scalar(ȳ) Ȳ = unthunk(ȳ) @@ -385,7 +387,9 @@ function rrule(::typeof(\), b::Number, A::AbstractArray{<:Number}) bthunk = if eltype(Y) isa CommutativeMulNumber @thunk(-conj(b) \ dot(Y, Ȳ)) else - @thunk(-conj(b) \ dot(Ȳ',Y')) + # NOTE: dot(Ȳ', Y') currently incorrect for non-commutative numbers + # https://github.com/JuliaLang/julia/issues/44152 + @thunk(-conj(b) \ dot(conj(Ȳ), conj(Y))) end return (NoTangent(), Athunk, bthunk) end From d1e8c2a33d836155b32741ec96eb9bf8b10701ab Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 12 Feb 2022 22:20:44 +0100 Subject: [PATCH 11/23] Revert "Add some projectors" This reverts commit 0cb446fd5aaa0f507907a56abfb2abe441d4cdd8. --- src/rulesets/Base/fastmath_able.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 8630dbbfe..8d3a0c9d6 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -11,8 +11,7 @@ let ## sin function rrule(::typeof(sin), x::CommutativeMulNumber) sinx, cosx = sincos(x) - project_x = ProjectTo(x) - sin_pullback(Δy) = (NoTangent(), project_x(cosx' * Δy)) + sin_pullback(Δy) = (NoTangent(), cosx' * Δy) return (sinx, sin_pullback) end @@ -24,8 +23,7 @@ let ## cos function rrule(::typeof(cos), x::CommutativeMulNumber) sinx, cosx = sincos(x) - project_x = ProjectTo(x) - cos_pullback(Δy) = (NoTangent(), -project_x(sinx' * Δy)) + cos_pullback(Δy) = (NoTangent(), -sinx' * Δy) return (cosx, cos_pullback) end @@ -63,7 +61,7 @@ let project_x = ProjectTo(x) function inv_pullback(ΔΩ) Ω′ = conj(Ω) - return NoTangent(), -project_x(Ω′ * ΔΩ * Ω′) + return NoTangent(), project_x(Ω′ * -ΔΩ * Ω′) end return Ω, inv_pullback end From 13d0c8403ccd9edc444c9396fbdf2a73d5575b3b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 15 Feb 2022 23:18:57 +0100 Subject: [PATCH 12/23] Remove projector --- src/rulesets/Base/fastmath_able.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 8d3a0c9d6..7dc0f2e4c 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -250,10 +250,9 @@ let function rrule(::typeof(sign), x::Number) n = ifelse(iszero(x), one(real(x)), abs(x)) Ω = x isa Real ? sign(x) : x / n - project_x = ProjectTo(x) function sign_pullback(ΔΩ) ∂x = Ω * (_imagconjtimes(Ω, ΔΩ) / n) * im - return (NoTangent(), project_x(∂x)) + return (NoTangent(), ∂x) end return Ω, sign_pullback end From 55abdeb2ec8aa1969634a8b229e1b3165647a89c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 16 Feb 2022 00:07:10 +0100 Subject: [PATCH 13/23] Fix return arguments --- src/rulesets/Base/arraymath.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 8054898ac..7499792ba 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -391,7 +391,7 @@ function rrule(::typeof(\), b::Number, A::AbstractArray{<:Number}) # https://github.com/JuliaLang/julia/issues/44152 @thunk(-conj(b) \ dot(conj(Ȳ), conj(Y))) end - return (NoTangent(), Athunk, bthunk) + return (NoTangent(), bthunk, Athunk) end return Y, backslash_pullback end From c5f032d02c24fdb7b29d1e6e1591a81c16a4c237 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 16 Feb 2022 00:07:32 +0100 Subject: [PATCH 14/23] Fix argument order --- src/rulesets/Base/arraymath.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 7499792ba..64c1e1641 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -381,8 +381,8 @@ function rrule(::typeof(\), b::Number, A::AbstractArray{<:Number}) function backslash_pullback(ȳ) Ȳ = unthunk(ȳ) Athunk = InplaceableThunk( - @thunk(conj(b) \ Ȳ), dA -> dA .+= conj(b) .\ Ȳ, + @thunk(conj(b) \ Ȳ), ) bthunk = if eltype(Y) isa CommutativeMulNumber @thunk(-conj(b) \ dot(Y, Ȳ)) From 057d7d69e8100662e616fe70d9d5292eb4bbc944 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 16 Feb 2022 00:07:51 +0100 Subject: [PATCH 15/23] Undo changes to rrule --- src/rulesets/Base/arraymath.jl | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 64c1e1641..3272c39a1 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -309,16 +309,19 @@ end ##### function rrule(::typeof(/), A::AbstractVecOrMat{<:Number}, B::AbstractVecOrMat{<:Number}) - C = A / B - project_A = ProjectTo(A) - project_B = ProjectTo(B) - function slash_pullback(ΔC) - ∂A = unthunk(ΔC) / B' - ∂B = InplaceableThunk( - @thunk(C' * -∂A), - B̄ -> mul!(B̄, C', ∂A, -1, true) - ) - return NoTangent(), project_A(∂A), project_B(∂B) + Aᵀ, dA_pb = rrule(adjoint, A) + Bᵀ, dB_pb = rrule(adjoint, B) + Cᵀ, dS_pb = rrule(\, Bᵀ, Aᵀ) + C, dC_pb = rrule(adjoint, Cᵀ) + function slash_pullback(Ȳ) + # Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want + _, dC = dC_pb(Ȳ) + _, dBᵀ, dAᵀ = dS_pb(unthunk(dC)) + + ∂A = last(dA_pb(unthunk(dAᵀ))) + ∂B = last(dB_pb(unthunk(dBᵀ))) + + (NoTangent(), ∂A, ∂B) end return C, slash_pullback end From ed16465b4c54b93efd2137610fbc9417c7f7102b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 16 Feb 2022 00:08:07 +0100 Subject: [PATCH 16/23] Work around Julia issue --- src/rulesets/Base/arraymath.jl | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 3272c39a1..5bb12466b 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -98,7 +98,16 @@ function rrule( Ȳ = unthunk(ȳ) return ( NoTangent(), - @thunk(project_A(eltype(B) isa CommutativeMulNumber ? dot(Ȳ, B)' : dot(Ȳ', B'))), + Thunk() do + if eltype(B) isa CommutativeMulNumber + project_A(dot(Ȳ, B)') + elseif ndims(B) < 3 + # https://github.com/JuliaLang/julia/issues/44152 + project_A(dot(conj(Ȳ), conj(B))) + else + project_A(dot(conj(vec(Ȳ)), conj(vec(B)))) + end + end, InplaceableThunk( X̄ -> mul!(X̄, conj(A), Ȳ, true, true), @thunk(project_B(A' * Ȳ)), @@ -121,7 +130,17 @@ function rrule( X̄ -> mul!(X̄, conj(A), Ȳ, true, true), @thunk(project_B(A' * Ȳ)), ), - @thunk(project_A(eltype(A) isa CommutativeMulNumber ? dot(Ȳ, B)' : dot(Ȳ', B'))), + # @thunk(project_A(eltype(A) isa CommutativeMulNumber ? dot(Ȳ, B)' : dot(Ȳ', B'))), + Thunk() do + if eltype(B) isa CommutativeMulNumber + project_A(dot(Ȳ, B)') + elseif ndims(B) < 3 + # https://github.com/JuliaLang/julia/issues/44152 + project_A(dot(conj(Ȳ), conj(B))) + else + project_A(dot(conj(vec(Ȳ)), conj(vec(B)))) + end + end, ) end return A * B, times_pullback From 8fd8c419d17325ca986260056b5fa2eada5eda66 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 16 Feb 2022 00:08:29 +0100 Subject: [PATCH 17/23] Relax constraints --- src/rulesets/Base/arraymath.jl | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 5bb12466b..77c868abc 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -23,11 +23,7 @@ frule((_, ΔA, ΔB), ::typeof(*), A, B) = A * B, muladd(ΔA, B, A * ΔB) frule((_, ΔA, ΔB, ΔC), ::typeof(*), A, B, C) = A*B*C, ΔA*B*C + A*ΔB*C + A*B*ΔC -function rrule( - ::typeof(*), - A::AbstractVecOrMat{<:CommutativeMulNumber}, - B::AbstractVecOrMat{<:CommutativeMulNumber}, -) +function rrule(::typeof(*), A::AbstractVecOrMat{<:Number}, B::AbstractVecOrMat{<:Number}) project_A = ProjectTo(A) project_B = ProjectTo(B) function times_pullback(ȳ) @@ -45,8 +41,8 @@ end # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/411 function rrule( ::typeof(*), - A::StridedMatrix{<:CommutativeMulNumber}, - B::StridedVecOrMat{<:CommutativeMulNumber}, + A::StridedMatrix{<:Number}, + B::StridedVecOrMat{<:Number}, ) function times_pullback(ȳ) Ȳ = unthunk(ȳ) @@ -90,7 +86,7 @@ end ##### function rrule( - ::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber} + ::typeof(*), A::Number, B::AbstractArray{<:Number} ) project_A = ProjectTo(A) project_B = ProjectTo(B) @@ -349,7 +345,7 @@ end ##### `\` ##### -function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) +function rrule(::typeof(\), A::AbstractVecOrMat{<:Number}, B::AbstractVecOrMat{<:Number}) project_A = ProjectTo(A) project_B = ProjectTo(B) From 9237994229862c56d79d70bee80ab8222554fc95 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 16 Feb 2022 10:22:33 +0100 Subject: [PATCH 18/23] Remove added projectors They cause tests to fail and should be added in a separate PR --- src/rulesets/Base/fastmath_able.jl | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 7dc0f2e4c..ac63632c8 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -86,10 +86,9 @@ let function rrule(::typeof(abs), x::Number) Ω = abs(x) - project_x = ProjectTo(x) function abs_pullback(ΔΩ) signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) - return (NoTangent(), project_x(signx * real(ΔΩ))) + return (NoTangent(), signx * real(ΔΩ)) end return Ω, abs_pullback end @@ -100,10 +99,9 @@ let end function rrule(::typeof(abs2), z::Number) - project_z = ProjectTo(z) function abs2_pullback(ΔΩ) Δu = real(ΔΩ) - return (NoTangent(), project_z(2Δu * z)) + return (NoTangent(), 2Δu * z) end return abs2(z), abs2_pullback end @@ -113,9 +111,8 @@ let return conj(z), conj(Δz) end function rrule(::typeof(conj), z::Number) - project_z = ProjectTo(z) function conj_pullback(ΔΩ) - return (NoTangent(), project_z(conj(ΔΩ))) + return (NoTangent(), conj(ΔΩ)) end return conj(z), conj_pullback end @@ -168,11 +165,9 @@ let function rrule(::typeof(hypot), x::T, y::T) where {T<:Number} Ω = hypot(x, y) - project_x = ProjectTo(x) - project_y = ProjectTo(y) function hypot_pullback(ΔΩ) c = real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω) - return (NoTangent(), project_x(c * x), project_y(c * y)) + return (NoTangent(), c * x, c * y) end return (Ω, hypot_pullback) end From 0e6f76b36310923194fd3409ecdc1487b52a900d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 17 Feb 2022 00:21:46 +0100 Subject: [PATCH 19/23] Fix slash rules --- src/rulesets/Base/base.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 13ec203a6..668cf9e28 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -178,15 +178,15 @@ end @scalar_rule x::CommutativeMulNumber \ y::CommutativeMulNumber (-(x \ Ω), x \ one(y)) function frule((_, Δx, Δy), ::typeof(\), x::Number, y::Number) Ω = x \ y - return Ω, x \ muladd(Δy, -Δx, Ω) + return Ω, x \ muladd(-Δx, Ω, Δy) end function rrule(::typeof(\), x::Number, y::Number) Ω = x \ y project_x = ProjectTo(x) project_y = ProjectTo(y) function backslash_pullback(ΔΩ) - ∂x = x' \ ΔΩ - return NoTangent(), project_x(∂x), project_y(-∂x * Ω') + ∂y = x' \ ΔΩ + return NoTangent(), project_x(-∂y * Ω'), project_y(∂y) end return Ω, backslash_pullback end From 8be7c70f5965c6d512659387bb0f681090a290c9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 17 Feb 2022 01:04:34 +0100 Subject: [PATCH 20/23] Repair rules for non-commutative mul inputs --- src/rulesets/Base/fastmath_able.jl | 10 ++++++---- src/rulesets/LinearAlgebra/dense.jl | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index ac63632c8..beb01b709 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -238,7 +238,8 @@ let function frule((_, Δx), ::typeof(sign), x::Number) n = ifelse(iszero(x), one(real(x)), abs(x)) Ω = x isa Real ? sign(x) : x / n - ∂Ω = Ω * (_imagconjtimes(Ω, Δx) / n) * im + d = dot(Ω, Δx) + ∂Ω = Ω * ((d - real(d)) / n) return Ω, ∂Ω end @@ -246,7 +247,8 @@ let n = ifelse(iszero(x), one(real(x)), abs(x)) Ω = x isa Real ? sign(x) : x / n function sign_pullback(ΔΩ) - ∂x = Ω * (_imagconjtimes(Ω, ΔΩ) / n) * im + d = dot(Ω, ΔΩ) + ∂x = Ω * ((d - real(d)) / n) return (NoTangent(), ∂x) end return Ω, sign_pullback @@ -275,9 +277,9 @@ let ΔΩ = unthunk(Ω̇) return ( NoTangent(), - ProjectTo(x)(ΔΩ * y' * z'), + ProjectTo(x)(ΔΩ * z' * y'), ProjectTo(y)(x' * ΔΩ * z'), - ProjectTo(z)(x' * y' * ΔΩ), + ProjectTo(z)(y' * x' * ΔΩ), ) end return x * y * z, times_pullback3 diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 4f6f7ab20..2bdc56c9f 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -34,9 +34,9 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:N z = adjoint(x) * Ay function dot_pullback(Ω̄) ΔΩ = unthunk(Ω̄) - dx = @thunk project_x(conj(ΔΩ) .* Ay) - dA = @thunk project_A(ΔΩ .* x .* adjoint(y)) - dy = @thunk project_y(ΔΩ .* (adjoint(A) * x)) + dx = @thunk project_x(Ay .* conj(ΔΩ)) + dA = @thunk project_A(x .* ΔΩ .* adjoint(y)) + dy = @thunk project_y((adjoint(A) * x) .* ΔΩ) return (NoTangent(), dx, dA, dy) end dot_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent(), ZeroTangent()) From 947544718cc102f5f46b00de00c37e554d1860a6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 17 Feb 2022 01:17:40 +0100 Subject: [PATCH 21/23] Add quaternion helper --- test/test_helpers.jl | 56 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/test/test_helpers.jl b/test/test_helpers.jl index 83de7b35d..80cf33d7c 100644 --- a/test/test_helpers.jl +++ b/test/test_helpers.jl @@ -75,6 +75,62 @@ function ChainRulesCore.rrule(::typeof(make_two_vec), x) return make_two_vec(x), make_two_vec_pullback end +# Minimal quaternion implementation for testing rules that accept numbers that don't commute +# under multiplication +# adapted from Base Julia +# https://github.com/JuliaLang/julia/blob/bb5b98e72a151c41471d8cc14cacb495d647fb7f/test/testhelpers/Quaternions.jl + +export Quaternion + +struct Quaternion{T<:Real} <: Number + s::T + v1::T + v2::T + v3::T +end +const QuaternionF64 = Quaternion{Float64} +Quaternion(s::Real, v1::Real, v2::Real, v3::Real) = Quaternion(promote(s, v1, v2, v3)...) +Quaternion{T}(s::Real) where {T} = Quaternion(T(s), zero(T), zero(T), zero(T)) +Base.convert(::Type{Quaternion{T}}, s::Real) where {T <: Real} = + Quaternion{T}(convert(T, s), zero(T), zero(T), zero(T)) +Base.promote_rule(::Type{Quaternion{S}}, ::Type{T}) where {S<:Real,T<:Real} = Quaternion{Base.promote_type(S,T)} +Base.abs2(q::Quaternion) = q.s*q.s + q.v1*q.v1 + q.v2*q.v2 + q.v3*q.v3 +Base.float(z::Quaternion{T}) where T = Quaternion(float(z.s), float(z.v1), float(z.v2), float(z.v3)) +Base.abs(q::Quaternion) = sqrt(abs2(q)) +Base.real(::Type{Quaternion{T}}) where {T} = T +Base.real(q::Quaternion) = q.s +Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3) +Base.isfinite(q::Quaternion) = isfinite(q.s) & isfinite(q.v1) & isfinite(q.v2) & isfinite(q.v3) +Base.zero(::Type{Quaternion{T}}) where T = Quaternion{T}(zero(T), zero(T), zero(T), zero(T)) +Base.:(+)(ql::Quaternion, qr::Quaternion) = + Quaternion(ql.s + qr.s, ql.v1 + qr.v1, ql.v2 + qr.v2, ql.v3 + qr.v3) +Base.:(-)(ql::Quaternion, qr::Quaternion) = + Quaternion(ql.s - qr.s, ql.v1 - qr.v1, ql.v2 - qr.v2, ql.v3 - qr.v3) +Base.:(-)(q::Quaternion) = Quaternion(-q.s, -q.v1, -q.v2, -q.v3) +Base.:(*)(q::Quaternion, w::Quaternion) = Quaternion(q.s*w.s - q.v1*w.v1 - q.v2*w.v2 - q.v3*w.v3, + q.s*w.v1 + q.v1*w.s + q.v2*w.v3 - q.v3*w.v2, + q.s*w.v2 - q.v1*w.v3 + q.v2*w.s + q.v3*w.v1, + q.s*w.v3 + q.v1*w.v2 - q.v2*w.v1 + q.v3*w.s) +Base.:(*)(q::Quaternion, r::Real) = Quaternion(q.s*r, q.v1*r, q.v2*r, q.v3*r) +Base.:(*)(q::Quaternion, b::Bool) = b * q # remove method ambiguity +Base.:(/)(q::Quaternion, w::Quaternion) = q * conj(w) * (1.0 / abs2(w)) +Base.:(\)(q::Quaternion, w::Quaternion) = conj(q) * w * (1.0 / abs2(q)) +function Base.rand(rng::AbstractRNG, ::Random.SamplerType{Quaternion{T}}) where {T<:Real} + return Quaternion{T}(rand(rng, T), rand(rng, T), rand(rng, T), rand(rng, T)) +end +function Base.randn(rng::AbstractRNG, ::Type{Quaternion{T}}) where {T<:AbstractFloat} + return Quaternion{T}( + randn(rng, T) / 2, + randn(rng, T) / 2, + randn(rng, T) / 2, + randn(rng, T) / 2, + ) +end +(project::ProjectTo{<:Real})(dx::Quaternion) = project(real(dx)) +function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, q::Quaternion{Float64}) + return Quaternion(rand(rng, -9:0.1:9), rand(rng, -9:0.1:9), rand(rng, -9:0.1:9), rand(rng, -9:0.1:9)) +end + @testset "test_helpers.jl" begin @testset "Multiplier" begin From dc2965572cf0d0387704502f31cc2823bc80588f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 17 Feb 2022 01:17:55 +0100 Subject: [PATCH 22/23] Test rules using quaternions --- test/rulesets/Base/base.jl | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 77dff1827..ccbfa5d37 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -73,10 +73,12 @@ end end - @testset "Unary complex functions" begin - for x in (-4.1, 6.4, 0.0, 0.0 + 0.0im, 0.5 + 0.25im) + @testset "Unary non-real functions" begin + for x in (-4.1, 6.4, 0.0, 0.0 + 0.0im, 0.5 + 0.25im, Quaternion(0.0,0,0,0), Quaternion(0.6,1,-0.2,0.5)) test_scalar(real, x) - test_scalar(imag, x) + if !isa(x, Quaternion) + test_scalar(imag, x) + end test_scalar(hypot, x) test_scalar(adjoint, x) end @@ -92,9 +94,12 @@ @testset "*(x, y) (scalar)" begin # This is pretty important so testing it fairly heavily - test_points = (0.0, -2.1, 3.2, 3.7+2.12im, 14.2-7.1im) + test_points = (0.0, -2.1, 3.2, 3.7+2.12im, 14.2-7.1im, Quaternion(-0.48, -0.15, 0.63, 0.14)) @testset "($x) * ($y)" for x in test_points, y in test_points + if (x isa Complex && y isa Quaternion) || (x isa Quaternion && y isa Complex) + continue + end # ensure all complex if any complex for FiniteDifferences x, y = Base.promote(x, y) @@ -103,6 +108,10 @@ test_rrule(*, x, y) end @testset "*($x, $y, ...)" for x in test_points, y in test_points + if (x isa Complex && y isa Quaternion) || (x isa Quaternion && y isa Complex) + continue + end + # This promotion is only for FiniteDifferences, the rules allow mixtures: x, y = Base.promote(x, y) @@ -119,7 +128,7 @@ end end - @testset "\\(x::$T, y::$T) (scalar)" for T in (Float64, ComplexF64) + @testset "\\(x::$T, y::$T) (scalar)" for T in (Float64, ComplexF64, QuaternionF64) test_frule(*, randn(T), randn(T)) test_rrule(*, randn(T), randn(T)) end @@ -132,7 +141,7 @@ test_rrule(mod, (rand(0:10) + .6rand() + .2) * base, base) end - @testset "identity" for T in (Float64, ComplexF64) + @testset "identity" for T in (Float64, ComplexF64, QuaternionF64) test_frule(identity, randn(T)) test_frule(identity, randn(T, 4)) test_frule(identity, Tuple(randn(T, 3))) @@ -142,14 +151,14 @@ test_rrule(identity, Tuple(randn(T, 3))) end - @testset "Constants" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0.0+200im) + @testset "Constants" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0.0+200im, Quaternion(2.0,0,3,0)) test_scalar(one, x) test_scalar(zero, x) end - @testset "muladd(x::$T, y::$T, z::$T)" for T in (Float64, ComplexF64) - test_frule(muladd, 10randn(), randn(), randn()) - test_rrule(muladd, 10randn(), randn(), randn()) + @testset "muladd(x::$T, y::$T, z::$T)" for T in (Float64, ComplexF64, QuaternionF64) + test_frule(muladd, 10randn(T), randn(T), randn(T)) + test_rrule(muladd, 10randn(T), randn(T), randn(T)) end @testset "fma" begin From 3e6e83845bf9215357078113928cc6e3704ed5d4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 17 Feb 2022 01:23:44 +0100 Subject: [PATCH 23/23] Add note RE source of rand implementations --- test/test_helpers.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_helpers.jl b/test/test_helpers.jl index 80cf33d7c..0261165bf 100644 --- a/test/test_helpers.jl +++ b/test/test_helpers.jl @@ -115,6 +115,7 @@ Base.:(*)(q::Quaternion, r::Real) = Quaternion(q.s*r, q.v1*r, q.v2*r, q.v3*r) Base.:(*)(q::Quaternion, b::Bool) = b * q # remove method ambiguity Base.:(/)(q::Quaternion, w::Quaternion) = q * conj(w) * (1.0 / abs2(w)) Base.:(\)(q::Quaternion, w::Quaternion) = conj(q) * w * (1.0 / abs2(q)) +# rand implementations adapted from https://github.com/JuliaGeometry/Quaternions.jl/pull/42/files function Base.rand(rng::AbstractRNG, ::Random.SamplerType{Quaternion{T}}) where {T<:Real} return Quaternion{T}(rand(rng, T), rand(rng, T), rand(rng, T), rand(rng, T)) end