Skip to content

Assume commutative multiplication exactly when necessary #540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 47 additions & 57 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ end
function rrule(::typeof(inv), x::AbstractArray)
Ω = inv(x)
function inv_pullback(ΔΩ)
return NoTangent(), -Ω' * ΔΩ * Ω'
return NoTangent(), Ω' * -ΔΩ * Ω'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I ask why you moved the minus?

If it was -true * Ω' * ΔΩ * Ω' then I think you'd save a copy (since this gets fused into mul!).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that if ΔΩ is an AbstractZero or a UniformScaling, then the negation is cheaper.

If it was -true * Ω' * ΔΩ * Ω' then I think you'd save a copy (since this gets fused into mul!).

I didn't follow this. How is this fused into the mul!?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this was not an important change, and I'm happy to remove)

Copy link
Member

@mcabbott mcabbott Oct 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I didn't think about those. For dense matrices there's a 4-arg method which fuses this:

julia> f1(Ω, ΔΩ) = Ω' * -ΔΩ * Ω';

julia> f2(Ω, ΔΩ) = -true * Ω' * ΔΩ * Ω';

julia> @btime f1(Ω, ΔΩ) setup=(N=100; Ω=rand(N,N); ΔΩ=rand(N,N));
  min 74.708 μs, mean 101.133 μs (6 allocations, 234.52 KiB. GC mean 6.51%)

julia> @btime f2(Ω, ΔΩ) setup=(N=100; Ω=rand(N,N); ΔΩ=rand(N,N));
  min 73.125 μs, mean 92.756 μs (4 allocations, 156.34 KiB. GC mean 4.82%)

julia> @which -1 * ones(2,2) * ones(2,2) * ones(2,2)
*(α::Union{Real, Complex}, B::AbstractMatrix{<:Union{Real, Complex}}, C::AbstractMatrix{<:Union{Real, Complex}}, D::AbstractMatrix{<:Union{Real, Complex}}) in LinearAlgebra at /Users/me/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:1134

But with I, no fusion, hence f2 is slower. Maybe * should have some extra methods for cases with I.

end
return Ω, inv_pullback
end
Expand All @@ -19,12 +19,7 @@ end
##### `*`
#####


function rrule(
::typeof(*),
A::AbstractVecOrMat{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
)
function rrule(::typeof(*), A::AbstractVecOrMat{<:Number}, B::AbstractVecOrMat{<:Number})
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
Expand Down Expand Up @@ -60,13 +55,7 @@ function rrule(
return A * B, times_pullback
end



function rrule(
::typeof(*),
A::AbstractVector{<:CommutativeMulNumber},
B::AbstractMatrix{<:CommutativeMulNumber},
)
function rrule(::typeof(*), A::AbstractVector{<:Number}, B::AbstractMatrix{<:Number})
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
Expand All @@ -87,19 +76,14 @@ function rrule(
return A * B, times_pullback
end




function rrule(
::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber}
)
function rrule(::typeof(*), A::Number, B::AbstractArray{<:Number})
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
Ȳ = unthunk(ȳ)
return (
NoTangent(),
@thunk(project_A(dot(Ȳ, B)')),
@thunk(project_A(eltype(B) isa CommutativeMulNumber ? dot(Ȳ, B)' : dot(Ȳ', B'))),
InplaceableThunk(
X̄ -> mul!(X̄, conj(A), Ȳ, true, true),
@thunk(project_B(A' * Ȳ)),
Expand All @@ -110,7 +94,7 @@ function rrule(
end

function rrule(
::typeof(*), B::AbstractArray{<:CommutativeMulNumber}, A::CommutativeMulNumber
::typeof(*), B::AbstractArray{<:Number}, A::Number
)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
Expand All @@ -122,7 +106,7 @@ function rrule(
X̄ -> mul!(X̄, conj(A), Ȳ, true, true),
@thunk(project_B(A' * Ȳ)),
),
@thunk(project_A(dot(Ȳ, B)')),
@thunk(project_A(eltype(A) isa CommutativeMulNumber ? dot(Ȳ, B)' : dot(Ȳ', B'))),
)
end
return A * B, times_pullback
Expand All @@ -135,9 +119,9 @@ end

function rrule(
::typeof(muladd),
A::AbstractMatrix{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}},
A::AbstractMatrix{<:Number},
B::AbstractVecOrMat{<:Number},
z::Union{Number, AbstractVecOrMat{<:Number}},
)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
Expand Down Expand Up @@ -173,9 +157,9 @@ end

function rrule(
::typeof(muladd),
ut::LinearAlgebra.AdjOrTransAbsVec{<:CommutativeMulNumber},
v::AbstractVector{<:CommutativeMulNumber},
z::CommutativeMulNumber,
ut::LinearAlgebra.AdjOrTransAbsVec{<:Number},
v::AbstractVector{<:Number},
z::Number,
)
project_ut = ProjectTo(ut)
project_v = ProjectTo(v)
Expand All @@ -186,11 +170,11 @@ function rrule(
dy = unthunk(ȳ)
ut_thunk = InplaceableThunk(
dut -> dut .+= v' .* dy,
@thunk(project_ut(v' .* dy)),
@thunk(project_ut((v * dy')')),
)
v_thunk = InplaceableThunk(
dv -> dv .+= ut' .* dy,
@thunk(project_v(ut' .* dy)),
@thunk(project_v(ut' * dy)),
)
(NoTangent(), ut_thunk, v_thunk, z isa Bool ? NoTangent() : project_z(dy))
end
Expand All @@ -199,9 +183,9 @@ end

function rrule(
::typeof(muladd),
u::AbstractVector{<:CommutativeMulNumber},
vt::LinearAlgebra.AdjOrTransAbsVec{<:CommutativeMulNumber},
z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}},
u::AbstractVector{<:Number},
vt::LinearAlgebra.AdjOrTransAbsVec{<:Number},
z::Union{Number, AbstractVecOrMat{<:Number}},
)
project_u = ProjectTo(u)
project_vt = ProjectTo(vt)
Expand All @@ -211,8 +195,8 @@ function rrule(
function muladd_pullback_3(ȳ)
Ȳ = unthunk(ȳ)
proj = (
@thunk(project_u(vec(sum(.* conj.(vt), dims=2)))),
@thunk(project_vt(vec(sum(u .* conj.(Ȳ), dims=1))')),
@thunk(project_u(Ȳ * vec(vt'))),
@thunk(project_vt((Ȳ' * u)')),
)
addon = if z isa Bool
NoTangent()
Expand All @@ -233,20 +217,17 @@ end
##### `/`
#####

function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
Aᵀ, dA_pb = rrule(adjoint, A)
Bᵀ, dB_pb = rrule(adjoint, B)
Cᵀ, dS_pb = rrule(\, Bᵀ, Aᵀ)
C, dC_pb = rrule(adjoint, Cᵀ)
function slash_pullback(Ȳ)
# Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want
_, dC = dC_pb(Ȳ)
_, dBᵀ, dAᵀ = dS_pb(unthunk(dC))

∂A = last(dA_pb(unthunk(dAᵀ)))
∂B = last(dB_pb(unthunk(dBᵀ)))

(NoTangent(), ∂A, ∂B)
function rrule(::typeof(/), A::AbstractVecOrMat{<:Number}, B::AbstractVecOrMat{<:Number})
C = A / B
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function slash_pullback(ΔC)
∂A = unthunk(ΔC) / B'
∂B = InplaceableThunk(
@thunk(C' * -∂A),
B̄ -> mul!(B̄, C', ∂A, -1, true)
)
return NoTangent(), project_A(∂A), project_B(∂B)
end
return C, slash_pullback
end
Expand Down Expand Up @@ -280,25 +261,34 @@ end
##### `\`, `/` matrix-scalar_rule
#####

function rrule(::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber)
function rrule(::typeof(/), A::AbstractArray{<:Number}, b::Number)
Y = A/b
function slash_pullback_scalar(ȳ)
Ȳ = unthunk(ȳ)
Athunk = InplaceableThunk(
dA -> dA .+= Ȳ ./ conj(b),
@thunk(Ȳ / conj(b)),
)
bthunk = @thunk(-dot(A,Ȳ) / conj(b^2))
bthunk = @thunk(-dot(Y,Ȳ) / conj(b))
return (NoTangent(), Athunk, bthunk)
end
return Y, slash_pullback_scalar
end

function rrule(::typeof(\), b::CommutativeMulNumber, A::AbstractArray{<:CommutativeMulNumber})
Y, back = rrule(/, A, b)
function backslash_pullback(dY) # just reverses the arguments!
d0, dA, db = back(dY)
return (d0, db, dA)
function rrule(::typeof(\), b::Number, A::AbstractArray{<:Number})
Y = b \ A
function backslash_pullback(ȳ)
Ȳ = unthunk(ȳ)
Athunk = InplaceableThunk(
@thunk(conj(b) \ Ȳ),
dA -> dA .+= conj(b) .\ Ȳ,
)
bthunk = if eltype(Y) isa CommutativeMulNumber
@thunk(-conj(b) \ dot(Y, Ȳ))
else
@thunk(-conj(b) \ dot(Ȳ',Y'))
end
return (NoTangent(), Athunk, bthunk)
end
return Y, backslash_pullback
end
Expand Down
107 changes: 69 additions & 38 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,38 @@ end

@scalar_rule hypot(x::Real) sign(x)

function frule((_, Δz), ::typeof(hypot), z::Complex)
function frule((_, Δz), ::typeof(hypot), z::Number)
Ω = hypot(z)
∂Ω = _realconjtimes(z, Δz) / ifelse(iszero(Ω), one(Ω), Ω)
return Ω, ∂Ω
end

function rrule(::typeof(hypot), z::Complex)
function rrule(::typeof(hypot), z::Number)
Ω = hypot(z)
function hypot_pullback(ΔΩ)
return (NoTangent(), (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z)
end
return (Ω, hypot_pullback)
end

@scalar_rule fma(x, y, z) (y, x, true)
@scalar_rule muladd(x, y, z) (y, x, true)
@scalar_rule fma(x, y::CommutativeMulNumber, z) (y, x, true)
function frule((_, Δx, Δy, Δz), ::typeof(fma), x::Number, y::Number, z::Number)
return fma(x, y, z), muladd(Δx, y, muladd(x, Δy, Δz))
end
function rrule(::typeof(fma), x::Number, y::Number, z::Number)
projectx, projecty, projectz = ProjectTo(x), ProjectTo(y), ProjectTo(z)
fma_pullback(ΔΩ) = NoTangent(), projectx(ΔΩ * y'), projecty(x' * ΔΩ), projectz(ΔΩ)
fma(x, y, z), fma_pullback
end
@scalar_rule muladd(x, y::CommutativeMulNumber, z) (y, x, true)
function frule((_, Δx, Δy, Δz), ::typeof(muladd), x::Number, y::Number, z::Number)
return muladd(x, y, z), muladd(Δx, y, muladd(x, Δy, Δz))
end
function rrule(::typeof(muladd), x::Number, y::Number, z::Number)
projectx, projecty, projectz = ProjectTo(x), ProjectTo(y), ProjectTo(z)
muladd_pullback(ΔΩ) = NoTangent(), projectx(ΔΩ * y'), projecty(x' * ΔΩ), projectz(ΔΩ)
muladd(x, y, z), muladd_pullback
end
Comment on lines +106 to +110
Copy link
Member

@mcabbott mcabbott Feb 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. here I think this is very clear, the pattern of where the Δs go is important:

function frule((_, Δx, Δy, Δz), ::typeof(muladd), x::Number, y::Number, z::Number)
    return muladd(x, y, z), muladd(Δx, y, muladd(x, Δy, Δz))
end

but I'd like to make the corresponding rrule more distinct:

Suggested change
function rrule(::typeof(muladd), x::Number, y::Number, z::Number)
projectx, projecty, projectz = ProjectTo(x), ProjectTo(y), ProjectTo(z)
muladd_pullback(ΔΩ) = NoTangent(), projectx(ΔΩ * y'), projecty(x' * ΔΩ), projectz(ΔΩ)
muladd(x, y, z), muladd_pullback
end
function rrule(::typeof(muladd), x::Number, y::Number, z::Number)
muladd_pullback(∇Ω) = NoTangent(), ProjectTo(x)(∇Ω * y'), ProjectTo(y)(x' * ∇Ω), ProjectTo(z)(∇Ω)
return muladd(x, y, z), muladd_pullback
end

or perhaps ∂Ω is a sort-of lower-case ∇Ω for scalars?

And since it closes over x,y,z already, there is nothing gained by constructing projectors outside.

@scalar_rule rem2pi(x, r::RoundingMode) (true, NoTangent())
@scalar_rule(
mod(x, y),
Expand All @@ -105,51 +121,51 @@ end
@scalar_rule(ldexp(x, y), (2^y, NoTangent()))

# Can't multiply though sqrt in acosh because of negative complex case for x
@scalar_rule acosh(x) inv(sqrt(x - 1) * sqrt(x + 1))
@scalar_rule acoth(x) inv(1 - x ^ 2)
@scalar_rule acsch(x) -(inv(x ^ 2 * sqrt(1 + x ^ -2)))
@scalar_rule acosh(x::CommutativeMulNumber) inv(sqrt(x - 1) * sqrt(x + 1))
@scalar_rule acoth(x::CommutativeMulNumber) inv(1 - x ^ 2)
@scalar_rule acsch(x::CommutativeMulNumber) -(inv(x ^ 2 * sqrt(1 + x ^ -2)))
@scalar_rule acsch(x::Real) -(inv(abs(x) * sqrt(1 + x ^ 2)))
@scalar_rule asech(x) -(inv(x * sqrt(1 - x ^ 2)))
@scalar_rule asinh(x) inv(sqrt(x ^ 2 + 1))
@scalar_rule atanh(x) inv(1 - x ^ 2)
@scalar_rule asech(x::CommutativeMulNumber) -(inv(x * sqrt(1 - x ^ 2)))
@scalar_rule asinh(x::CommutativeMulNumber) inv(sqrt(x ^ 2 + 1))
@scalar_rule atanh(x::CommutativeMulNumber) inv(1 - x ^ 2)


@scalar_rule acosd(x) (-(oftype(x, 180)) / π) / sqrt(1 - x ^ 2)
@scalar_rule acotd(x) (-(oftype(x, 180)) / π) / (1 + x ^ 2)
@scalar_rule acscd(x) ((-(oftype(x, 180)) / π) / x ^ 2) / sqrt(1 - x ^ -2)
@scalar_rule acosd(x::CommutativeMulNumber) (-(oftype(x, 180)) / π) / sqrt(1 - x ^ 2)
@scalar_rule acotd(x::CommutativeMulNumber) (-(oftype(x, 180)) / π) / (1 + x ^ 2)
@scalar_rule acscd(x::CommutativeMulNumber) ((-(oftype(x, 180)) / π) / x ^ 2) / sqrt(1 - x ^ -2)
@scalar_rule acscd(x::Real) ((-(oftype(x, 180)) / π) / abs(x)) / sqrt(x ^ 2 - 1)
@scalar_rule asecd(x) ((oftype(x, 180) / π) / x ^ 2) / sqrt(1 - x ^ -2)
@scalar_rule asecd(x::CommutativeMulNumber) ((oftype(x, 180) / π) / x ^ 2) / sqrt(1 - x ^ -2)
@scalar_rule asecd(x::Real) ((oftype(x, 180) / π) / abs(x)) / sqrt(x ^ 2 - 1)
@scalar_rule asind(x) (oftype(x, 180) / π) / sqrt(1 - x ^ 2)
@scalar_rule atand(x) (oftype(x, 180) / π) / (1 + x ^ 2)

@scalar_rule cot(x) -((1 + Ω ^ 2))
@scalar_rule coth(x) -(csch(x) ^ 2)
@scalar_rule cotd(x) -(π / oftype(x, 180)) * (1 + Ω ^ 2)
@scalar_rule csc(x) -Ω * cot(x)
@scalar_rule cscd(x) -(π / oftype(x, 180)) * Ω * cotd(x)
@scalar_rule csch(x) -(coth(x)) * Ω
@scalar_rule sec(x) Ω * tan(x)
@scalar_rule secd(x) (π / oftype(x, 180)) * Ω * tand(x)
@scalar_rule sech(x) -(tanh(x)) * Ω

@scalar_rule acot(x) -(inv(1 + x ^ 2))
@scalar_rule acsc(x) -(inv(x ^ 2 * sqrt(1 - x ^ -2)))
@scalar_rule asind(x::CommutativeMulNumber) (oftype(x, 180) / π) / sqrt(1 - x ^ 2)
@scalar_rule atand(x::CommutativeMulNumber) (oftype(x, 180) / π) / (1 + x ^ 2)

@scalar_rule cot(x::CommutativeMulNumber) -((1 + Ω ^ 2))
@scalar_rule coth(x::CommutativeMulNumber) -(csch(x) ^ 2)
@scalar_rule cotd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * (1 + Ω ^ 2)
@scalar_rule csc(x::CommutativeMulNumber) -Ω * cot(x)
@scalar_rule cscd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * Ω * cotd(x)
@scalar_rule csch(x::CommutativeMulNumber) -(coth(x)) * Ω
@scalar_rule sec(x::CommutativeMulNumber) Ω * tan(x)
@scalar_rule secd(x::CommutativeMulNumber) (π / oftype(x, 180)) * Ω * tand(x)
@scalar_rule sech(x::CommutativeMulNumber) -(tanh(x)) * Ω

@scalar_rule acot(x::CommutativeMulNumber) -(inv(1 + x ^ 2))
@scalar_rule acsc(x::CommutativeMulNumber) -(inv(x ^ 2 * sqrt(1 - x ^ -2)))
@scalar_rule acsc(x::Real) -(inv(abs(x) * sqrt(x ^ 2 - 1)))
@scalar_rule asec(x) inv(x ^ 2 * sqrt(1 - x ^ -2))
@scalar_rule asec(x::CommutativeMulNumber) inv(x ^ 2 * sqrt(1 - x ^ -2))
@scalar_rule asec(x::Real) inv(abs(x) * sqrt(x ^ 2 - 1))

@scalar_rule cosd(x) -(π / oftype(x, 180)) * sind(x)
@scalar_rule cospi(x) -π * sinpi(x)
@scalar_rule sind(x) (π / oftype(x, 180)) * cosd(x)
@scalar_rule sinpi(x) π * cospi(x)
@scalar_rule tand(x) (π / oftype(x, 180)) * (1 + Ω ^ 2)
@scalar_rule cosd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * sind(x)
@scalar_rule cospi(x::CommutativeMulNumber) -π * sinpi(x)
@scalar_rule sind(x::CommutativeMulNumber) (π / oftype(x, 180)) * cosd(x)
@scalar_rule sinpi(x::CommutativeMulNumber) π * cospi(x)
@scalar_rule tand(x::CommutativeMulNumber) (π / oftype(x, 180)) * (1 + Ω ^ 2)

@scalar_rule sinc(x) cosc(x)
@scalar_rule sinc(x::CommutativeMulNumber) cosc(x)

# the position of the minus sign below warrants the correct type for π
if VERSION ≥ v"1.6"
@scalar_rule sincospi(x) @setup((sinpix, cospix) = Ω) (π * cospix) (π * (-sinpix))
@scalar_rule sincospi(x::CommutativeMulNumber) @setup((sinpix, cospix) = Ω) (π * cospix) (π * (-sinpix))
end

@scalar_rule(
Expand All @@ -160,7 +176,22 @@ end
),
(!(islow | ishigh), islow, ishigh),
)
@scalar_rule x \ y (-(Ω / x), one(y) / x)

@scalar_rule x::CommutativeMulNumber \ y::CommutativeMulNumber (-(x \ Ω), x \ one(y))
function frule((_, Δx, Δy), ::typeof(\), x::Number, y::Number)
Ω = x \ y
return Ω, x \ muladd(Δy, -Δx, Ω)
end
function rrule(::typeof(\), x::Number, y::Number)
Ω = x \ y
project_x = ProjectTo(x)
project_y = ProjectTo(y)
function backslash_pullback(ΔΩ)
∂x = x' \ ΔΩ
return NoTangent(), project_x(∂x), project_y(-∂x * Ω')
end
return Ω, backslash_pullback
end

function frule((_, ẏ), ::typeof(identity), x)
return (x, ẏ)
Expand Down
Loading